From 63439c60c77c43d5dee6660fef519eb27637db41 Mon Sep 17 00:00:00 2001 From: Sameen Karim Date: Tue, 5 May 2026 08:06:58 -0400 Subject: [PATCH 1/2] use pr template when opening prs --- cmd/link.go | 13 ++- cmd/link_test.go | 159 +++++++++++++++++++++++++++++++++++ cmd/submit.go | 28 ++++-- cmd/submit_test.go | 138 ++++++++++++++++++++++++++++-- internal/git/git.go | 5 ++ internal/git/gitops.go | 5 ++ internal/git/mock_ops.go | 8 ++ internal/pr/template.go | 35 ++++++++ internal/pr/template_test.go | 71 ++++++++++++++++ 9 files changed, 445 insertions(+), 17 deletions(-) create mode 100644 internal/pr/template.go create mode 100644 internal/pr/template_test.go diff --git a/cmd/link.go b/cmd/link.go index 2e0e3ab..93bb749 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -9,6 +9,7 @@ import ( "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" "github.com/github/gh-stack/internal/github" + "github.com/github/gh-stack/internal/pr" "github.com/spf13/cobra" ) @@ -109,6 +110,12 @@ func runLink(cfg *config.Config, opts *linkOptions, args []string) error { } } + // Look up the repository's PR template (best-effort; skip if not in a repo). + var templateContent string + if repoRoot, tlErr := git.RootDir(); tlErr == nil { + templateContent = pr.FindTemplate(repoRoot) + } + // Phase 4: Create PRs for branches that don't have one yet needsCreation := 0 for _, r := range found { @@ -119,7 +126,7 @@ func runLink(cfg *config.Config, opts *linkOptions, args []string) error { if needsCreation > 0 { cfg.Printf("Creating %d %s...", needsCreation, plural(needsCreation, "PR", "PRs")) } - resolved, err := createMissingPRs(cfg, client, opts, args, found) + resolved, err := createMissingPRs(cfg, client, opts, args, found, templateContent) if err != nil { return err } @@ -303,7 +310,7 @@ func prevalidateStack(cfg *config.Config, stacks []github.RemoteStack, knownPRNu // createMissingPRs creates PRs for branches that don't have one yet. // Returns the fully resolved list with all branches mapped to PRs. -func createMissingPRs(cfg *config.Config, client github.ClientOps, opts *linkOptions, args []string, found []*resolvedArg) ([]resolvedArg, error) { +func createMissingPRs(cfg *config.Config, client github.ClientOps, opts *linkOptions, args []string, found []*resolvedArg, templateContent string) ([]resolvedArg, error) { resolved := make([]resolvedArg, len(args)) for i, arg := range args { @@ -319,7 +326,7 @@ func createMissingPRs(cfg *config.Config, client github.ClientOps, opts *linkOpt } title := humanize(arg) - body := generatePRBody("") + body := generatePRBody("", templateContent) newPR, err := client.CreatePR(baseBranch, arg, title, body, !opts.open) if err != nil { diff --git a/cmd/link_test.go b/cmd/link_test.go index e70b20e..b7b3e5f 100644 --- a/cmd/link_test.go +++ b/cmd/link_test.go @@ -3,6 +3,8 @@ package cmd import ( "fmt" "io" + "os" + "path/filepath" "testing" "github.com/cli/go-gh/v2/pkg/api" @@ -1203,3 +1205,160 @@ func TestLink_SkipsBaseFix_ForNewlyCreatedPRs(t *testing.T) { // Silence "imported and not used" for fmt in case test helpers use it. var _ = fmt.Sprintf + +func TestLink_FetchesBeforePush(t *testing.T) { + var callOrder []string + var fetchedBranches []string + + mock := newLinkGitMock("feat-a", "feat-b") + mock.FetchBranchesFn = func(remote string, branches []string) error { + callOrder = append(callOrder, "fetch") + fetchedBranches = branches + assert.Equal(t, "origin", remote) + return nil + } + mock.PushFn = func(remote string, branches []string, force, atomic bool) error { + callOrder = append(callOrder, "push") + return nil + } + + restore := git.SetOps(mock) + defer restore() + + prNum := 0 + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + prNum++ + return &github.PullRequest{ + Number: prNum, + URL: fmt.Sprintf("https://github.com/o/r/pull/%d", prNum), + BaseRefName: "main", + HeadRefName: branch, + State: "OPEN", + }, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func(prNumbers []int) (int, error) { + return 42, nil + }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"feat-a", "feat-b"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + _, _ = io.ReadAll(errR) + + assert.NoError(t, err) + assert.Equal(t, []string{"feat-a", "feat-b"}, fetchedBranches, "should fetch pushed branches") + require.Len(t, callOrder, 2) + assert.Equal(t, "fetch", callOrder[0], "fetch must happen before push") + assert.Equal(t, "push", callOrder[1]) +} + +func TestLink_BranchNames_UsesPRTemplate(t *testing.T) { + tmpDir := t.TempDir() + ghDir := filepath.Join(tmpDir, ".github") + require.NoError(t, os.MkdirAll(ghDir, 0o755)) + require.NoError(t, os.WriteFile( + filepath.Join(ghDir, "pull_request_template.md"), + []byte("## Summary\n\nDescribe your changes."), + 0o644, + )) + + mock := newLinkGitMock("feat-a", "feat-b") + mock.RootDirFn = func() (string, error) { return tmpDir, nil } + restore := git.SetOps(mock) + defer restore() + + var capturedBody string + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(string) (*github.PullRequest, error) { + return nil, nil // No existing PRs + }, + CreatePRFn: func(base, head, title, body string, draft bool) (*github.PullRequest, error) { + capturedBody = body + return &github.PullRequest{ + Number: 1, HeadRefName: head, BaseRefName: base, + URL: "https://github.com/o/r/pull/1", + }, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func([]int) (int, error) { return 42, nil }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"feat-a", "feat-b"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + _, _ = io.ReadAll(errR) + + assert.NoError(t, err) + assert.Contains(t, capturedBody, "## Summary") + assert.Contains(t, capturedBody, "Describe your changes.") + assert.NotContains(t, capturedBody, "GitHub Stacks CLI", "footer should not be present when template is used") +} + +func TestLink_PRNumbers_NoTemplateUsesFooter(t *testing.T) { + // When using PR numbers (no local repo context), no template is found + // and the footer should be present for newly created PRs. + mock := &git.MockOps{ + RootDirFn: func() (string, error) { + return "", fmt.Errorf("not in a git repo") + }, + } + restore := git.SetOps(mock) + defer restore() + + var capturedBody string + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRByNumberFn: func(n int) (*github.PullRequest, error) { + if n == 10 { + return &github.PullRequest{ + Number: 10, HeadRefName: "feat-a", BaseRefName: "main", + URL: "https://github.com/o/r/pull/10", + }, nil + } + return nil, nil // PR 20 doesn't exist → will create + }, + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return nil, nil + }, + CreatePRFn: func(base, head, title, body string, draft bool) (*github.PullRequest, error) { + capturedBody = body + return &github.PullRequest{ + Number: 20, HeadRefName: head, BaseRefName: base, + URL: "https://github.com/o/r/pull/20", + }, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func([]int) (int, error) { return 42, nil }, + } + + cmd := LinkCmd(cfg) + cmd.SetArgs([]string{"10", "20"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + _, _ = io.ReadAll(errR) + + assert.NoError(t, err) + assert.Contains(t, capturedBody, "GitHub Stacks CLI", "footer should be present when no template") +} diff --git a/cmd/submit.go b/cmd/submit.go index 16db9ea..18849f2 100644 --- a/cmd/submit.go +++ b/cmd/submit.go @@ -12,6 +12,7 @@ import ( "github.com/github/gh-stack/internal/git" "github.com/github/gh-stack/internal/github" "github.com/github/gh-stack/internal/modify" + "github.com/github/gh-stack/internal/pr" "github.com/github/gh-stack/internal/stack" "github.com/spf13/cobra" ) @@ -148,6 +149,12 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { // remote yet. _ = git.FetchBranches(remote, activeBranches) + // Look up the repository's PR template once before creating any PRs. + var templateContent string + if repoRoot, err := git.RootDir(); err == nil { + templateContent = pr.FindTemplate(repoRoot) + } + // Push each branch and create/update its PR in stack order (bottom to top). // Sequential pushing ensures each branch's base is up-to-date on the // remote before the next branch is pushed, preventing race conditions. @@ -165,7 +172,7 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { // Find or create PR, and fix base if needed baseBranch := s.ActiveBaseBranch(b.Branch) - if err := ensurePR(cfg, client, s, i, baseBranch, opts); err != nil { + if err := ensurePR(cfg, client, s, i, baseBranch, opts, templateContent); err != nil { if errors.Is(err, errInterrupt) { printInterrupt(cfg) return ErrSilent @@ -195,7 +202,7 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { // ensurePR finds or creates a PR for the branch at index i, and updates // its base branch if needed. This is the single place where PR state is // reconciled during submit. -func ensurePR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int, baseBranch string, opts *submitOptions) error { +func ensurePR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int, baseBranch string, opts *submitOptions, templateContent string) error { b := s.Branches[i] pr, err := client.FindPRForBranch(b.Branch) @@ -205,7 +212,7 @@ func ensurePR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int } if pr == nil { - return createPR(cfg, client, s, i, baseBranch, opts) + return createPR(cfg, client, s, i, baseBranch, opts, templateContent) } // PR exists — record it and fix base if needed. @@ -250,7 +257,7 @@ func ensurePR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int } // createPR creates a new PR for the branch at index i. -func createPR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int, baseBranch string, opts *submitOptions) error { +func createPR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int, baseBranch string, opts *submitOptions, templateContent string) error { b := s.Branches[i] title, commitBody := defaultPRTitleBody(baseBranch, b.Branch) @@ -272,7 +279,7 @@ func createPR(cfg *config.Config, client github.ClientOps, s *stack.Stack, i int if title != originalTitle && commitBody != "" { prBody = originalTitle + "\n\n" + commitBody } - body := generatePRBody(prBody) + body := generatePRBody(prBody, templateContent) newPR, createErr := client.CreatePR(baseBranch, b.Branch, title, body, !opts.open) if createErr != nil { @@ -299,9 +306,14 @@ func defaultPRTitleBody(base, head string) (string, string) { return humanize(head), "" } -// generatePRBody builds a PR description from the commit body (if any) -// and a footer linking to the CLI and feedback form. -func generatePRBody(commitBody string) string { +// generatePRBody builds a PR description. When a templateContent is provided, +// it is used as the body and the attribution footer is omitted. Otherwise the +// body is built from the commit body with a footer linking to the CLI. +func generatePRBody(commitBody string, templateContent string) string { + if templateContent != "" { + return templateContent + } + var parts []string if commitBody != "" { diff --git a/cmd/submit_test.go b/cmd/submit_test.go index c4ded3d..1d86ed8 100644 --- a/cmd/submit_test.go +++ b/cmd/submit_test.go @@ -6,6 +6,7 @@ import ( "io" "net/url" "os" + "path/filepath" "testing" "github.com/cli/go-gh/v2/pkg/api" @@ -20,12 +21,14 @@ import ( func TestGeneratePRBody(t *testing.T) { tests := []struct { - name string - commitBody string - wantContains []string + name string + commitBody string + templateContent string + wantContains []string + wantNotContains []string }{ { - name: "empty commit body", + name: "empty commit body no template", commitBody: "", wantContains: []string{ "GitHub Stacks CLI", @@ -34,7 +37,7 @@ func TestGeneratePRBody(t *testing.T) { }, }, { - name: "with commit body", + name: "with commit body no template", commitBody: "This is a detailed description\nof the change.", wantContains: []string{ "This is a detailed description\nof the change.", @@ -42,14 +45,37 @@ func TestGeneratePRBody(t *testing.T) { "", }, }, + { + name: "with template", + commitBody: "some commit body", + templateContent: "## Description\n\nFill in details.", + wantContains: []string{ + "## Description", + "Fill in details.", + }, + wantNotContains: []string{ + "GitHub Stacks CLI", + feedbackURL, + "some commit body", + }, + }, + { + name: "template replaces footer", + templateContent: "Template body only", + wantContains: []string{"Template body only"}, + wantNotContains: []string{""}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := generatePRBody(tt.commitBody) + got := generatePRBody(tt.commitBody, tt.templateContent) for _, want := range tt.wantContains { assert.Contains(t, got, want) } + for _, notWant := range tt.wantNotContains { + assert.NotContains(t, got, notWant) + } }) } } @@ -58,6 +84,7 @@ func TestGeneratePRBody(t *testing.T) { func newSubmitMock(tmpDir string, currentBranch string) *git.MockOps { return &git.MockOps{ GitDirFn: func() (string, error) { return tmpDir, nil }, + RootDirFn: func() (string, error) { return tmpDir, nil }, CurrentBranchFn: func() (string, error) { return currentBranch, nil }, ResolveRemoteFn: func(string) (string, error) { return "origin", nil }, PushFn: func(string, []string, bool, bool) error { return nil }, @@ -1581,3 +1608,102 @@ func TestSubmit_FetchesBeforePush(t *testing.T) { require.True(t, len(callOrder) >= 3, "expected at least 3 calls (fetch + 2 pushes)") assert.Equal(t, "fetch", callOrder[0], "fetch must happen before any push") } + +func TestSubmit_UsesPRTemplate(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + }, + } + + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + // Create a PR template in the repo root + ghDir := filepath.Join(tmpDir, ".github") + require.NoError(t, os.MkdirAll(ghDir, 0o755)) + require.NoError(t, os.WriteFile( + filepath.Join(ghDir, "pull_request_template.md"), + []byte("## What\n\nDescribe changes.\n\n## Why\n\nExplain motivation."), + 0o644, + )) + + var capturedBody string + + mock := newSubmitMock(tmpDir, "b1") + mock.PushFn = func(string, []string, bool, bool) error { return nil } + mock.LogRangeFn = func(base, head string) ([]git.CommitInfo, error) { + return []git.CommitInfo{{Subject: "add feature", Body: "detailed commit body"}}, nil + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + ListStacksFn: func() ([]github.RemoteStack, error) { return nil, nil }, + FindPRForBranchFn: func(string) (*github.PullRequest, error) { return nil, nil }, + CreatePRFn: func(base, head, title, body string, draft bool) (*github.PullRequest, error) { + capturedBody = body + return &github.PullRequest{Number: 1, ID: "PR_1", URL: "https://github.com/o/r/pull/1"}, nil + }, + CreateStackFn: func([]int) (int, error) { return 1, nil }, + } + + cmd := SubmitCmd(cfg) + cmd.SetArgs([]string{"--auto"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + assert.NoError(t, err) + assert.Contains(t, capturedBody, "## What") + assert.Contains(t, capturedBody, "## Why") + assert.NotContains(t, capturedBody, "GitHub Stacks CLI", "footer should not be present when template is used") + assert.NotContains(t, capturedBody, feedbackURL) +} + +func TestSubmit_NoTemplate_UsesFooter(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + }, + } + + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + // No template file created + + var capturedBody string + + mock := newSubmitMock(tmpDir, "b1") + mock.PushFn = func(string, []string, bool, bool) error { return nil } + mock.LogRangeFn = func(base, head string) ([]git.CommitInfo, error) { + return []git.CommitInfo{{Subject: "fix bug"}}, nil + } + restore := git.SetOps(mock) + defer restore() + + cfg, _, _ := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + ListStacksFn: func() ([]github.RemoteStack, error) { return nil, nil }, + FindPRForBranchFn: func(string) (*github.PullRequest, error) { return nil, nil }, + CreatePRFn: func(base, head, title, body string, draft bool) (*github.PullRequest, error) { + capturedBody = body + return &github.PullRequest{Number: 1, ID: "PR_1", URL: "https://github.com/o/r/pull/1"}, nil + }, + CreateStackFn: func([]int) (int, error) { return 1, nil }, + } + + cmd := SubmitCmd(cfg) + cmd.SetArgs([]string{"--auto"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + assert.NoError(t, err) + assert.Contains(t, capturedBody, "GitHub Stacks CLI", "footer should be present when no template") + assert.Contains(t, capturedBody, feedbackURL) +} diff --git a/internal/git/git.go b/internal/git/git.go index 36e1325..7be063e 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -104,6 +104,11 @@ func GitDir() (string, error) { return ops.GitDir() } +// RootDir returns the repository's root directory. +func RootDir() (string, error) { + return ops.RootDir() +} + // CurrentBranch returns the name of the current branch. func CurrentBranch() (string, error) { return ops.CurrentBranch() diff --git a/internal/git/gitops.go b/internal/git/gitops.go index 3ed107f..a00deac 100644 --- a/internal/git/gitops.go +++ b/internal/git/gitops.go @@ -17,6 +17,7 @@ import ( // Tests can substitute a mock via SetOps(). type Ops interface { GitDir() (string, error) + RootDir() (string, error) CurrentBranch() (string, error) BranchExists(name string) bool CheckoutBranch(name string) error @@ -91,6 +92,10 @@ func (d *defaultOps) GitDir() (string, error) { return client.GitDir(context.Background()) } +func (d *defaultOps) RootDir() (string, error) { + return run("rev-parse", "--show-toplevel") +} + func (d *defaultOps) CurrentBranch() (string, error) { return client.CurrentBranch(context.Background()) } diff --git a/internal/git/mock_ops.go b/internal/git/mock_ops.go index 07bfb75..05b8c51 100644 --- a/internal/git/mock_ops.go +++ b/internal/git/mock_ops.go @@ -5,6 +5,7 @@ package git // Ops method call. When nil, a reasonable default is returned. type MockOps struct { GitDirFn func() (string, error) + RootDirFn func() (string, error) CurrentBranchFn func() (string, error) BranchExistsFn func(string) bool CheckoutBranchFn func(string) error @@ -60,6 +61,13 @@ func (m *MockOps) GitDir() (string, error) { return "/tmp/fake-git-dir", nil } +func (m *MockOps) RootDir() (string, error) { + if m.RootDirFn != nil { + return m.RootDirFn() + } + return "/tmp/fake-repo", nil +} + func (m *MockOps) CurrentBranch() (string, error) { if m.CurrentBranchFn != nil { return m.CurrentBranchFn() diff --git a/internal/pr/template.go b/internal/pr/template.go new file mode 100644 index 0000000..d9ccfae --- /dev/null +++ b/internal/pr/template.go @@ -0,0 +1,35 @@ +package pr + +import ( + "os" + "path/filepath" + "strings" +) + +// templatePaths lists the candidate locations for a pull request template. +var templatePaths = []string{ + ".github/pull_request_template.md", + ".github/PULL_REQUEST_TEMPLATE.md", + "pull_request_template.md", + "PULL_REQUEST_TEMPLATE.md", + "docs/pull_request_template.md", + "docs/PULL_REQUEST_TEMPLATE.md", +} + +// FindTemplate searches the repository root for a default pull request +// template and returns its content. Returns an empty string if no template +// is found or cannot be read. +func FindTemplate(repoRoot string) string { + for _, candidate := range templatePaths { + path := filepath.Join(repoRoot, candidate) + data, err := os.ReadFile(path) + if err != nil { + continue + } + content := strings.TrimSpace(string(data)) + if content != "" { + return content + } + } + return "" +} diff --git a/internal/pr/template_test.go b/internal/pr/template_test.go new file mode 100644 index 0000000..c55ec1d --- /dev/null +++ b/internal/pr/template_test.go @@ -0,0 +1,71 @@ +package pr + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFindTemplate_GitHubDir(t *testing.T) { + root := t.TempDir() + dir := filepath.Join(root, ".github") + os.MkdirAll(dir, 0o755) + os.WriteFile(filepath.Join(dir, "pull_request_template.md"), []byte("## Description\n\nFill in details."), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "## Description\n\nFill in details.", got) +} + +func TestFindTemplate_RootDir(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte("Root template"), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "Root template", got) +} + +func TestFindTemplate_DocsDir(t *testing.T) { + root := t.TempDir() + dir := filepath.Join(root, "docs") + os.MkdirAll(dir, 0o755) + os.WriteFile(filepath.Join(dir, "PULL_REQUEST_TEMPLATE.md"), []byte("Docs template"), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "Docs template", got) +} + +func TestFindTemplate_PriorityOrder(t *testing.T) { + root := t.TempDir() + ghDir := filepath.Join(root, ".github") + os.MkdirAll(ghDir, 0o755) + os.WriteFile(filepath.Join(ghDir, "pull_request_template.md"), []byte("github template"), 0o644) + os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte("root template"), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "github template", got) +} + +func TestFindTemplate_NoTemplate(t *testing.T) { + root := t.TempDir() + + got := FindTemplate(root) + assert.Equal(t, "", got) +} + +func TestFindTemplate_EmptyFile(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte(" \n "), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "", got, "empty/whitespace-only template should be treated as no template") +} + +func TestFindTemplate_UpperCase(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "PULL_REQUEST_TEMPLATE.md"), []byte("UPPER template"), 0o644) + + got := FindTemplate(root) + assert.Equal(t, "UPPER template", got) +} From 7d9ee8c8a4dd32544753b49447bc43cf849a6b85 Mon Sep 17 00:00:00 2001 From: Sameen Karim Date: Tue, 5 May 2026 15:51:58 -0400 Subject: [PATCH 2/2] add a helper to clearly distinguish between filesystem errors and actual test failures --- cmd/link_test.go | 56 ------------------------------------ internal/pr/template_test.go | 33 ++++++++++++--------- 2 files changed, 20 insertions(+), 69 deletions(-) diff --git a/cmd/link_test.go b/cmd/link_test.go index b7b3e5f..4514de2 100644 --- a/cmd/link_test.go +++ b/cmd/link_test.go @@ -1206,62 +1206,6 @@ func TestLink_SkipsBaseFix_ForNewlyCreatedPRs(t *testing.T) { // Silence "imported and not used" for fmt in case test helpers use it. var _ = fmt.Sprintf -func TestLink_FetchesBeforePush(t *testing.T) { - var callOrder []string - var fetchedBranches []string - - mock := newLinkGitMock("feat-a", "feat-b") - mock.FetchBranchesFn = func(remote string, branches []string) error { - callOrder = append(callOrder, "fetch") - fetchedBranches = branches - assert.Equal(t, "origin", remote) - return nil - } - mock.PushFn = func(remote string, branches []string, force, atomic bool) error { - callOrder = append(callOrder, "push") - return nil - } - - restore := git.SetOps(mock) - defer restore() - - prNum := 0 - cfg, _, errR := config.NewTestConfig() - cfg.GitHubClientOverride = &github.MockClient{ - FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { - prNum++ - return &github.PullRequest{ - Number: prNum, - URL: fmt.Sprintf("https://github.com/o/r/pull/%d", prNum), - BaseRefName: "main", - HeadRefName: branch, - State: "OPEN", - }, nil - }, - ListStacksFn: func() ([]github.RemoteStack, error) { - return []github.RemoteStack{}, nil - }, - CreateStackFn: func(prNumbers []int) (int, error) { - return 42, nil - }, - } - - cmd := LinkCmd(cfg) - cmd.SetArgs([]string{"feat-a", "feat-b"}) - cmd.SetOut(io.Discard) - cmd.SetErr(io.Discard) - err := cmd.Execute() - - cfg.Err.Close() - _, _ = io.ReadAll(errR) - - assert.NoError(t, err) - assert.Equal(t, []string{"feat-a", "feat-b"}, fetchedBranches, "should fetch pushed branches") - require.Len(t, callOrder, 2) - assert.Equal(t, "fetch", callOrder[0], "fetch must happen before push") - assert.Equal(t, "push", callOrder[1]) -} - func TestLink_BranchNames_UsesPRTemplate(t *testing.T) { tmpDir := t.TempDir() ghDir := filepath.Join(tmpDir, ".github") diff --git a/internal/pr/template_test.go b/internal/pr/template_test.go index c55ec1d..3f4ca7c 100644 --- a/internal/pr/template_test.go +++ b/internal/pr/template_test.go @@ -8,11 +8,22 @@ import ( "github.com/stretchr/testify/assert" ) +// writeTemplate is a test helper that creates a file with the given content, +// creating parent directories as needed. It calls t.Fatal on any error so +// that setup failures are clearly distinguished from feature failures. +func writeTemplate(t *testing.T, path string, content []byte) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("test setup: MkdirAll: %v", err) + } + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatalf("test setup: WriteFile: %v", err) + } +} + func TestFindTemplate_GitHubDir(t *testing.T) { root := t.TempDir() - dir := filepath.Join(root, ".github") - os.MkdirAll(dir, 0o755) - os.WriteFile(filepath.Join(dir, "pull_request_template.md"), []byte("## Description\n\nFill in details."), 0o644) + writeTemplate(t, filepath.Join(root, ".github", "pull_request_template.md"), []byte("## Description\n\nFill in details.")) got := FindTemplate(root) assert.Equal(t, "## Description\n\nFill in details.", got) @@ -20,7 +31,7 @@ func TestFindTemplate_GitHubDir(t *testing.T) { func TestFindTemplate_RootDir(t *testing.T) { root := t.TempDir() - os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte("Root template"), 0o644) + writeTemplate(t, filepath.Join(root, "pull_request_template.md"), []byte("Root template")) got := FindTemplate(root) assert.Equal(t, "Root template", got) @@ -28,9 +39,7 @@ func TestFindTemplate_RootDir(t *testing.T) { func TestFindTemplate_DocsDir(t *testing.T) { root := t.TempDir() - dir := filepath.Join(root, "docs") - os.MkdirAll(dir, 0o755) - os.WriteFile(filepath.Join(dir, "PULL_REQUEST_TEMPLATE.md"), []byte("Docs template"), 0o644) + writeTemplate(t, filepath.Join(root, "docs", "PULL_REQUEST_TEMPLATE.md"), []byte("Docs template")) got := FindTemplate(root) assert.Equal(t, "Docs template", got) @@ -38,10 +47,8 @@ func TestFindTemplate_DocsDir(t *testing.T) { func TestFindTemplate_PriorityOrder(t *testing.T) { root := t.TempDir() - ghDir := filepath.Join(root, ".github") - os.MkdirAll(ghDir, 0o755) - os.WriteFile(filepath.Join(ghDir, "pull_request_template.md"), []byte("github template"), 0o644) - os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte("root template"), 0o644) + writeTemplate(t, filepath.Join(root, ".github", "pull_request_template.md"), []byte("github template")) + writeTemplate(t, filepath.Join(root, "pull_request_template.md"), []byte("root template")) got := FindTemplate(root) assert.Equal(t, "github template", got) @@ -56,7 +63,7 @@ func TestFindTemplate_NoTemplate(t *testing.T) { func TestFindTemplate_EmptyFile(t *testing.T) { root := t.TempDir() - os.WriteFile(filepath.Join(root, "pull_request_template.md"), []byte(" \n "), 0o644) + writeTemplate(t, filepath.Join(root, "pull_request_template.md"), []byte(" \n ")) got := FindTemplate(root) assert.Equal(t, "", got, "empty/whitespace-only template should be treated as no template") @@ -64,7 +71,7 @@ func TestFindTemplate_EmptyFile(t *testing.T) { func TestFindTemplate_UpperCase(t *testing.T) { root := t.TempDir() - os.WriteFile(filepath.Join(root, "PULL_REQUEST_TEMPLATE.md"), []byte("UPPER template"), 0o644) + writeTemplate(t, filepath.Join(root, "PULL_REQUEST_TEMPLATE.md"), []byte("UPPER template")) got := FindTemplate(root) assert.Equal(t, "UPPER template", got)