Skip to content

Commit cb7f9fa

Browse files
committed
update pr base if not already part of a stack
1 parent bd76d31 commit cb7f9fa

5 files changed

Lines changed: 341 additions & 11 deletions

File tree

cmd/submit.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error {
100100
return ErrSilent
101101
}
102102

103-
// Create or update PRs
103+
// Create or update PRs — ensure every active branch has a PR with the
104+
// correct base branch. This makes submit idempotent: running it again
105+
// fills gaps and fixes base branches before syncing the stack.
104106
for i, b := range s.Branches {
105107
if s.Branches[i].IsMerged() {
106108
continue
@@ -154,14 +156,33 @@ func runSubmit(cfg *config.Config, opts *submitOptions) error {
154156
URL: newPR.URL,
155157
}
156158
} else {
157-
cfg.Printf("PR %s for %s is up to date", cfg.PRLink(pr.Number, pr.URL), b.Branch)
159+
// PR already exists — record it and fix base branch if needed.
158160
if s.Branches[i].PullRequest == nil {
159161
s.Branches[i].PullRequest = &stack.PullRequestRef{
160162
Number: pr.Number,
161163
ID: pr.ID,
162164
URL: pr.URL,
163165
}
164166
}
167+
168+
if pr.BaseRefName != baseBranch {
169+
if s.ID != "" {
170+
// PRs in an existing stack can't have their base updated
171+
// via the API — the stack owns the base relationships.
172+
cfg.Warningf("PR %s has base %q (expected %q) but cannot update while stacked",
173+
cfg.PRLink(pr.Number, pr.URL), pr.BaseRefName, baseBranch)
174+
} else {
175+
if err := client.UpdatePRBase(pr.Number, baseBranch); err != nil {
176+
cfg.Warningf("failed to update base branch for PR %s: %v",
177+
cfg.PRLink(pr.Number, pr.URL), err)
178+
} else {
179+
cfg.Successf("Updated base branch for PR %s to %s",
180+
cfg.PRLink(pr.Number, pr.URL), baseBranch)
181+
}
182+
}
183+
} else {
184+
cfg.Printf("PR %s for %s is up to date", cfg.PRLink(pr.Number, pr.URL), b.Branch)
185+
}
165186
}
166187
}
167188

