diff --git a/.claude/rules/testing.md b/.claude/rules/testing.md index 8eed2cd..0c32944 100644 --- a/.claude/rules/testing.md +++ b/.claude/rules/testing.md @@ -30,7 +30,46 @@ go test ./internal/services/... -run TestCreateSession ## Test Patterns -Follow the established patterns from existing tests: +### Table-Driven Tests (Preferred) + +Use table-driven tests for functions with multiple scenarios: + +```go +func TestFunctionName(t *testing.T) { + tests := []struct { + name string + input string + expected int + assertErr assert.ErrorAssertionFunc + }{ + { + name: "valid input", + input: "hello", + expected: 5, + assertErr: assert.NoError, + }, + { + name: "empty input returns error", + input: "", + expected: 0, + assertErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := FunctionName(tt.input) + + tt.assertErr(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} +``` + +### Single Scenario Tests + +For tests with mocks or complex setup: ```go func TestFunctionName_Scenario(t *testing.T) { diff --git a/internal/adapters/git/cli_repository.go b/internal/adapters/git/cli_repository.go index 7adef3f..8ff1d2c 100644 --- a/internal/adapters/git/cli_repository.go +++ b/internal/adapters/git/cli_repository.go @@ -125,6 +125,11 @@ func (r *CLIRepository) FetchGitStats(ctx context.Context, worktreePath string) // PRInfoProvider methods +// FetchAllPRs implements PRInfoProvider.FetchAllPRs +func (r *CLIRepository) FetchAllPRs(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error) { + return fetchAllPRs(ctx, repoPath) +} + // FetchPRInfo implements PRInfoProvider.FetchPRInfo func (r *CLIRepository) FetchPRInfo(ctx context.Context, worktreePath, branchName string) (*domain.PRInfo, error) { return fetchPRInfo(ctx, worktreePath, branchName) diff --git a/internal/adapters/git/pr.go b/internal/adapters/git/pr.go index f7f0bd9..54c0a6d 100644 --- a/internal/adapters/git/pr.go +++ b/internal/adapters/git/pr.go @@ -20,6 +20,14 @@ type ghPRResponse struct { URL string `json:"url"` } +// ghPRListResponse represents a single PR from gh pr list output +type ghPRListResponse struct { + HeadRefName string `json:"headRefName"` + Number int `json:"number"` + State string `json:"state"` + URL string `json:"url"` +} + // fetchPRInfo fetches PR information for a branch using gh CLI. // Returns (nil, nil) if gh CLI is not installed. // Returns (PRInfo with Number=0, nil) if no PR exists for the branch. @@ -74,6 +82,54 @@ func fetchPRInfo(ctx context.Context, worktreePath, branchName string) (*domain. }, nil } +// fetchAllPRs fetches all PRs for a repository in one call. +// Returns map[branchName]*PRInfo where branchName is the head branch of the PR. +// Returns (nil, nil) if gh CLI is not installed. +func fetchAllPRs(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error) { + logging.Logger.Debug("Fetching all PRs for repo", "path", repoPath) + + // Check if gh is available + if _, err := exec.LookPath("gh"); err != nil { + logging.Logger.Debug("gh CLI not found, skipping PR fetch") + return nil, nil + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(ctx, prInfoFetchTimeout) + defer cancel() + + // Run gh pr list for all PRs in the repo + cmd := exec.CommandContext(ctx, "gh", "pr", "list", "--state", "all", "--json", "number,headRefName,state,url", "--limit", "100") + cmd.Dir = repoPath + + output, err := cmd.Output() + if err != nil { + logging.Logger.Debug("gh pr list failed", "error", err) + return nil, fmt.Errorf("gh pr list failed: %w", err) + } + + var prList []ghPRListResponse + if err := json.Unmarshal(output, &prList); err != nil { + logging.Logger.Debug("Failed to parse gh pr list output", "error", err, "output", string(output)) + return nil, fmt.Errorf("failed to parse gh response: %w", err) + } + + // Build map by branch name + result := make(map[string]*domain.PRInfo, len(prList)) + checkedAt := time.Now().UTC() + for _, pr := range prList { + result[pr.HeadRefName] = &domain.PRInfo{ + CheckedAt: checkedAt, + Number: pr.Number, + State: pr.State, + URL: pr.URL, + } + } + + logging.Logger.Debug("Fetched all PRs", "repo", repoPath, "count", len(result)) + return result, nil +} + // openPRInBrowser opens the PR URL in the default browser using gh CLI func openPRInBrowser(worktreePath string) error { logging.Logger.Debug("Opening PR in browser", "path", worktreePath) diff --git a/internal/ports/git_repository.go b/internal/ports/git_repository.go index 1d37899..5b1b676 100644 --- a/internal/ports/git_repository.go +++ b/internal/ports/git_repository.go @@ -49,6 +49,7 @@ type GitStatsProvider interface { // PRInfoProvider provides PR information for UI type PRInfoProvider interface { + FetchAllPRs(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error) FetchPRInfo(ctx context.Context, worktreePath, branchName string) (*domain.PRInfo, error) OpenPRInBrowser(worktreePath string) error } diff --git a/internal/ports/mocks/mock_git_repository.go b/internal/ports/mocks/mock_git_repository.go index ba910d8..ddbf27d 100644 --- a/internal/ports/mocks/mock_git_repository.go +++ b/internal/ports/mocks/mock_git_repository.go @@ -164,6 +164,74 @@ func (_c *MockGitRepository_CreateWorktree_Call) RunAndReturn(run func(repoPath return _c } +// FetchAllPRs provides a mock function for the type MockGitRepository +func (_mock *MockGitRepository) FetchAllPRs(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error) { + ret := _mock.Called(ctx, repoPath) + + if len(ret) == 0 { + panic("no return value specified for FetchAllPRs") + } + + var r0 map[string]*domain.PRInfo + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (map[string]*domain.PRInfo, error)); ok { + return returnFunc(ctx, repoPath) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) map[string]*domain.PRInfo); ok { + r0 = returnFunc(ctx, repoPath) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*domain.PRInfo) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, repoPath) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockGitRepository_FetchAllPRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FetchAllPRs' +type MockGitRepository_FetchAllPRs_Call struct { + *mock.Call +} + +// FetchAllPRs is a helper method to define mock.On call +// - ctx context.Context +// - repoPath string +func (_e *MockGitRepository_Expecter) FetchAllPRs(ctx interface{}, repoPath interface{}) *MockGitRepository_FetchAllPRs_Call { + return &MockGitRepository_FetchAllPRs_Call{Call: _e.mock.On("FetchAllPRs", ctx, repoPath)} +} + +func (_c *MockGitRepository_FetchAllPRs_Call) Run(run func(ctx context.Context, repoPath string)) *MockGitRepository_FetchAllPRs_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockGitRepository_FetchAllPRs_Call) Return(stringToPRInfo map[string]*domain.PRInfo, err error) *MockGitRepository_FetchAllPRs_Call { + _c.Call.Return(stringToPRInfo, err) + return _c +} + +func (_c *MockGitRepository_FetchAllPRs_Call) RunAndReturn(run func(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error)) *MockGitRepository_FetchAllPRs_Call { + _c.Call.Return(run) + return _c +} + // FetchGitStats provides a mock function for the type MockGitRepository func (_mock *MockGitRepository) FetchGitStats(ctx context.Context, worktreePath string) (*domain.GitStats, error) { ret := _mock.Called(ctx, worktreePath) diff --git a/internal/services/git.go b/internal/services/git.go index b32a70c..4926577 100644 --- a/internal/services/git.go +++ b/internal/services/git.go @@ -75,6 +75,11 @@ func (s *GitService) GetBranchName(path string) string { return s.gitRepo.GetBranchName(path) } +// FetchAllPRs fetches all PRs for a repository in one call +func (s *GitService) FetchAllPRs(ctx context.Context, repoPath string) (map[string]*domain.PRInfo, error) { + return s.gitRepo.FetchAllPRs(ctx, repoPath) +} + // FetchPRInfo fetches PR information for a branch func (s *GitService) FetchPRInfo(ctx context.Context, worktreePath, branchName string) (*domain.PRInfo, error) { return s.gitRepo.FetchPRInfo(ctx, worktreePath, branchName) diff --git a/internal/ui/model.go b/internal/ui/model.go index 254267d..06a4aa9 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -142,8 +142,18 @@ func NewModel( } func (m *Model) Init() tea.Cmd { - // Delegate to session list component (starts auto-refresh polling) - return m.sessionList.Init() + cmds := []tea.Cmd{m.sessionList.Init()} + + // Batch fetch PR info for all sessions on startup + if m.showPRNumber { + requests := GroupSessionsByRepo(m.sessionState.Sessions) + if len(requests) > 0 { + logging.Logger.Debug("Triggering batch PR fetch on init", "repos", len(requests)) + cmds = append(cmds, StartBatchPRInfoFetcher(m.gitService, requests)) + } + } + + return tea.Batch(cmds...) } func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -397,6 +407,24 @@ func (m *Model) updateList(msg tea.Msg) (tea.Model, tea.Cmd) { logging.Logger.Debug("PR info fetch failed", "session", msg.SessionName, "error", msg.Err) return m, nil + case BatchPRInfoReadyMsg: + // Batch PR info fetched - update all sessions + logging.Logger.Debug("Received batch PR info", "count", len(msg.Results)) + for sessionName, prInfo := range msg.Results { + if sessionInfo, exists := m.sessionState.Sessions[sessionName]; exists { + sessionInfo.PRInfo = prInfo + m.sessionState.Sessions[sessionName] = sessionInfo + + // Persist to database + if err := m.sessionService.UpdatePRInfo(context.Background(), sessionName, prInfo); err != nil { + logging.Logger.Error("Failed to persist PR info", "error", err, "session", sessionName) + } + } + } + // Single UI refresh after all updates + refreshCmd := m.sessionList.RefreshFromState() + return m, tea.Batch(refreshCmd, m.sessionList.Init()) + case OpenPRMsg: // Open PR in browser for session sessionInfo, exists := m.sessionState.Sessions[msg.SessionName] @@ -432,24 +460,17 @@ func (m *Model) updateList(msg tea.Msg) (tea.Model, tea.Cmd) { } // Handle detach message - session list auto-refreshes via polling - if detachMsg, ok := msg.(detachedMsg); ok { + if _, ok := msg.(detachedMsg); ok { m.state = stateList refreshCmd := m.sessionList.RefreshFromState() - // Trigger PR fetch for detached session if enabled + // Trigger batch PR fetch for all sessions if enabled var prFetchCmd tea.Cmd - if m.showPRNumber && detachMsg.SessionName != "" { - if sessionInfo, exists := m.sessionState.Sessions[detachMsg.SessionName]; exists { - if sessionInfo.WorktreePath != "" && sessionInfo.BranchName != "" { - logging.Logger.Debug("Triggering PR fetch on detach", - "session", detachMsg.SessionName, - "branch", sessionInfo.BranchName) - prFetchCmd = StartPRInfoFetcher(m.gitService, PRInfoRequest{ - BranchName: sessionInfo.BranchName, - SessionName: detachMsg.SessionName, - WorktreePath: sessionInfo.WorktreePath, - }) - } + if m.showPRNumber { + requests := GroupSessionsByRepo(m.sessionState.Sessions) + if len(requests) > 0 { + logging.Logger.Debug("Triggering batch PR fetch on detach", "repos", len(requests)) + prFetchCmd = StartBatchPRInfoFetcher(m.gitService, requests) } } diff --git a/internal/ui/pr_info_cmd.go b/internal/ui/pr_info_cmd.go index 4a5e92e..089f015 100644 --- a/internal/ui/pr_info_cmd.go +++ b/internal/ui/pr_info_cmd.go @@ -2,6 +2,7 @@ package ui import ( "context" + "sync" "time" tea "github.com/charmbracelet/bubbletea" @@ -62,3 +63,93 @@ func StartPRInfoFetcher(gitService *services.GitService, request PRInfoRequest) } } } + +// BatchPRInfoSession represents a session to fetch PR info for +type BatchPRInfoSession struct { + BranchName string + SessionName string +} + +// BatchPRInfoRequest groups sessions by repo for batch fetching +type BatchPRInfoRequest struct { + RepoPath string + Sessions []BatchPRInfoSession +} + +// BatchPRInfoReadyMsg contains PR info for multiple sessions +type BatchPRInfoReadyMsg struct { + Results map[string]*domain.PRInfo // sessionName -> PRInfo +} + +// StartBatchPRInfoFetcher fetches PRs for multiple repos in parallel. +// Each repo gets one `gh pr list` call, and results are matched to sessions by branch name. +func StartBatchPRInfoFetcher(gitService *services.GitService, requests []BatchPRInfoRequest) tea.Cmd { + return func() tea.Msg { + results := make(map[string]*domain.PRInfo) + var wg sync.WaitGroup + var mu sync.Mutex + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, req := range requests { + wg.Add(1) + go func(r BatchPRInfoRequest) { + defer wg.Done() + + logging.Logger.Debug("Batch fetching PRs for repo", "repo", r.RepoPath, "sessions", len(r.Sessions)) + + prMap, err := gitService.FetchAllPRs(ctx, r.RepoPath) + if err != nil { + logging.Logger.Warn("Failed to batch fetch PRs", "repo", r.RepoPath, "error", err) + return + } + if prMap == nil { + logging.Logger.Debug("gh CLI not available, skipping batch fetch") + return + } + + mu.Lock() + for _, sess := range r.Sessions { + if pr, ok := prMap[sess.BranchName]; ok { + results[sess.SessionName] = pr + logging.Logger.Debug("Matched PR to session", + "session", sess.SessionName, + "branch", sess.BranchName, + "pr", pr.Number) + } + } + mu.Unlock() + }(req) + } + + wg.Wait() + logging.Logger.Debug("Batch PR fetch complete", "results", len(results)) + return BatchPRInfoReadyMsg{Results: results} + } +} + +// GroupSessionsByRepo groups sessions needing PR fetch by repository. +// Returns batch requests ready for StartBatchPRInfoFetcher. +func GroupSessionsByRepo(sessions map[string]domain.Session) []BatchPRInfoRequest { + byRepo := make(map[string][]BatchPRInfoSession) + + for name, sess := range sessions { + if sess.RepoPath == "" || sess.BranchName == "" { + continue + } + byRepo[sess.RepoPath] = append(byRepo[sess.RepoPath], BatchPRInfoSession{ + BranchName: sess.BranchName, + SessionName: name, + }) + } + + requests := make([]BatchPRInfoRequest, 0, len(byRepo)) + for repoPath, sessionList := range byRepo { + requests = append(requests, BatchPRInfoRequest{ + RepoPath: repoPath, + Sessions: sessionList, + }) + } + return requests +} diff --git a/internal/ui/pr_info_cmd_test.go b/internal/ui/pr_info_cmd_test.go new file mode 100644 index 0000000..aa122c5 --- /dev/null +++ b/internal/ui/pr_info_cmd_test.go @@ -0,0 +1,143 @@ +package ui + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/renato0307/rocha/internal/domain" +) + +func TestGroupSessionsByRepo(t *testing.T) { + tests := []struct { + name string + sessions map[string]domain.Session + expectedRepos int + expectedSessions map[string]int // repoPath -> expected session count + }{ + { + name: "empty sessions", + sessions: map[string]domain.Session{}, + expectedRepos: 0, + expectedSessions: map[string]int{}, + }, + { + name: "session without repo path is skipped", + sessions: map[string]domain.Session{ + "session1": { + BranchName: "feature-1", + Name: "session1", + RepoPath: "", + }, + }, + expectedRepos: 0, + expectedSessions: map[string]int{}, + }, + { + name: "session without branch name is skipped", + sessions: map[string]domain.Session{ + "session1": { + BranchName: "", + Name: "session1", + RepoPath: "/path/to/repo", + }, + }, + expectedRepos: 0, + expectedSessions: map[string]int{}, + }, + { + name: "single valid session", + sessions: map[string]domain.Session{ + "session1": { + BranchName: "feature-1", + Name: "session1", + RepoPath: "/path/to/repo", + }, + }, + expectedRepos: 1, + expectedSessions: map[string]int{"/path/to/repo": 1}, + }, + { + name: "multiple sessions same repo", + sessions: map[string]domain.Session{ + "session1": { + BranchName: "feature-1", + Name: "session1", + RepoPath: "/path/to/repo", + }, + "session2": { + BranchName: "feature-2", + Name: "session2", + RepoPath: "/path/to/repo", + }, + }, + expectedRepos: 1, + expectedSessions: map[string]int{"/path/to/repo": 2}, + }, + { + name: "multiple repos", + sessions: map[string]domain.Session{ + "session1": { + BranchName: "feature-1", + Name: "session1", + RepoPath: "/path/to/repo1", + }, + "session2": { + BranchName: "feature-2", + Name: "session2", + RepoPath: "/path/to/repo2", + }, + "session3": { + BranchName: "feature-3", + Name: "session3", + RepoPath: "/path/to/repo1", + }, + }, + expectedRepos: 2, + expectedSessions: map[string]int{ + "/path/to/repo1": 2, + "/path/to/repo2": 1, + }, + }, + { + name: "mixed valid and invalid sessions", + sessions: map[string]domain.Session{ + "valid": { + BranchName: "feature-1", + Name: "valid", + RepoPath: "/path/to/repo", + }, + "no-branch": { + BranchName: "", + Name: "no-branch", + RepoPath: "/path/to/repo", + }, + "no-repo": { + BranchName: "feature-2", + Name: "no-repo", + RepoPath: "", + }, + }, + expectedRepos: 1, + expectedSessions: map[string]int{"/path/to/repo": 1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GroupSessionsByRepo(tt.sessions) + + assert.Len(t, result, tt.expectedRepos) + + // Build map for easier assertion (order not guaranteed) + byRepo := make(map[string]int) + for _, req := range result { + byRepo[req.RepoPath] = len(req.Sessions) + } + + for repoPath, expectedCount := range tt.expectedSessions { + assert.Equal(t, expectedCount, byRepo[repoPath], "repo %s", repoPath) + } + }) + } +}