Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ func runPush(cfg *config.Config, opts *pushOptions) error {
cfg.Printf("No active branches to push (all merged or queued)")
return nil
}
// Best-effort fetch to update tracking refs (helps --force-with-lease
// in shallow clones). Silently ignored if branches don't exist on the
// remote yet.
_ = git.FetchBranches(remote, activeBranches)
cfg.Printf("Pushing %d %s to %s...", len(activeBranches), plural(len(activeBranches), "branch", "branches"), remote)
if err := git.Push(remote, activeBranches, true, true); err != nil {
if err := git.Push(remote, activeBranches, true, false); err != nil {
Comment thread
skarim marked this conversation as resolved.
cfg.Errorf("failed to push: %s", err)
return ErrSilent
}
Expand Down
85 changes: 84 additions & 1 deletion cmd/push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestPush_PushesAllBranches(t *testing.T) {
assert.Equal(t, "origin", pushCalls[0].remote)
assert.Equal(t, []string{"b1", "b2"}, pushCalls[0].branches)
assert.True(t, pushCalls[0].force)
assert.True(t, pushCalls[0].atomic)
assert.False(t, pushCalls[0].atomic)
assert.Contains(t, output, "Pushed 2 branches")
assert.Contains(t, output, "gh stack submit", "should hint about submit when branches have no PRs")
}
Expand Down Expand Up @@ -182,6 +182,89 @@ func TestPush_PushFailure(t *testing.T) {
assert.Contains(t, output, "failed to push")
}

func TestPush_FetchesBeforePush(t *testing.T) {
s := stack.Stack{
Trunk: stack.BranchRef{Branch: "main"},
Branches: []stack.BranchRef{
{Branch: "b1"},
{Branch: "b2"},
},
}

tmpDir := t.TempDir()
writeStackFile(t, tmpDir, s)

var callOrder []string

mock := newPushMock(tmpDir, "b1")
mock.FetchBranchesFn = func(remote string, branches []string) error {
callOrder = append(callOrder, "fetch")
assert.Equal(t, "origin", remote)
assert.Equal(t, []string{"b1", "b2"}, branches)
return nil
}
mock.PushFn = func(remote string, branches []string, force, atomic bool) error {
callOrder = append(callOrder, "push")
return nil
}

restore := git.SetOps(mock)
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := PushCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
_, _ = io.ReadAll(errR)

assert.NoError(t, err)
assert.Equal(t, []string{"fetch", "push"}, callOrder, "fetch must happen before push")
}

func TestPush_FetchFailureIsNonFatal(t *testing.T) {
s := stack.Stack{
Trunk: stack.BranchRef{Branch: "main"},
Branches: []stack.BranchRef{
{Branch: "b1"},
},
}

tmpDir := t.TempDir()
writeStackFile(t, tmpDir, s)

pushCalled := false

mock := newPushMock(tmpDir, "b1")
mock.FetchBranchesFn = func(string, []string) error {
return fmt.Errorf("network error")
}
mock.PushFn = func(string, []string, bool, bool) error {
pushCalled = true
return nil
}

restore := git.SetOps(mock)
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{}
cmd := PushCmd(cfg)
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
errOut, _ := io.ReadAll(errR)

assert.NoError(t, err, "fetch failure should not abort push")
assert.True(t, pushCalled, "push should proceed after fetch failure")
assert.NotContains(t, string(errOut), "Failed to fetch", "fetch failure should be silent")
}

func TestPush_DoesNotCreatePRs(t *testing.T) {
s := stack.Stack{
Trunk: stack.BranchRef{Branch: "main"},
Expand Down
9 changes: 8 additions & 1 deletion cmd/submit.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error {
return ErrNotInStack
}

cfg.Printf("Checking stack state...")

// Find the stack for the current branch without switching branches.
// Submit should never change the user's checked-out branch.
stacks := sf.FindAllStacksForBranch(currentBranch)
Expand Down Expand Up @@ -129,7 +131,7 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error {
return nil
}

// If a modification is pending, delete the old remote stack first so that
// If a modification is pending, delete the old remote stack first so that
// PR base updates are allowed and force-pushes don't trigger auto-merges.
if stacksAvailable {
if err := handlePendingModify(cfg, client, s, gitDir); err != nil {
Expand All @@ -141,6 +143,11 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error {
}
}

// Best-effort fetch to update tracking refs (helps --force-with-lease
// in shallow clones). Silently ignored if branches don't exist on the
// remote yet.
_ = git.FetchBranches(remote, activeBranches)

// Push each branch and create/update its PR in stack order (bottom to top).
// Sequential pushing ensures each branch's base is up-to-date on the
// remote before the next branch is pushed, preventing race conditions.
Expand Down
65 changes: 65 additions & 0 deletions cmd/submit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1370,3 +1370,68 @@ func TestSubmit_WithPendingModify_SequentialPush(t *testing.T) {
// State file should be cleared
assert.False(t, modify.StateExists(tmpDir), "modify state file should be cleared after success")
}

func TestSubmit_FetchesBeforePush(t *testing.T) {
s := stack.Stack{
Trunk: stack.BranchRef{Branch: "main"},
Branches: []stack.BranchRef{
{Branch: "b1"},
{Branch: "b2"},
},
}

tmpDir := t.TempDir()
writeStackFile(t, tmpDir, s)

var callOrder []string
var fetchedBranches []string

mock := newSubmitMock(tmpDir, "b1")
mock.FetchBranchesFn = func(remote string, branches []string) error {
callOrder = append(callOrder, "fetch")
fetchedBranches = branches
assert.Equal(t, "origin", remote)
return nil
}
mock.PushFn = func(remote string, branches []string, force, atomic bool) error {
callOrder = append(callOrder, "push")
return nil
}

restore := git.SetOps(mock)
defer restore()

cfg, _, errR := config.NewTestConfig()
cfg.GitHubClientOverride = &github.MockClient{
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
return &github.PullRequest{
Number: 1,
URL: "https://github.com/o/r/pull/1",
BaseRefName: "main",
HeadRefName: branch,
State: "OPEN",
}, nil
},
ListStacksFn: func() ([]github.RemoteStack, error) {
return []github.RemoteStack{}, nil
},
CreateStackFn: func(prNumbers []int) (int, error) {
return 42, nil
},
}

cmd := SubmitCmd(cfg)
cmd.SetArgs([]string{"--auto"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()

cfg.Err.Close()
_, _ = io.ReadAll(errR)

assert.NoError(t, err)
assert.Equal(t, []string{"b1", "b2"}, fetchedBranches, "should fetch active branches")
// fetch must come before all pushes
require.True(t, len(callOrder) >= 3, "expected at least 3 calls (fetch + 2 pushes)")
assert.Equal(t, "fetch", callOrder[0], "fetch must happen before any push")
}
10 changes: 5 additions & 5 deletions cmd/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func runSync(cfg *config.Config, opts *syncOptions) error {
return ErrSilent
}

if err := git.Fetch(remote); err != nil {
cfg.Warningf("Failed to fetch %s: %v", remote, err)
} else {
cfg.Successf("Fetched latest changes from %s", remote)
}
// Fetch trunk + active branches so tracking refs are current for
// fast-forward detection (Step 2) and --force-with-lease (Step 4).
fetchTargets := append([]string{s.Trunk.Branch}, activeBranchNames(s)...)
_ = git.FetchBranches(remote, fetchTargets)
cfg.Successf("Fetched latest changes from %s", remote)

// --- Step 2: Fast-forward trunk ---
trunk := s.Trunk.Branch
Expand Down
6 changes: 6 additions & 0 deletions internal/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ func Fetch(remote string) error {
return ops.Fetch(remote)
}

// FetchBranches fetches specific branches from a remote,
// updating their tracking refs.
func FetchBranches(remote string, branches []string) error {
return ops.FetchBranches(remote, branches)
}

// DefaultBranch returns the HEAD branch from origin.
func DefaultBranch() (string, error) {
return ops.DefaultBranch()
Expand Down
28 changes: 28 additions & 0 deletions internal/git/gitops.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Ops interface {
BranchExists(name string) bool
CheckoutBranch(name string) error
Fetch(remote string) error
FetchBranches(remote string, branches []string) error
DefaultBranch() (string, error)
CreateBranch(name, base string) error
Push(remote string, branches []string, force, atomic bool) error
Expand Down Expand Up @@ -106,6 +107,33 @@ func (d *defaultOps) Fetch(remote string) error {
return client.Fetch(context.Background(), remote, "")
}

func (d *defaultOps) FetchBranches(remote string, branches []string) error {
// Only fetch branches that already have a remote tracking ref.
var tracked []string
for _, b := range branches {
ref := fmt.Sprintf("refs/remotes/%s/%s", remote, b)
if err := runSilent("rev-parse", "--verify", "--quiet", ref); err == nil {
tracked = append(tracked, b)
}
}
if len(tracked) == 0 {
return nil
}
// Fast path: fetch all tracked branches in a single call.
args := []string{"fetch", remote}
args = append(args, tracked...)
if err := runSilent(args...); err == nil {
return nil
}
// Fallback: a ref may have been deleted on the remote while the
// local tracking ref still exists. Fetch branches individually so
// one missing ref doesn't block the others.
for _, b := range tracked {
_ = runSilent("fetch", remote, b)
}
return nil
}

func (d *defaultOps) DefaultBranch() (string, error) {
ref, err := run("symbolic-ref", "refs/remotes/origin/HEAD")
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions internal/git/mock_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type MockOps struct {
BranchExistsFn func(string) bool
CheckoutBranchFn func(string) error
FetchFn func(string) error
FetchBranchesFn func(string, []string) error
DefaultBranchFn func() (string, error)
CreateBranchFn func(string, string) error
PushFn func(string, []string, bool, bool) error
Expand Down Expand Up @@ -87,6 +88,13 @@ func (m *MockOps) Fetch(remote string) error {
return nil
}

func (m *MockOps) FetchBranches(remote string, branches []string) error {
if m.FetchBranchesFn != nil {
return m.FetchBranchesFn(remote, branches)
}
return nil
}

func (m *MockOps) DefaultBranch() (string, error) {
if m.DefaultBranchFn != nil {
return m.DefaultBranchFn()
Expand Down
Loading