diff --git a/cmd/push.go b/cmd/push.go index 1f0f857..a0b809b 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -93,8 +93,12 @@ func runPush(cfg *config.Config, opts *pushOptions) error { cfg.Printf("No active branches to push (all merged or queued)") return nil } + // Best-effort fetch to update tracking refs (helps --force-with-lease + // in shallow clones). Silently ignored if branches don't exist on the + // remote yet. + _ = git.FetchBranches(remote, activeBranches) cfg.Printf("Pushing %d %s to %s...", len(activeBranches), plural(len(activeBranches), "branch", "branches"), remote) - if err := git.Push(remote, activeBranches, true, true); err != nil { + if err := git.Push(remote, activeBranches, true, false); err != nil { cfg.Errorf("failed to push: %s", err) return ErrSilent } diff --git a/cmd/push_test.go b/cmd/push_test.go index ff70e57..d44da1a 100644 --- a/cmd/push_test.go +++ b/cmd/push_test.go @@ -62,7 +62,7 @@ func TestPush_PushesAllBranches(t *testing.T) { assert.Equal(t, "origin", pushCalls[0].remote) assert.Equal(t, []string{"b1", "b2"}, pushCalls[0].branches) assert.True(t, pushCalls[0].force) - assert.True(t, pushCalls[0].atomic) + assert.False(t, pushCalls[0].atomic) assert.Contains(t, output, "Pushed 2 branches") assert.Contains(t, output, "gh stack submit", "should hint about submit when branches have no PRs") } @@ -182,6 +182,89 @@ func TestPush_PushFailure(t *testing.T) { assert.Contains(t, output, "failed to push") } +func TestPush_FetchesBeforePush(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + {Branch: "b2"}, + }, + } + + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + var callOrder []string + + mock := newPushMock(tmpDir, "b1") + mock.FetchBranchesFn = func(remote string, branches []string) error { + callOrder = append(callOrder, "fetch") + assert.Equal(t, "origin", remote) + assert.Equal(t, []string{"b1", "b2"}, branches) + return nil + } + mock.PushFn = func(remote string, branches []string, force, atomic bool) error { + callOrder = append(callOrder, "push") + return nil + } + + restore := git.SetOps(mock) + defer restore() + + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} + cmd := PushCmd(cfg) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + _, _ = io.ReadAll(errR) + + assert.NoError(t, err) + assert.Equal(t, []string{"fetch", "push"}, callOrder, "fetch must happen before push") +} + +func TestPush_FetchFailureIsNonFatal(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + }, + } + + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + pushCalled := false + + mock := newPushMock(tmpDir, "b1") + mock.FetchBranchesFn = func(string, []string) error { + return fmt.Errorf("network error") + } + mock.PushFn = func(string, []string, bool, bool) error { + pushCalled = true + return nil + } + + restore := git.SetOps(mock) + defer restore() + + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{} + cmd := PushCmd(cfg) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + errOut, _ := io.ReadAll(errR) + + assert.NoError(t, err, "fetch failure should not abort push") + assert.True(t, pushCalled, "push should proceed after fetch failure") + assert.NotContains(t, string(errOut), "Failed to fetch", "fetch failure should be silent") +} + func TestPush_DoesNotCreatePRs(t *testing.T) { s := stack.Stack{ Trunk: stack.BranchRef{Branch: "main"}, diff --git a/cmd/submit.go b/cmd/submit.go index a830177..5c4c0a3 100644 --- a/cmd/submit.go +++ b/cmd/submit.go @@ -59,6 +59,8 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { return ErrNotInStack } + cfg.Printf("Checking stack state...") + // Find the stack for the current branch without switching branches. // Submit should never change the user's checked-out branch. stacks := sf.FindAllStacksForBranch(currentBranch) @@ -129,7 +131,7 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { return nil } - // If a modification is pending, delete the old remote stack first so that + // If a modification is pending, delete the old remote stack first so that // PR base updates are allowed and force-pushes don't trigger auto-merges. if stacksAvailable { if err := handlePendingModify(cfg, client, s, gitDir); err != nil { @@ -141,6 +143,11 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error { } } + // Best-effort fetch to update tracking refs (helps --force-with-lease + // in shallow clones). Silently ignored if branches don't exist on the + // remote yet. + _ = git.FetchBranches(remote, activeBranches) + // Push each branch and create/update its PR in stack order (bottom to top). // Sequential pushing ensures each branch's base is up-to-date on the // remote before the next branch is pushed, preventing race conditions. diff --git a/cmd/submit_test.go b/cmd/submit_test.go index 2b0e40d..2ffe5d3 100644 --- a/cmd/submit_test.go +++ b/cmd/submit_test.go @@ -1370,3 +1370,68 @@ func TestSubmit_WithPendingModify_SequentialPush(t *testing.T) { // State file should be cleared assert.False(t, modify.StateExists(tmpDir), "modify state file should be cleared after success") } + +func TestSubmit_FetchesBeforePush(t *testing.T) { + s := stack.Stack{ + Trunk: stack.BranchRef{Branch: "main"}, + Branches: []stack.BranchRef{ + {Branch: "b1"}, + {Branch: "b2"}, + }, + } + + tmpDir := t.TempDir() + writeStackFile(t, tmpDir, s) + + var callOrder []string + var fetchedBranches []string + + mock := newSubmitMock(tmpDir, "b1") + mock.FetchBranchesFn = func(remote string, branches []string) error { + callOrder = append(callOrder, "fetch") + fetchedBranches = branches + assert.Equal(t, "origin", remote) + return nil + } + mock.PushFn = func(remote string, branches []string, force, atomic bool) error { + callOrder = append(callOrder, "push") + return nil + } + + restore := git.SetOps(mock) + defer restore() + + cfg, _, errR := config.NewTestConfig() + cfg.GitHubClientOverride = &github.MockClient{ + FindPRForBranchFn: func(branch string) (*github.PullRequest, error) { + return &github.PullRequest{ + Number: 1, + URL: "https://github.com/o/r/pull/1", + BaseRefName: "main", + HeadRefName: branch, + State: "OPEN", + }, nil + }, + ListStacksFn: func() ([]github.RemoteStack, error) { + return []github.RemoteStack{}, nil + }, + CreateStackFn: func(prNumbers []int) (int, error) { + return 42, nil + }, + } + + cmd := SubmitCmd(cfg) + cmd.SetArgs([]string{"--auto"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + err := cmd.Execute() + + cfg.Err.Close() + _, _ = io.ReadAll(errR) + + assert.NoError(t, err) + assert.Equal(t, []string{"b1", "b2"}, fetchedBranches, "should fetch active branches") + // fetch must come before all pushes + require.True(t, len(callOrder) >= 3, "expected at least 3 calls (fetch + 2 pushes)") + assert.Equal(t, "fetch", callOrder[0], "fetch must happen before any push") +} diff --git a/cmd/sync.go b/cmd/sync.go index 34a4baf..866105c 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -76,11 +76,11 @@ func runSync(cfg *config.Config, opts *syncOptions) error { return ErrSilent } - if err := git.Fetch(remote); err != nil { - cfg.Warningf("Failed to fetch %s: %v", remote, err) - } else { - cfg.Successf("Fetched latest changes from %s", remote) - } + // Fetch trunk + active branches so tracking refs are current for + // fast-forward detection (Step 2) and --force-with-lease (Step 4). + fetchTargets := append([]string{s.Trunk.Branch}, activeBranchNames(s)...) + _ = git.FetchBranches(remote, fetchTargets) + cfg.Successf("Fetched latest changes from %s", remote) // --- Step 2: Fast-forward trunk --- trunk := s.Trunk.Branch diff --git a/internal/git/git.go b/internal/git/git.go index e9ed717..36e1325 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -124,6 +124,12 @@ func Fetch(remote string) error { return ops.Fetch(remote) } +// FetchBranches fetches specific branches from a remote, +// updating their tracking refs. +func FetchBranches(remote string, branches []string) error { + return ops.FetchBranches(remote, branches) +} + // DefaultBranch returns the HEAD branch from origin. func DefaultBranch() (string, error) { return ops.DefaultBranch() diff --git a/internal/git/gitops.go b/internal/git/gitops.go index e3c1870..3ed107f 100644 --- a/internal/git/gitops.go +++ b/internal/git/gitops.go @@ -21,6 +21,7 @@ type Ops interface { BranchExists(name string) bool CheckoutBranch(name string) error Fetch(remote string) error + FetchBranches(remote string, branches []string) error DefaultBranch() (string, error) CreateBranch(name, base string) error Push(remote string, branches []string, force, atomic bool) error @@ -106,6 +107,33 @@ func (d *defaultOps) Fetch(remote string) error { return client.Fetch(context.Background(), remote, "") } +func (d *defaultOps) FetchBranches(remote string, branches []string) error { + // Only fetch branches that already have a remote tracking ref. + var tracked []string + for _, b := range branches { + ref := fmt.Sprintf("refs/remotes/%s/%s", remote, b) + if err := runSilent("rev-parse", "--verify", "--quiet", ref); err == nil { + tracked = append(tracked, b) + } + } + if len(tracked) == 0 { + return nil + } + // Fast path: fetch all tracked branches in a single call. + args := []string{"fetch", remote} + args = append(args, tracked...) + if err := runSilent(args...); err == nil { + return nil + } + // Fallback: a ref may have been deleted on the remote while the + // local tracking ref still exists. Fetch branches individually so + // one missing ref doesn't block the others. + for _, b := range tracked { + _ = runSilent("fetch", remote, b) + } + return nil +} + func (d *defaultOps) DefaultBranch() (string, error) { ref, err := run("symbolic-ref", "refs/remotes/origin/HEAD") if err != nil { diff --git a/internal/git/mock_ops.go b/internal/git/mock_ops.go index 12ddf5b..07bfb75 100644 --- a/internal/git/mock_ops.go +++ b/internal/git/mock_ops.go @@ -9,6 +9,7 @@ type MockOps struct { BranchExistsFn func(string) bool CheckoutBranchFn func(string) error FetchFn func(string) error + FetchBranchesFn func(string, []string) error DefaultBranchFn func() (string, error) CreateBranchFn func(string, string) error PushFn func(string, []string, bool, bool) error @@ -87,6 +88,13 @@ func (m *MockOps) Fetch(remote string) error { return nil } +func (m *MockOps) FetchBranches(remote string, branches []string) error { + if m.FetchBranchesFn != nil { + return m.FetchBranchesFn(remote, branches) + } + return nil +} + func (m *MockOps) DefaultBranch() (string, error) { if m.DefaultBranchFn != nil { return m.DefaultBranchFn()