@@ -291,7 +312,7 @@ func createNewStack(cfg *config.Config, client github.ClientOps, s *stack.Stack,
291312

292313
switch httpErr.StatusCode {
293314
case 422:
294-
handleCreate422(cfg, httpErr)
315+
handleCreate422(cfg, httpErr, prNumbers)
295316
case 404:
296317
cfg.Warningf("Stacked PRs are not yet available for this repository")
297318
default:
@@ -304,11 +325,18 @@ func createNewStack(cfg *config.Config, client github.ClientOps, s *stack.Stack,
304325
// - "Stack must contain at least two pull requests"
305326
// - "Pull requests must form a stack, where each PR's base ref is the previous PR's head ref"
306327
// - "Pull requests #123, #124, #125 are already stacked"
307-
func handleCreate422(cfg *config.Config, httpErr *api.HTTPError) {
328+
func handleCreate422(cfg *config.Config, httpErr *api.HTTPError, prNumbers []int) {
308329
msg := httpErr.Message
309330

310331
if strings.Contains(msg, "already stacked") {
311-
cfg.Warningf("One or more PRs are already part of an existing stack on GitHub")
332+
// Check if the error lists exactly the same PRs we're trying to
333+
// stack. If so, they're already in a stack together — nothing to do.
334+
// If only a subset matches, the PRs are in a different stack.
335+
if allPRsInMessage(msg, prNumbers) {
336+
cfg.Successf("Stack with %d PRs is up to date", len(prNumbers))
337+
return
338+
}
339+
cfg.Warningf("One or more PRs are already part of a different stack on GitHub")
312340
cfg.Printf(" To fix this, unstack the PRs from the web, then `%s`",
313341
cfg.ColorCyan("gh stack submit"))
314342
return
@@ -323,3 +351,15 @@ func handleCreate422(cfg *config.Config, httpErr *api.HTTPError) {
323351
// "at least two" or any other validation error
324352
cfg.Warningf("Could not create stack: %s", msg)
325353
}
354+
355+
// allPRsInMessage checks whether every PR number in prNumbers appears
356+
// in the error message (e.g. as "#65"). This distinguishes "our PRs are
357+
// already stacked together" from "some PRs are in a different stack."
358+
func allPRsInMessage(msg string, prNumbers []int) bool {
359+
for _, n := range prNumbers {
360+
if !strings.Contains(msg, fmt.Sprintf("#%d", n)) {
361+
return false
362+
}
363+
}
364+
return true
365+
}

cmd/submit_test.go

Lines changed: 269 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ func TestSyncStack_ExistingStack_Update404(t *testing.T) {
397397
assert.Contains(t, output, "Stack created on GitHub with 2 PRs")
398398
}
399399

400-
func TestSyncStack_AlreadyStacked_422(t *testing.T) {
400+
func TestSyncStack_AlreadyStacked_OurStack(t *testing.T) {
401+
// All our PRs are listed as "already stacked" — this is our stack, show up-to-date.
401402
s := &stack.Stack{
402403
Trunk: stack.BranchRef{Branch: "main"},
403404
Branches: []stack.BranchRef{
@@ -423,7 +424,40 @@ func TestSyncStack_AlreadyStacked_422(t *testing.T) {
423424
errOut, _ := io.ReadAll(errR)
424425
output := string(errOut)
425426

426-
assert.Contains(t, output, "already part of an existing stack")
427+
assert.Contains(t, output, "Stack with 2 PRs is up to date")
428+
assert.NotContains(t, output, "different stack")
429+
}
430+
431+
func TestSyncStack_AlreadyStacked_DifferentStack(t *testing.T) {
432+
// Only a subset of our PRs are listed — they're in a different stack.
433+
s := &stack.Stack{
434+
Trunk: stack.BranchRef{Branch: "main"},
435+
Branches: []stack.BranchRef{
436+
{Branch: "b1", PullRequest: &stack.PullRequestRef{Number: 10}},
437+
{Branch: "b2", PullRequest: &stack.PullRequestRef{Number: 11}},
438+
{Branch: "b3", PullRequest: &stack.PullRequestRef{Number: 12}},
439+
},
440+
}
441+
442+
mock := &github.MockClient{
443+
CreateStackFn: func([]int) (int, error) {
444+
return 0, &api.HTTPError{
445+
StatusCode: 422,
446+
Message: "Pull requests #10, #11 are already stacked",
447+
RequestURL: &url.URL{Path: "/repos/o/r/cli_internal/pulls/stacks"},
448+
}
449+
},
450+
}
451+
452+
cfg, _, errR := config.NewTestConfig()
453+
syncStack(cfg, mock, s)
454+
455+
cfg.Err.Close()
456+
errOut, _ := io.ReadAll(errR)
457+
output := string(errOut)
458+
459+
assert.Contains(t, output, "different stack")
460+
assert.NotContains(t, output, "up to date")
427461
}
428462

429463
func TestSyncStack_InvalidChain_422(t *testing.T) {
@@ -544,7 +578,7 @@ func TestSyncStack_SkipsBranchesWithoutPR(t *testing.T) {
544578
Trunk: stack.BranchRef{Branch: "main"},
545579
Branches: []stack.BranchRef{
546580
{Branch: "b1", PullRequest: &stack.PullRequestRef{Number: 10}},
547-
{Branch: "b2"}, // no PR yet
581+
{Branch: "b2"}, // no PR — skipped
548582
{Branch: "b3", PullRequest: &stack.PullRequestRef{Number: 12}},
549583
},
550584
}
@@ -563,3 +597,235 @@ func TestSyncStack_SkipsBranchesWithoutPR(t *testing.T) {
563597

564598
assert.Equal(t, []int{10, 12}, gotNumbers, "should skip branches without PRs")
565599
}
600+
601+
func TestSubmit_UpdatesBaseBranch(t *testing.T) {
602+
// b1's PR has base "main" but it should be "main" (correct).
603+
// b2's PR has base "main" but it should be "b1" (wrong — needs update).
604+
s := stack.Stack{
605+
Trunk: stack.BranchRef{Branch: "main"},
606+
Branches: []stack.BranchRef{
607+
{Branch: "b1", PullRequest: &stack.PullRequestRef{Number: 10}},
608+
{Branch: "b2", PullRequest: &stack.PullRequestRef{Number: 11}},
609+
},
610+
}
611+
612+
tmpDir := t.TempDir()
613+
writeStackFile(t, tmpDir, s)
614+
615+
mock := newSubmitMock(tmpDir, "b1")
616+
617+
restore := git.SetOps(mock)
618+
defer restore()
619+
620+
var updatedPRs []struct {
621+
number int
622+
base string
623+
}
624+
625+
cfg, _, errR := config.NewTestConfig()
626+
cfg.GitHubClientOverride = &github.MockClient{
627+
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
628+
switch branch {
629+
case "b1":
630+
return &github.PullRequest{
631+
Number: 10, ID: "PR_10",
632+
URL: "https://github.com/owner/repo/pull/10",
633+
BaseRefName: "main", HeadRefName: "b1",
634+
}, nil
635+
case "b2":
636+
return &github.PullRequest{
637+
Number: 11, ID: "PR_11",
638+
URL: "https://github.com/owner/repo/pull/11",
639+
BaseRefName: "main", HeadRefName: "b2", // wrong base
640+
}, nil
641+
}
642+
return nil, nil
643+
},
644+
UpdatePRBaseFn: func(number int, base string) error {
645+
updatedPRs = append(updatedPRs, struct {
646+
number int
647+
base string
648+
}{number, base})
649+
return nil
650+
},
651+
CreateStackFn: func(prNumbers []int) (int, error) {
652+
return 42, nil
653+
},
654+
}
655+
656+
cmd := SubmitCmd(cfg)
657+
cmd.SetArgs([]string{"--auto"})
658+
cmd.SetOut(io.Discard)
659+
cmd.SetErr(io.Discard)
660+
err := cmd.Execute()
661+
662+
cfg.Err.Close()
663+
errOut, _ := io.ReadAll(errR)
664+
output := string(errOut)
665+
666+
assert.NoError(t, err)
667+
// b1's base is "main" which is correct — no update.
668+
// b2's base is "main" but should be "b1" — should be updated.
669+
require.Len(t, updatedPRs, 1)
670+
assert.Equal(t, 11, updatedPRs[0].number)
671+
assert.Equal(t, "b1", updatedPRs[0].base)
672+
assert.Contains(t, output, "Updated base branch for PR")
673+
}
674+
675+
func TestSubmit_SkipsBaseUpdateWhenStacked(t *testing.T) {
676+
// Stack already exists (s.ID is set), so base updates should be skipped.
677+
s := stack.Stack{
678+
ID: "99",
679+
Trunk: stack.BranchRef{Branch: "main"},
680+
Branches: []stack.BranchRef{
681+
{Branch: "b1", PullRequest: &stack.PullRequestRef{Number: 10}},
682+
{Branch: "b2", PullRequest: &stack.PullRequestRef{Number: 11}},
683+
},
684+
}
685+
686+
tmpDir := t.TempDir()
687+
writeStackFile(t, tmpDir, s)
688+
689+
mock := newSubmitMock(tmpDir, "b1")
690+
691+
restore := git.SetOps(mock)
692+
defer restore()
693+
694+
updateCalled := false
695+
cfg, _, errR := config.NewTestConfig()
696+
cfg.GitHubClientOverride = &github.MockClient{
697+
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
698+
switch branch {
699+
case "b1":
700+
return &github.PullRequest{
701+
Number: 10, ID: "PR_10",
702+
URL: "https://github.com/owner/repo/pull/10",
703+
BaseRefName: "main", HeadRefName: "b1",
704+
}, nil
705+
case "b2":
706+
return &github.PullRequest{
707+
Number: 11, ID: "PR_11",
708+
URL: "https://github.com/owner/repo/pull/11",
709+
BaseRefName: "main", HeadRefName: "b2", // wrong base
710+
}, nil
711+
}
712+
return nil, nil
713+
},
714+
UpdatePRBaseFn: func(number int, base string) error {
715+
updateCalled = true
716+
return nil
717+
},
718+
UpdateStackFn: func(stackID string, prNumbers []int) error {
719+
return nil
720+
},
721+
}
722+
723+
cmd := SubmitCmd(cfg)
724+
cmd.SetArgs([]string{"--auto"})
725+
cmd.SetOut(io.Discard)
726+
cmd.SetErr(io.Discard)
727+
err := cmd.Execute()
728+
729+
cfg.Err.Close()
730+
errOut, _ := io.ReadAll(errR)
731+
output := string(errOut)
732+
733+
assert.NoError(t, err)
734+
assert.False(t, updateCalled, "should not call UpdatePRBase when stack exists")
735+
assert.Contains(t, output, "cannot update while stacked")
736+
}
737+
738+
func TestSubmit_CreatesMissingPRsAndUpdatesExisting(t *testing.T) {
739+
// b1 has a PR, b2 does not, b3 has a PR with wrong base.
740+
// Submit should create b2's PR and fix b3's base.
741+
s := stack.Stack{
742+
Trunk: stack.BranchRef{Branch: "main"},
743+
Branches: []stack.BranchRef{
744+
{Branch: "b1", PullRequest: &stack.PullRequestRef{Number: 10}},
745+
{Branch: "b2"},
746+
{Branch: "b3", PullRequest: &stack.PullRequestRef{Number: 12}},
747+
},
748+
}
749+
750+
tmpDir := t.TempDir()
751+
writeStackFile(t, tmpDir, s)
752+
753+
mock := newSubmitMock(tmpDir, "b1")
754+
mock.LogRangeFn = func(base, head string) ([]git.CommitInfo, error) {
755+
return []git.CommitInfo{{Subject: "commit for " + head}}, nil
756+
}
757+
758+
restore := git.SetOps(mock)
759+
defer restore()
760+
761+
var createdPRs []string
762+
var updatedBases []struct {
763+
number int
764+
base string
765+
}
766+
767+
cfg, _, errR := config.NewTestConfig()
768+
cfg.GitHubClientOverride = &github.MockClient{
769+
FindPRForBranchFn: func(branch string) (*github.PullRequest, error) {
770+
switch branch {
771+
case "b1":
772+
return &github.PullRequest{
773+
Number: 10, ID: "PR_10",
774+
URL: "https://github.com/owner/repo/pull/10",
775+
BaseRefName: "main", HeadRefName: "b1",
776+
}, nil
777+
case "b2":
778+
return nil, nil // no PR
779+
case "b3":
780+
return &github.PullRequest{
781+
Number: 12, ID: "PR_12",
782+
URL: "https://github.com/owner/repo/pull/12",
783+
BaseRefName: "main", HeadRefName: "b3", // wrong base — should be b2
784+
}, nil
785+
}
786+
return nil, nil
787+
},
788+
CreatePRFn: func(base, head, title, body string, draft bool) (*github.PullRequest, error) {
789+
createdPRs = append(createdPRs, head)
790+
return &github.PullRequest{
791+
Number: 11, ID: "PR_11",
792+
URL: "https://github.com/owner/repo/pull/11",
793+
}, nil
794+
},
795+
UpdatePRBaseFn: func(number int, base string) error {
796+
updatedBases = append(updatedBases, struct {
797+
number int
798+
base string
799+
}{number, base})
800+
return nil
801+
},
802+
CreateStackFn: func(prNumbers []int) (int, error) {
803+
return 42, nil
804+
},
805+
}
806+
807+
cmd := SubmitCmd(cfg)
808+
cmd.SetArgs([]string{"--auto"})
809+
cmd.SetOut(io.Discard)
810+
cmd.SetErr(io.Discard)
811+
err := cmd.Execute()
812+
813+
cfg.Err.Close()
814+
errOut, _ := io.ReadAll(errR)
815+
output := string(errOut)
816+
817+
assert.NoError(t, err)
818+
819+
// b2 should have been created
820+
assert.Equal(t, []string{"b2"}, createdPRs)
821+
assert.Contains(t, output, "Created PR")
822+
823+
// b3's base should have been updated from "main" to "b2"
824+
require.Len(t, updatedBases, 1)
825+
assert.Equal(t, 12, updatedBases[0].number)
826+
assert.Equal(t, "b2", updatedBases[0].base)
827+
assert.Contains(t, output, "Updated base branch for PR")
828+
829+
// Stack should be created with all 3 PRs
830+
assert.Contains(t, output, "Stack created on GitHub with 3 PRs")
831+
}

internal/github/client_interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ type ClientOps interface {
88
FindAnyPRForBranch(branch string) (*PullRequest, error)
99
FindPRDetailsForBranch(branch string) (*PRDetails, error)
1010
CreatePR(base, head, title, body string, draft bool) (*PullRequest, error)
11+
UpdatePRBase(number int, base string) error
1112
CreateStack(prNumbers []int) (int, error)
1213
UpdateStack(stackID string, prNumbers []int) error
1314
}

0 commit comments

Comments
 (0)