diff --git a/cli/search.go b/cli/search.go index 004a2bd..a47edc9 100644 --- a/cli/search.go +++ b/cli/search.go @@ -1,11 +1,13 @@ package cli import ( + "bytes" "context" "encoding/json" "fmt" "os" "strings" + "time" "github.com/alpkeskin/gotoon" "github.com/spf13/cobra" @@ -13,6 +15,7 @@ import ( "github.com/yoanbernabeu/grepai/embedder" "github.com/yoanbernabeu/grepai/rpg" "github.com/yoanbernabeu/grepai/search" + "github.com/yoanbernabeu/grepai/stats" "github.com/yoanbernabeu/grepai/store" ) @@ -283,63 +286,113 @@ func runSearch(cmd *cobra.Command, args []string) error { // JSON output mode if searchJSON { + var err error + var outputStr string if searchCompact { - return outputSearchCompactJSON(results, enrichments) + outputStr, err = captureSearchCompactJSON(results, enrichments) + } else { + outputStr, err = captureSearchJSON(results, enrichments) + } + if err != nil { + return err } - return outputSearchJSON(results, enrichments) + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, outputModeFromFlags(searchJSON, searchTOON, searchCompact), len(results), outputStr) + return nil } // TOON output mode if searchTOON { + var err error + var outputStr string if searchCompact { - return outputSearchCompactTOON(results, enrichments) + outputStr, err = captureSearchCompactTOON(results, enrichments) + } else { + outputStr, err = captureSearchTOON(results, enrichments) + } + if err != nil { + return err } - return outputSearchTOON(results, enrichments) + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, outputModeFromFlags(searchJSON, searchTOON, searchCompact), len(results), outputStr) + return nil } if len(results) == 0 { fmt.Println("No results found.") + recordSearchStats(projectRoot, stats.Search, stats.Full, 0, "") return nil } - // Display results - fmt.Printf("Found %d results for: %q\n\n", len(results), query) + // Display results (plain text — build output string for token estimation) + var buf strings.Builder + fmt.Fprintf(&buf, "Found %d results for: %q\n\n", len(results), query) for i, result := range results { - fmt.Printf("─── Result %d (score: %.4f) ───\n", i+1, result.Score) - fmt.Printf("File: %s:%d-%d\n", result.Chunk.FilePath, result.Chunk.StartLine, result.Chunk.EndLine) + fmt.Fprintf(&buf, "─── Result %d (score: %.4f) ───\n", i+1, result.Score) + fmt.Fprintf(&buf, "File: %s:%d-%d\n", result.Chunk.FilePath, result.Chunk.StartLine, result.Chunk.EndLine) if enrichments[i].FeaturePath != "" { - fmt.Printf("Feature: %s\n", enrichments[i].FeaturePath) + fmt.Fprintf(&buf, "Feature: %s\n", enrichments[i].FeaturePath) } if enrichments[i].SymbolName != "" { - fmt.Printf("Symbol: %s\n", enrichments[i].SymbolName) + fmt.Fprintf(&buf, "Symbol: %s\n", enrichments[i].SymbolName) } - fmt.Println() + buf.WriteString("\n") - // Display content with line numbers lines := strings.Split(result.Chunk.Content, "\n") - // Skip the "File: xxx" prefix line if present startIdx := 0 if len(lines) > 0 && strings.HasPrefix(lines[0], "File: ") { - startIdx = 2 // Skip "File: xxx" and empty line + startIdx = 2 } lineNum := result.Chunk.StartLine for j := startIdx; j < len(lines) && j < startIdx+15; j++ { - fmt.Printf("%4d │ %s\n", lineNum, lines[j]) + fmt.Fprintf(&buf, "%4d │ %s\n", lineNum, lines[j]) lineNum++ } if len(lines)-startIdx > 15 { - fmt.Printf(" │ ... (%d more lines)\n", len(lines)-startIdx-15) + fmt.Fprintf(&buf, " │ ... (%d more lines)\n", len(lines)-startIdx-15) } - fmt.Println() + buf.WriteString("\n") } + outputStr := buf.String() + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, stats.Full, len(results), outputStr) return nil } -// outputSearchJSON outputs results in JSON format for AI agents -func outputSearchJSON(results []store.SearchResult, enrichments []rpgEnrichment) error { +// outputModeFromFlags determines the OutputMode from the active CLI flags. +func outputModeFromFlags(jsonFlag, toonFlag, compactFlag bool) stats.OutputMode { + if compactFlag { + return stats.Compact + } + if toonFlag { + return stats.Toon + } + return stats.Full +} + +// recordSearchStats fires a goroutine to record a stats entry without blocking. +func recordSearchStats(projectRoot, commandType, outputMode string, resultCount int, outputStr string) { + rec := stats.NewRecorder(projectRoot) + entry := stats.Entry{ + Timestamp: time.Now().UTC().Format(time.RFC3339), + CommandType: commandType, + OutputMode: outputMode, + ResultCount: resultCount, + OutputTokens: embedder.EstimateTokens(outputStr), + GrepTokens: stats.GrepEquivalentTokens(resultCount), + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _ = rec.Record(ctx, entry) + }() +} + +// captureSearchJSON returns JSON-encoded results as a string. +func captureSearchJSON(results []store.SearchResult, enrichments []rpgEnrichment) (string, error) { jsonResults := make([]SearchResultJSON, len(results)) for i, r := range results { jsonResults[i] = SearchResultJSON{ @@ -352,14 +405,17 @@ func outputSearchJSON(results []store.SearchResult, enrichments []rpgEnrichment) SymbolName: enrichments[i].SymbolName, } } - - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(jsonResults) + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetIndent("", " ") + if err := enc.Encode(jsonResults); err != nil { + return "", err + } + return buf.String(), nil } -// outputSearchCompactJSON outputs results in minimal JSON format (without content) -func outputSearchCompactJSON(results []store.SearchResult, enrichments []rpgEnrichment) error { +// captureSearchCompactJSON returns compact JSON-encoded results as a string. +func captureSearchCompactJSON(results []store.SearchResult, enrichments []rpgEnrichment) (string, error) { jsonResults := make([]SearchResultCompactJSON, len(results)) for i, r := range results { jsonResults[i] = SearchResultCompactJSON{ @@ -371,22 +427,17 @@ func outputSearchCompactJSON(results []store.SearchResult, enrichments []rpgEnri SymbolName: enrichments[i].SymbolName, } } - - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(jsonResults) -} - -// outputSearchErrorJSON outputs an error in JSON format -func outputSearchErrorJSON(err error) error { - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - _ = encoder.Encode(map[string]string{"error": err.Error()}) - return nil + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetIndent("", " ") + if err := enc.Encode(jsonResults); err != nil { + return "", err + } + return buf.String(), nil } -// outputSearchTOON outputs results in TOON format for AI agents -func outputSearchTOON(results []store.SearchResult, enrichments []rpgEnrichment) error { +// captureSearchTOON returns TOON-encoded results as a string. +func captureSearchTOON(results []store.SearchResult, enrichments []rpgEnrichment) (string, error) { toonResults := make([]SearchResultJSON, len(results)) for i, r := range results { toonResults[i] = SearchResultJSON{ @@ -399,17 +450,15 @@ func outputSearchTOON(results []store.SearchResult, enrichments []rpgEnrichment) SymbolName: enrichments[i].SymbolName, } } - output, err := gotoon.Encode(toonResults) if err != nil { - return fmt.Errorf("failed to encode TOON: %w", err) + return "", fmt.Errorf("failed to encode TOON: %w", err) } - fmt.Println(output) - return nil + return output + "\n", nil } -// outputSearchCompactTOON outputs results in minimal TOON format (without content) -func outputSearchCompactTOON(results []store.SearchResult, enrichments []rpgEnrichment) error { +// captureSearchCompactTOON returns compact TOON-encoded results as a string. +func captureSearchCompactTOON(results []store.SearchResult, enrichments []rpgEnrichment) (string, error) { toonResults := make([]SearchResultCompactJSON, len(results)) for i, r := range results { toonResults[i] = SearchResultCompactJSON{ @@ -421,12 +470,18 @@ func outputSearchCompactTOON(results []store.SearchResult, enrichments []rpgEnri SymbolName: enrichments[i].SymbolName, } } - output, err := gotoon.Encode(toonResults) if err != nil { - return fmt.Errorf("failed to encode TOON: %w", err) + return "", fmt.Errorf("failed to encode TOON: %w", err) } - fmt.Println(output) + return output + "\n", nil +} + +// outputSearchErrorJSON outputs an error in JSON format +func outputSearchErrorJSON(err error) error { + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + _ = encoder.Encode(map[string]string{"error": err.Error()}) return nil } @@ -592,40 +647,62 @@ func runWorkspaceSearch(ctx context.Context, query string, projects []string, pa // Workspace mode doesn't have RPG enrichment (no single projectRoot) enrichments := make([]rpgEnrichment, len(results)) + projectRoot, _ := config.FindProjectRoot() + // JSON output mode if searchJSON { + var outputStr string + var err error if searchCompact { - return outputSearchCompactJSON(results, enrichments) + outputStr, err = captureSearchCompactJSON(results, enrichments) + } else { + outputStr, err = captureSearchJSON(results, enrichments) } - return outputSearchJSON(results, enrichments) + if err != nil { + return err + } + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, outputModeFromFlags(searchJSON, searchTOON, searchCompact), len(results), outputStr) + return nil } // TOON output mode if searchTOON { + var outputStr string + var err error if searchCompact { - return outputSearchCompactTOON(results, enrichments) + outputStr, err = captureSearchCompactTOON(results, enrichments) + } else { + outputStr, err = captureSearchTOON(results, enrichments) } - return outputSearchTOON(results, enrichments) + if err != nil { + return err + } + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, outputModeFromFlags(searchJSON, searchTOON, searchCompact), len(results), outputStr) + return nil } if len(results) == 0 { fmt.Println("No results found.") + recordSearchStats(projectRoot, stats.Search, stats.Full, 0, "") return nil } // Display results - fmt.Printf("Found %d results for: %q in workspace %q\n\n", len(results), query, searchWorkspace) + var buf strings.Builder + fmt.Fprintf(&buf, "Found %d results for: %q in workspace %q\n\n", len(results), query, searchWorkspace) for i, result := range results { - fmt.Printf("─── Result %d (score: %.4f) ───\n", i+1, result.Score) - fmt.Printf("File: %s:%d-%d\n", result.Chunk.FilePath, result.Chunk.StartLine, result.Chunk.EndLine) + fmt.Fprintf(&buf, "─── Result %d (score: %.4f) ───\n", i+1, result.Score) + fmt.Fprintf(&buf, "File: %s:%d-%d\n", result.Chunk.FilePath, result.Chunk.StartLine, result.Chunk.EndLine) if enrichments[i].FeaturePath != "" { - fmt.Printf("Feature: %s\n", enrichments[i].FeaturePath) + fmt.Fprintf(&buf, "Feature: %s\n", enrichments[i].FeaturePath) } if enrichments[i].SymbolName != "" { - fmt.Printf("Symbol: %s\n", enrichments[i].SymbolName) + fmt.Fprintf(&buf, "Symbol: %s\n", enrichments[i].SymbolName) } - fmt.Println() + buf.WriteString("\n") // Display content with line numbers lines := strings.Split(result.Chunk.Content, "\n") @@ -636,14 +713,17 @@ func runWorkspaceSearch(ctx context.Context, query string, projects []string, pa lineNum := result.Chunk.StartLine for j := startIdx; j < len(lines) && j < startIdx+15; j++ { - fmt.Printf("%4d │ %s\n", lineNum, lines[j]) + fmt.Fprintf(&buf, "%4d │ %s\n", lineNum, lines[j]) lineNum++ } if len(lines)-startIdx > 15 { - fmt.Printf(" │ ... (%d more lines)\n", len(lines)-startIdx-15) + fmt.Fprintf(&buf, " │ ... (%d more lines)\n", len(lines)-startIdx-15) } - fmt.Println() + buf.WriteString("\n") } + outputStr := buf.String() + fmt.Print(outputStr) + recordSearchStats(projectRoot, stats.Search, stats.Full, len(results), outputStr) return nil } diff --git a/cli/stats.go b/cli/stats.go new file mode 100644 index 0000000..14331bd --- /dev/null +++ b/cli/stats.go @@ -0,0 +1,203 @@ +package cli + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/spf13/cobra" + "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/stats" +) + +var ( + statsJSON bool + statsHistory bool + statsLimit int + statsNoUI bool +) + +var statsCmd = &cobra.Command{ + Use: "stats", + Short: "Show token savings achieved by using grepai", + Long: `Display a summary of token savings achieved by grepai compared to +a traditional grep-based workflow. + +Every successful search and trace command records an entry locally in +.grepai/stats.json. This command aggregates those entries and shows +how many tokens (and optionally dollars) have been saved.`, + RunE: runStats, +} + +func init() { + rootCmd.AddCommand(statsCmd) + statsCmd.Flags().BoolVarP(&statsJSON, "json", "j", false, "Output results in JSON format") + statsCmd.Flags().BoolVar(&statsHistory, "history", false, "Show per-day history breakdown") + statsCmd.Flags().IntVarP(&statsLimit, "limit", "l", 30, "Max days shown with --history") + statsCmd.Flags().BoolVar(&statsNoUI, "no-ui", false, "Print plain text instead of interactive UI") +} + +func runStats(cmd *cobra.Command, args []string) error { + projectRoot, err := config.FindProjectRoot() + if err != nil { + return err + } + + cfg, err := config.Load(projectRoot) + if err != nil { + return fmt.Errorf("failed to load configuration: %w", err) + } + + statsPath := stats.StatsPath(projectRoot) + entries, err := stats.ReadAll(statsPath) + if err != nil { + return fmt.Errorf("failed to read stats: %w", err) + } + + if len(entries) == 0 { + fmt.Println("No stats recorded yet.") + fmt.Println("Run a search or trace command to start tracking token savings.") + return nil + } + + summary := stats.Summarize(entries, cfg.Embedder.Provider) + + if statsJSON { + return outputStatsJSON(summary, entries) + } + + if shouldUseStatsUI(isInteractiveTerminal(), statsNoUI) && !statsHistory { + return runStatsUI(summary, entries, cfg.Embedder.Provider) + } + + return outputStatsHuman(summary, entries, cfg.Embedder.Provider) +} + +// outputStatsJSON renders the summary (and optional history) as JSON. +func outputStatsJSON(summary stats.Summary, entries []stats.Entry) error { + if !statsHistory { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(summary) + } + + days := stats.HistoryByDay(entries) + if statsLimit > 0 && len(days) > statsLimit { + days = days[:statsLimit] + } + + out := struct { + Summary stats.Summary `json:"summary"` + History []stats.DaySummary `json:"history"` + }{ + Summary: summary, + History: days, + } + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(out) +} + +// outputStatsHuman renders the summary using lipgloss styles. +func outputStatsHuman(summary stats.Summary, entries []stats.Entry, provider string) error { + // Styles + headerStyle := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("205")) + labelStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252")).Width(22) + valueStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("229")).Bold(true) + dimStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + boxStyle := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("62")). + Padding(1, 2) + + content := headerStyle.Render("grepai stats — Token Savings Report") + "\n\n" + + content += labelStyle.Render("Total queries") + valueStyle.Render(fmt.Sprintf("%d", summary.TotalQueries)) + "\n" + content += labelStyle.Render("Tokens (grepai)") + valueStyle.Render(formatInt(summary.OutputTokens)) + "\n" + content += labelStyle.Render("Tokens (grep est.)") + valueStyle.Render(formatInt(summary.GrepTokens)) + "\n" + content += labelStyle.Render("Tokens saved") + + valueStyle.Render(fmt.Sprintf("%s ▲ %.1f%%", formatInt(summary.TokensSaved), summary.SavingsPct)) + "\n" + + if summary.CostSavedUSD != nil { + content += labelStyle.Render("Est. cost saved") + + valueStyle.Render(fmt.Sprintf("$%.4f", *summary.CostSavedUSD)) + + dimStyle.Render(" (cloud provider)") + "\n" + } + + // Command breakdown + content += "\n" + cmdLine := "By command: " + for _, k := range []string{"search", "trace-callers", "trace-callees", "trace-graph"} { + if v := summary.ByCommandType[k]; v > 0 { + cmdLine += fmt.Sprintf("%s %d · ", k, v) + } + } + content += dimStyle.Render(strings.TrimSuffix(cmdLine, " · ")) + "\n" + + modeLine := "By mode: " + for _, k := range []string{"full", "compact", "toon"} { + if v := summary.ByOutputMode[k]; v > 0 { + modeLine += fmt.Sprintf("%s %d · ", k, v) + } + } + content += dimStyle.Render(strings.TrimSuffix(modeLine, " · ")) + "\n" + + fmt.Println(boxStyle.Render(content)) + + if statsHistory { + printHistoryTable(entries, dimStyle, valueStyle) + } + + return nil +} + +func printHistoryTable(entries []stats.Entry, dimStyle, valueStyle lipgloss.Style) { + days := stats.HistoryByDay(entries) + if statsLimit > 0 && len(days) > statsLimit { + days = days[:statsLimit] + } + + colDate := lipgloss.NewStyle().Width(14) + colNum := lipgloss.NewStyle().Width(10) + colSaved := lipgloss.NewStyle().Width(16) + colPct := lipgloss.NewStyle().Width(10) + + header := dimStyle.Render( + colDate.Render("Date") + + colNum.Render("Queries") + + colSaved.Render("Tokens saved") + + colPct.Render("Savings"), + ) + sep := dimStyle.Render(fmt.Sprintf("%-14s %-10s %-16s %-10s", "──────────────", "─────────", "───────────────", "────────")) + fmt.Println(header) + fmt.Println(sep) + + for _, d := range days { + pct := 0.0 + if d.GrepTokens > 0 { + pct = float64(d.TokensSaved) / float64(d.GrepTokens) * 100 + } + row := colDate.Render(d.Date) + + colNum.Render(fmt.Sprintf("%d", d.QueryCount)) + + colSaved.Render(formatInt(d.TokensSaved)) + + colPct.Render(fmt.Sprintf("%.1f%%", pct)) + fmt.Println(valueStyle.Render(row)) + } +} + +func formatInt(n int) string { + if n == 0 { + return "0" + } + s := fmt.Sprintf("%d", n) + result := "" + for i, c := range s { + if i > 0 && (len(s)-i)%3 == 0 { + result += "," + } + result += string(c) + } + return result +} diff --git a/cli/status.go b/cli/status.go index 2422e0e..ee8b123 100644 --- a/cli/status.go +++ b/cli/status.go @@ -14,6 +14,7 @@ import ( "github.com/yoanbernabeu/grepai/config" "github.com/yoanbernabeu/grepai/daemon" "github.com/yoanbernabeu/grepai/git" + "github.com/yoanbernabeu/grepai/stats" "github.com/yoanbernabeu/grepai/store" ) @@ -38,25 +39,29 @@ const ( viewStats viewState = iota viewFiles viewChunks + viewTokenSavings ) type model struct { - st store.VectorStore - cfg *config.Config - state viewState - stats *store.IndexStats - files []store.FileStats - chunks []store.Chunk - selectedFile int - selectedChunk int - width int - height int - watchRunning bool - watchPID int - watchLogDir string - watchLogFile string - worktreeID string - err error + st store.VectorStore + cfg *config.Config + state viewState + stats *store.IndexStats + files []store.FileStats + chunks []store.Chunk + selectedFile int + selectedChunk int + width int + height int + watchRunning bool + watchPID int + watchLogDir string + watchLogFile string + worktreeID string + err error + savingsSummary *stats.Summary + savingsDays []stats.DaySummary + savingsSelected int } func init() { @@ -105,6 +110,14 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.state = viewStats case viewChunks: m.state = viewFiles + case viewTokenSavings: + m.state = viewStats + } + + case "s": + if m.state == viewStats { + m.state = viewTokenSavings + m.savingsSelected = 0 } case "enter": @@ -135,6 +148,10 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.selectedChunk > 0 { m.selectedChunk-- } + case viewTokenSavings: + if m.savingsSelected > 0 { + m.savingsSelected-- + } } case "down", "j": @@ -147,6 +164,10 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.selectedChunk < len(m.chunks)-1 { m.selectedChunk++ } + case viewTokenSavings: + if m.savingsSelected < len(m.savingsDays)-1 { + m.savingsSelected++ + } } } @@ -170,6 +191,8 @@ func (m model) View() string { return m.viewFiles() case viewChunks: return m.viewChunks() + case viewTokenSavings: + return m.viewTokenSavingsView() } return "" @@ -214,7 +237,7 @@ func (m model) viewStats() string { } sb.WriteString("\n") - sb.WriteString(helpStyle.Render("[Enter] Browse files [q] Quit")) + sb.WriteString(helpStyle.Render("[Enter] Browse files [s] Token savings [q] Quit")) return boxStyle.Render(sb.String()) } @@ -325,6 +348,91 @@ func (m model) viewChunks() string { return boxStyle.Render(sb.String()) } +func (m model) viewTokenSavingsView() string { + var sb strings.Builder + + sb.WriteString(titleStyle.Render("Token Savings")) + sb.WriteString("\n\n") + + if m.savingsSummary == nil { + sb.WriteString(dimStyle.Render("No stats recorded yet.")) + sb.WriteString("\n\n") + sb.WriteString(helpStyle.Render("[Esc] Back [q] Quit")) + return boxStyle.Render(sb.String()) + } + + s := m.savingsSummary + label := normalStyle.Width(20) + + sb.WriteString(label.Render("Queries")) + sb.WriteString(fmt.Sprintf("%d\n", s.TotalQueries)) + sb.WriteString(label.Render("Tokens saved")) + sb.WriteString(fmt.Sprintf("%s\n", formatInt(s.TokensSaved))) + sb.WriteString(label.Render("Savings")) + sb.WriteString(fmt.Sprintf("%.1f%%\n", s.SavingsPct)) + if s.CostSavedUSD != nil { + sb.WriteString(label.Render("Cost saved")) + sb.WriteString(fmt.Sprintf("~$%.4f", *s.CostSavedUSD)) + sb.WriteString(dimStyle.Render(" (cloud provider)")) + sb.WriteString("\n") + } + + if len(m.savingsDays) > 0 { + sb.WriteString("\n") + colDate := 16 + colQ := 10 + colSaved := 16 + colPct := 10 + sb.WriteString(dimStyle.Render(fmt.Sprintf("%-*s %-*s %-*s %-*s", + colDate, "Date", colQ, "Queries", colSaved, "Tokens saved", colPct, "Savings"))) + sb.WriteString("\n") + sb.WriteString(dimStyle.Render(fmt.Sprintf("%-*s %-*s %-*s %-*s", + colDate, "────────────────", colQ, "─────────", colSaved, "───────────────", colPct, "────────"))) + sb.WriteString("\n") + + maxVisible := 10 + if m.height > 0 { + maxVisible = m.height - 20 + } + if maxVisible < 3 { + maxVisible = 3 + } + start := 0 + if m.savingsSelected >= maxVisible { + start = m.savingsSelected - maxVisible + 1 + } + end := start + maxVisible + if end > len(m.savingsDays) { + end = len(m.savingsDays) + } + + for i := start; i < end; i++ { + d := m.savingsDays[i] + pct := 0.0 + if d.GrepTokens > 0 { + pct = float64(d.TokensSaved) / float64(d.GrepTokens) * 100 + } + row := fmt.Sprintf("%-*s %-*d %-*s %-*.1f%%", + colDate, d.Date, + colQ, d.QueryCount, + colSaved, formatInt(d.TokensSaved), + colPct-1, pct, + ) + if i == m.savingsSelected { + sb.WriteString(selectedStyle.Render("> " + row)) + } else { + sb.WriteString(normalStyle.Render(" " + row)) + } + sb.WriteString("\n") + } + } + + sb.WriteString("\n") + sb.WriteString(helpStyle.Render("[↑/↓] Navigate [Esc] Back [q] Quit")) + + return boxStyle.Render(sb.String()) +} + func runStatus(cmd *cobra.Command, args []string) error { ctx := context.Background() @@ -371,17 +479,27 @@ func runStatus(cmd *cobra.Command, args []string) error { } defer st.Close() - // Get stats - stats, err := st.GetStats(ctx) + // Get index stats + indexStats, err := st.GetStats(ctx) if err != nil { return fmt.Errorf("failed to get stats: %w", err) } + // Load token savings stats (non-fatal) + var savingsSummary *stats.Summary + var savingsDays []stats.DaySummary + statsPath := stats.StatsPath(projectRoot) + if entries, serr := stats.ReadAll(statsPath); serr == nil && len(entries) > 0 { + s := stats.Summarize(entries, cfg.Embedder.Provider) + savingsSummary = &s + savingsDays = stats.HistoryByDay(entries) + } + watchStatus := resolveWatcherRuntimeStatus(projectRoot) useUI := shouldUseStatusUI(isInteractiveTerminal(), statusNoUI) if !useUI { - fmt.Print(renderStatusSummary(cfg, stats, watchStatus)) + fmt.Print(renderStatusSummary(cfg, indexStats, watchStatus)) return nil } @@ -392,16 +510,18 @@ func runStatus(cmd *cobra.Command, args []string) error { // Create model m := model{ - st: st, - cfg: cfg, - state: viewStats, - stats: stats, - files: files, - watchRunning: watchStatus.running, - watchPID: watchStatus.pid, - watchLogDir: watchStatus.logDir, - watchLogFile: watchStatus.logFile, - worktreeID: watchStatus.worktreeID, + st: st, + cfg: cfg, + state: viewStats, + stats: indexStats, + files: files, + watchRunning: watchStatus.running, + watchPID: watchStatus.pid, + watchLogDir: watchStatus.logDir, + watchLogFile: watchStatus.logFile, + worktreeID: watchStatus.worktreeID, + savingsSummary: savingsSummary, + savingsDays: savingsDays, } // Run TUI diff --git a/cli/trace.go b/cli/trace.go index 8747c04..d82b596 100644 --- a/cli/trace.go +++ b/cli/trace.go @@ -1,17 +1,20 @@ package cli import ( + "bytes" "context" "encoding/json" "fmt" "log" - "os" "strings" + "time" "github.com/alpkeskin/gotoon" "github.com/spf13/cobra" "github.com/yoanbernabeu/grepai/config" + "github.com/yoanbernabeu/grepai/embedder" "github.com/yoanbernabeu/grepai/rpg" + gstats "github.com/yoanbernabeu/grepai/stats" "github.com/yoanbernabeu/grepai/trace" ) @@ -246,7 +249,7 @@ func runTraceCallers(cmd *cobra.Command, args []string) error { enrichTraceWithRPG(projectRoot, cfg, &result) } - return outputTraceResult(result, traceViewCallers) + return outputAndRecord(result, traceViewCallers, projectRoot, gstats.TraceCallers, len(result.Callers)) } func runTraceCallees(cmd *cobra.Command, args []string) error { @@ -395,7 +398,7 @@ func runTraceCallees(cmd *cobra.Command, args []string) error { enrichTraceWithRPG(projectRoot, cfg, &result) } - return outputTraceResult(result, traceViewCallees) + return outputAndRecord(result, traceViewCallees, projectRoot, gstats.TraceCallees, len(result.Callees)) } func runTraceGraph(cmd *cobra.Command, args []string) error { @@ -504,7 +507,35 @@ func runTraceGraph(cmd *cobra.Command, args []string) error { enrichTraceWithRPG(projectRoot, cfg, &result) } - return outputTraceResult(result, traceViewGraph) + nodeCount := 0 + if result.Graph != nil { + nodeCount = len(result.Graph.Nodes) + } + + return outputAndRecord(result, traceViewGraph, projectRoot, gstats.TraceGraph, nodeCount) +} + +func outputAndRecord(result trace.TraceResult, view traceViewKind, projectRoot, commandType string, resultCount int) error { + if traceJSON { + outputStr := captureJSON(result) + fmt.Print(outputStr) + recordTraceStats(projectRoot, commandType, resultCount, outputStr) + return nil + } + if traceTOON { + outputStr, err := captureTOON(result) + if err != nil { + return err + } + fmt.Print(outputStr) + recordTraceStats(projectRoot, commandType, resultCount, outputStr) + return nil + } + if err := outputTraceResult(result, view); err != nil { + return err + } + recordTraceStats(projectRoot, commandType, resultCount, "") + return nil } // enrichTraceWithRPG enriches all symbols in a TraceResult with RPG feature paths. @@ -606,21 +637,56 @@ func showTraceActionCardUIError(displayErr error, title, why, action string) err return displayErr } -func outputJSON(result trace.TraceResult) error { - enc := json.NewEncoder(os.Stdout) +// captureJSON serializes a TraceResult to a JSON string without writing to stdout. +func captureJSON(result trace.TraceResult) string { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) enc.SetIndent("", " ") - return enc.Encode(result) + _ = enc.Encode(result) + return buf.String() +} + +func outputJSON(result trace.TraceResult) error { + fmt.Print(captureJSON(result)) + return nil } func outputTOON(result trace.TraceResult) error { - output, err := gotoon.Encode(result) + s, err := captureTOON(result) if err != nil { - return fmt.Errorf("failed to encode TOON: %w", err) + return err } - fmt.Println(output) + fmt.Print(s) return nil } +// captureTOON serializes a TraceResult to a TOON string without writing to stdout. +func captureTOON(result trace.TraceResult) (string, error) { + output, err := gotoon.Encode(result) + if err != nil { + return "", fmt.Errorf("failed to encode TOON: %w", err) + } + return output + "\n", nil +} + +// recordTraceStats fires a goroutine to record a trace stats entry without blocking. +func recordTraceStats(projectRoot, commandType string, resultCount int, outputStr string) { + rec := gstats.NewRecorder(projectRoot) + entry := gstats.Entry{ + Timestamp: time.Now().UTC().Format(time.RFC3339), + CommandType: commandType, + OutputMode: outputModeFromFlags(traceJSON, traceTOON, false), + ResultCount: resultCount, + OutputTokens: embedder.EstimateTokens(outputStr), + GrepTokens: gstats.GrepEquivalentTokens(resultCount), + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _ = rec.Record(ctx, entry) + }() +} + func displayCallersResult(result trace.TraceResult) error { fmt.Printf("Symbol: %s (%s)\n", result.Symbol.Name, result.Symbol.Kind) fmt.Printf("File: %s:%d\n", result.Symbol.File, result.Symbol.Line) diff --git a/cli/trace_test.go b/cli/trace_test.go index ba13628..26fae2f 100644 --- a/cli/trace_test.go +++ b/cli/trace_test.go @@ -203,28 +203,12 @@ func TestOutputJSON_should_produce_valid_json(t *testing.T) { }, } - // Capture stdout - old := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - err := outputJSON(result) - - w.Close() - os.Stdout = old - - if err != nil { - t.Fatalf("outputJSON() failed: %v", err) - } - - var buf bytes.Buffer - buf.ReadFrom(r) - output := buf.String() + output := captureJSON(result) // Verify it's valid JSON var decoded trace.TraceResult if err := json.Unmarshal([]byte(output), &decoded); err != nil { - t.Fatalf("outputJSON() produced invalid JSON: %v\nOutput: %s", err, output) + t.Fatalf("captureJSON() produced invalid JSON: %v\nOutput: %s", err, output) } if decoded.Query != "TestSymbol" { t.Errorf("decoded query = %q, want %q", decoded.Query, "TestSymbol") diff --git a/cli/tui_runtime.go b/cli/tui_runtime.go index 0a751c9..e5a9c08 100644 --- a/cli/tui_runtime.go +++ b/cli/tui_runtime.go @@ -37,3 +37,7 @@ func shouldUseWatchUI(isTTY, noUI, background, status, stop bool, workspace stri func shouldUseStatusUI(isTTY, noUI bool) bool { return isTTY && !noUI } + +func shouldUseStatsUI(isTTY, noUI bool) bool { + return isTTY && !noUI +} diff --git a/cli/tui_stats.go b/cli/tui_stats.go new file mode 100644 index 0000000..9ebcddc --- /dev/null +++ b/cli/tui_stats.go @@ -0,0 +1,239 @@ +package cli + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yoanbernabeu/grepai/stats" +) + +type statsView int + +const ( + statsViewSummary statsView = iota + statsViewHistory +) + +type statsUIModel struct { + theme tuiTheme + summary stats.Summary + entries []stats.Entry + days []stats.DaySummary + provider string + view statsView + selected int + width int + height int +} + +func newStatsUIModel(summary stats.Summary, entries []stats.Entry, provider string) statsUIModel { + days := stats.HistoryByDay(entries) + return statsUIModel{ + theme: newTUITheme(), + summary: summary, + entries: entries, + days: days, + provider: provider, + view: statsViewSummary, + selected: 0, + } +} + +func (m statsUIModel) Init() tea.Cmd { return nil } + +func (m statsUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c": + return m, tea.Quit + + case "h", "tab": + if m.view == statsViewSummary { + m.view = statsViewHistory + m.selected = 0 + } else { + m.view = statsViewSummary + } + + case "esc": + if m.view == statsViewHistory { + m.view = statsViewSummary + } + + case "up", "k": + if m.view == statsViewHistory && m.selected > 0 { + m.selected-- + } + + case "down", "j": + if m.view == statsViewHistory && m.selected < len(m.days)-1 { + m.selected++ + } + } + } + + return m, nil +} + +func (m statsUIModel) View() string { + switch m.view { + case statsViewHistory: + return m.viewHistory() + default: + return m.viewSummary() + } +} + +func (m statsUIModel) viewSummary() string { + t := m.theme + var sb strings.Builder + + sb.WriteString(t.title.Render("Token Savings — grepai")) + sb.WriteString("\n\n") + + label := t.muted.Width(20) + value := t.text + + sb.WriteString(label.Render("Queries")) + sb.WriteString(value.Render(fmt.Sprintf("%d", m.summary.TotalQueries))) + sb.WriteString("\n") + + sb.WriteString(label.Render("Tokens saved")) + sb.WriteString(value.Render(formatInt(m.summary.TokensSaved))) + sb.WriteString("\n") + + sb.WriteString(label.Render("Savings")) + sb.WriteString(t.ok.Render(fmt.Sprintf("%.1f%%", m.summary.SavingsPct))) + sb.WriteString("\n") + + if m.summary.CostSavedUSD != nil { + sb.WriteString(label.Render("Cost saved")) + sb.WriteString(value.Render(fmt.Sprintf("~$%.4f", *m.summary.CostSavedUSD))) + sb.WriteString(t.muted.Render(" (cloud provider)")) + sb.WriteString("\n") + } + + // Command breakdown + sb.WriteString("\n") + cmdParts := []string{} + for _, k := range []string{"search", "trace-callers", "trace-callees", "trace-graph"} { + if v := m.summary.ByCommandType[k]; v > 0 { + cmdParts = append(cmdParts, fmt.Sprintf("%s %d", k, v)) + } + } + if len(cmdParts) > 0 { + sb.WriteString(t.muted.Render("By command: " + strings.Join(cmdParts, " · "))) + sb.WriteString("\n") + } + + modeParts := []string{} + for _, k := range []string{"full", "compact", "toon"} { + if v := m.summary.ByOutputMode[k]; v > 0 { + modeParts = append(modeParts, fmt.Sprintf("%s %d", k, v)) + } + } + if len(modeParts) > 0 { + sb.WriteString(t.muted.Render("By mode: " + strings.Join(modeParts, " · "))) + sb.WriteString("\n") + } + + sb.WriteString("\n") + sb.WriteString(t.help.Render("[h/tab] history [q] quit")) + + return t.panel.Render(sb.String()) +} + +func (m statsUIModel) viewHistory() string { + t := m.theme + var sb strings.Builder + + sb.WriteString(t.title.Render("History — Token Savings")) + sb.WriteString("\n\n") + + colDate := 16 + colQueries := 10 + colSaved := 16 + colPct := 10 + + header := t.muted.Render( + fmt.Sprintf("%-*s %-*s %-*s %-*s", + colDate, "Date", + colQueries, "Queries", + colSaved, "Tokens saved", + colPct, "Savings", + ), + ) + sep := t.muted.Render( + fmt.Sprintf("%-*s %-*s %-*s %-*s", + colDate, "────────────────", + colQueries, "─────────", + colSaved, "───────────────", + colPct, "────────", + ), + ) + sb.WriteString(header) + sb.WriteString("\n") + sb.WriteString(sep) + sb.WriteString("\n") + + maxVisible := 15 + if m.height > 0 { + maxVisible = m.height - 12 + } + if maxVisible < 5 { + maxVisible = 5 + } + + start := 0 + if m.selected >= maxVisible { + start = m.selected - maxVisible + 1 + } + end := start + maxVisible + if end > len(m.days) { + end = len(m.days) + } + + for i := start; i < end; i++ { + d := m.days[i] + pct := 0.0 + if d.GrepTokens > 0 { + pct = float64(d.TokensSaved) / float64(d.GrepTokens) * 100 + } + row := fmt.Sprintf("%-*s %-*d %-*s %-*.1f%%", + colDate, d.Date, + colQueries, d.QueryCount, + colSaved, formatInt(d.TokensSaved), + colPct-1, pct, + ) + if i == m.selected { + sb.WriteString(t.highlight.Render("> " + row)) + } else { + sb.WriteString(t.text.Render(" " + row)) + } + sb.WriteString("\n") + } + + if len(m.days) > maxVisible { + sb.WriteString("\n") + sb.WriteString(t.muted.Render(fmt.Sprintf("... showing %d-%d of %d days", start+1, end, len(m.days)))) + sb.WriteString("\n") + } + + sb.WriteString("\n") + sb.WriteString(t.help.Render("[↑/↓] navigate [esc/tab] back [q] quit")) + + return t.panel.Render(sb.String()) +} + +func runStatsUI(summary stats.Summary, entries []stats.Entry, provider string) error { + m := newStatsUIModel(summary, entries, provider) + p := tea.NewProgram(m, tea.WithAltScreen()) + _, err := p.Run() + return err +} diff --git a/cli/tui_stats_test.go b/cli/tui_stats_test.go new file mode 100644 index 0000000..921d519 --- /dev/null +++ b/cli/tui_stats_test.go @@ -0,0 +1,291 @@ +package cli + +import ( + "fmt" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yoanbernabeu/grepai/stats" +) + +var testEntries = []stats.Entry{ + {Timestamp: "2026-01-15T10:00:00Z", CommandType: "search", OutputMode: "compact", ResultCount: 3, OutputTokens: 100, GrepTokens: 500}, + {Timestamp: "2026-01-15T11:00:00Z", CommandType: "trace-callers", OutputMode: "full", ResultCount: 2, OutputTokens: 80, GrepTokens: 300}, + {Timestamp: "2026-01-16T09:00:00Z", CommandType: "search", OutputMode: "toon", ResultCount: 1, OutputTokens: 50, GrepTokens: 150}, +} + +// TestStatsUIModelInit vérifie l'état initial du modèle. +func TestStatsUIModelInit(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + + if m.view != statsViewSummary { + t.Errorf("view: got %v, want statsViewSummary", m.view) + } + if m.selected != 0 { + t.Errorf("selected: got %d, want 0", m.selected) + } + // 3 entries réparties sur 2 jours distincts + if len(m.days) != 2 { + t.Errorf("days: got %d, want 2", len(m.days)) + } +} + +// TestStatsViewSummaryEmpty vérifie l'affichage avec des entrées vides. +func TestStatsViewSummaryEmpty(t *testing.T) { + summary := stats.Summarize(nil, "ollama") + m := newStatsUIModel(summary, nil, "ollama") + out := stripANSI(m.View()) + + if !strings.Contains(out, "0") { + t.Errorf("expected '0' queries in empty summary, got:\n%s", out) + } + if !strings.Contains(out, "0.0%") { + t.Errorf("expected '0.0%%' savings in empty summary, got:\n%s", out) + } +} + +// TestStatsViewSummaryCloudProvider vérifie l'affichage du coût pour un provider cloud. +func TestStatsViewSummaryCloudProvider(t *testing.T) { + summary := stats.Summarize(testEntries, "openai") + m := newStatsUIModel(summary, testEntries, "openai") + out := stripANSI(m.View()) + + if !strings.Contains(out, "Cost saved") { + t.Errorf("expected 'Cost saved' for cloud provider, got:\n%s", out) + } +} + +// TestStatsViewSummaryLocalProvider vérifie l'absence du coût pour un provider local. +func TestStatsViewSummaryLocalProvider(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + out := stripANSI(m.View()) + + if strings.Contains(out, "Cost saved") { + t.Errorf("unexpected 'Cost saved' for local provider, got:\n%s", out) + } +} + +// TestStatsViewSummaryBreakdown vérifie que le breakdown par commande et mode est affiché. +func TestStatsViewSummaryBreakdown(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + out := stripANSI(m.View()) + + if !strings.Contains(out, "search") { + t.Errorf("expected 'search' in command breakdown, got:\n%s", out) + } + if !strings.Contains(out, "compact") { + t.Errorf("expected 'compact' in mode breakdown, got:\n%s", out) + } +} + +// TestStatsViewHistoryHeader vérifie l'en-tête de la vue historique. +func TestStatsViewHistoryHeader(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + out := stripANSI(m.View()) + + if !strings.Contains(out, "History — Token Savings") { + t.Errorf("expected history header, got:\n%s", out) + } +} + +// TestStatsViewHistoryLines vérifie que les lignes par date sont affichées. +func TestStatsViewHistoryLines(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + out := stripANSI(m.View()) + + if !strings.Contains(out, "2026-01-16") { + t.Errorf("expected date '2026-01-16', got:\n%s", out) + } + if !strings.Contains(out, "2026-01-15") { + t.Errorf("expected date '2026-01-15', got:\n%s", out) + } +} + +// TestStatsViewHistorySelectedMarker vérifie que la ligne sélectionnée a le marqueur ">". +func TestStatsViewHistorySelectedMarker(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + m.selected = 0 + out := stripANSI(m.View()) + + // La première ligne de données doit avoir ">" + found := false + for _, line := range strings.Split(out, "\n") { + if strings.Contains(line, ">") && strings.Contains(line, "2026-01") { + found = true + break + } + } + if !found { + t.Errorf("expected '>' marker on selected row, got:\n%s", out) + } +} + +// TestStatsViewHistoryPagination vérifie le message de pagination quand > maxVisible jours. +func TestStatsViewHistoryPagination(t *testing.T) { + // Créer suffisamment d'entrées pour dépasser maxVisible (15 par défaut) + var manyEntries []stats.Entry + for i := 1; i <= 20; i++ { + manyEntries = append(manyEntries, stats.Entry{ + Timestamp: "2026-01-" + pad2(i) + "T10:00:00Z", + CommandType: "search", + OutputMode: "compact", + ResultCount: 1, + OutputTokens: 50, + GrepTokens: 200, + }) + } + summary := stats.Summarize(manyEntries, "ollama") + m := newStatsUIModel(summary, manyEntries, "ollama") + m.view = statsViewHistory + // Sans height défini, maxVisible = 15, donc 20 > 15 → pagination + out := stripANSI(m.View()) + + if !strings.Contains(out, "showing") { + t.Errorf("expected pagination message with 'showing', got:\n%s", out) + } + if !strings.Contains(out, "of 20 days") { + t.Errorf("expected 'of 20 days' in pagination, got:\n%s", out) + } +} + +func pad2(n int) string { + return fmt.Sprintf("%02d", n) +} + +// TestStatsUpdateTabToHistory vérifie la navigation h/tab depuis summary → history. +func TestStatsUpdateTabToHistory(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("h")}) + m2 := newModel.(statsUIModel) + + if m2.view != statsViewHistory { + t.Errorf("expected statsViewHistory after 'h', got %v", m2.view) + } + if m2.selected != 0 { + t.Errorf("expected selected=0 after switch to history, got %d", m2.selected) + } +} + +// TestStatsUpdateTabBackToSummary vérifie le retour en summary avec tab depuis history. +func TestStatsUpdateTabBackToSummary(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab}) + m2 := newModel.(statsUIModel) + + if m2.view != statsViewSummary { + t.Errorf("expected statsViewSummary after tab from history, got %v", m2.view) + } +} + +// TestStatsUpdateEscBackToSummary vérifie le retour en summary avec esc. +func TestStatsUpdateEscBackToSummary(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + m2 := newModel.(statsUIModel) + + if m2.view != statsViewSummary { + t.Errorf("expected statsViewSummary after esc, got %v", m2.view) + } +} + +// TestStatsUpdateNavigationDown vérifie que down/j incrémente selected. +func TestStatsUpdateNavigationDown(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + m.selected = 0 + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) + m2 := newModel.(statsUIModel) + + if m2.selected != 1 { + t.Errorf("expected selected=1 after down, got %d", m2.selected) + } +} + +// TestStatsUpdateNavigationUp vérifie que up/k décrémente selected. +func TestStatsUpdateNavigationUp(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + m.selected = 1 + + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + m2 := newModel.(statsUIModel) + + if m2.selected != 0 { + t.Errorf("expected selected=0 after up, got %d", m2.selected) + } +} + +// TestStatsUpdateNavigationBounds vérifie que les bornes sont respectées. +func TestStatsUpdateNavigationBounds(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + m.view = statsViewHistory + m.selected = 0 + + // Ne peut pas aller en dessous de 0 + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyUp}) + m2 := newModel.(statsUIModel) + if m2.selected != 0 { + t.Errorf("expected selected=0 at lower bound, got %d", m2.selected) + } + + // Ne peut pas dépasser len(days)-1 + m.selected = len(m.days) - 1 + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyDown}) + m3 := newModel.(statsUIModel) + if m3.selected != len(m.days)-1 { + t.Errorf("expected selected=%d at upper bound, got %d", len(m.days)-1, m3.selected) + } +} + +// TestStatsUpdateQuit vérifie que q retourne tea.Quit. +func TestStatsUpdateQuit(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + + _, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")}) + if cmd == nil { + t.Fatal("expected a command after 'q', got nil") + } + msg := cmd() + if _, ok := msg.(tea.QuitMsg); !ok { + t.Errorf("expected tea.QuitMsg, got %T", msg) + } +} + +// TestStatsUpdateWindowSize vérifie que WindowSizeMsg met à jour width/height. +func TestStatsUpdateWindowSize(t *testing.T) { + summary := stats.Summarize(testEntries, "ollama") + m := newStatsUIModel(summary, testEntries, "ollama") + + newModel, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m2 := newModel.(statsUIModel) + + if m2.width != 120 { + t.Errorf("expected width=120, got %d", m2.width) + } + if m2.height != 40 { + t.Errorf("expected height=40, got %d", m2.height) + } +} diff --git a/config/config_test.go b/config/config_test.go index c06ab29..d053dac 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "runtime" "testing" ) @@ -306,6 +307,9 @@ func TestFindProjectRootWithSymlink(t *testing.T) { symlinkParent := t.TempDir() symlinkPath := filepath.Join(symlinkParent, "symlink-project") if err := os.Symlink(realDir, symlinkPath); err != nil { + if runtime.GOOS == "windows" { + t.Skipf("skipping: symlink creation requires elevated privileges on Windows: %v", err) + } t.Fatalf("failed to create symlink: %v", err) } diff --git a/mcp/server.go b/mcp/server.go index 0ec5ce1..e1b1936 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -14,6 +14,8 @@ import ( "sort" "strings" + "time" + "github.com/alpkeskin/gotoon" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -21,6 +23,7 @@ import ( "github.com/yoanbernabeu/grepai/embedder" "github.com/yoanbernabeu/grepai/rpg" "github.com/yoanbernabeu/grepai/search" + "github.com/yoanbernabeu/grepai/stats" "github.com/yoanbernabeu/grepai/store" "github.com/yoanbernabeu/grepai/trace" ) @@ -30,6 +33,7 @@ type Server struct { mcpServer *server.MCPServer projectRoot string workspaceName string // non-empty when started via --workspace or auto-detect + recorder *stats.Recorder } // SearchResult is a lightweight struct for MCP output. @@ -111,6 +115,7 @@ func encodeOutput(data any, format string) (string, error) { func NewServer(projectRoot string) (*Server, error) { s := &Server{ projectRoot: projectRoot, + recorder: stats.NewRecorder(projectRoot), } // Create MCP server @@ -132,6 +137,7 @@ func NewServerWithWorkspace(projectRoot, workspaceName string) (*Server, error) s := &Server{ projectRoot: projectRoot, workspaceName: workspaceName, + recorder: stats.NewRecorder(projectRoot), } s.mcpServer = server.NewMCPServer( @@ -334,6 +340,18 @@ func (s *Server) registerTools() { ), ) s.mcpServer.AddTool(rpgExploreTool, s.handleRPGExplore) + + // grepai_stats tool + statsTool := mcp.NewTool("grepai_stats", + mcp.WithDescription("Show token savings summary achieved by using grepai instead of grep-based workflows. Returns aggregated metrics from local stats."), + mcp.WithBoolean("history", + mcp.Description("Include per-day history breakdown (default: false)"), + ), + mcp.WithNumber("limit", + mcp.Description("Max days in history (default: 30, only used when history=true)"), + ), + ) + s.mcpServer.AddTool(statsTool, s.handleStats) } // handleSearch handles the grepai_search tool call. @@ -474,6 +492,7 @@ func (s *Server) handleSearch(ctx context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("failed to encode results: %v", err)), nil } + s.recordMCPStats(stats.Search, mcpOutputMode(compact, format), len(results), output) return mcp.NewToolResultText(output), nil } @@ -1137,6 +1156,11 @@ func (s *Server) handleTraceCallersFromStores(ctx context.Context, symbolName st return mcp.NewToolResultError(fmt.Sprintf("failed to encode results: %v", err)), nil } + resultCount := 0 + if tr, ok := data.(trace.TraceResult); ok { + resultCount = len(tr.Callers) + } + s.recordMCPStats(stats.TraceCallers, mcpOutputMode(compact, format), resultCount, output) return mcp.NewToolResultText(output), nil } @@ -1311,6 +1335,11 @@ func (s *Server) handleTraceCalleesFromStores(ctx context.Context, symbolName st return mcp.NewToolResultError(fmt.Sprintf("failed to encode results: %v", err)), nil } + resultCount := 0 + if tr, ok := data.(trace.TraceResult); ok { + resultCount = len(tr.Callees) + } + s.recordMCPStats(stats.TraceCallees, mcpOutputMode(compact, format), resultCount, output) return mcp.NewToolResultText(output), nil } @@ -1394,8 +1423,8 @@ func (s *Server) handleTraceGraph(ctx context.Context, request mcp.CallToolReque } defer symbolStore.Close() - stats, err := symbolStore.GetStats(ctx) - if err != nil || stats.TotalSymbols == 0 { + symStats, err := symbolStore.GetStats(ctx) + if err != nil || symStats.TotalSymbols == 0 { return mcp.NewToolResultError("symbol index is empty. Run 'grepai watch' first to build the index"), nil } @@ -1434,6 +1463,11 @@ func (s *Server) handleTraceGraph(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultError(fmt.Sprintf("failed to encode results: %v", err)), nil } + nodeCount := 0 + if result.Graph != nil { + nodeCount = len(result.Graph.Nodes) + } + s.recordMCPStats(stats.TraceGraph, mcpOutputMode(false, format), nodeCount, output) return mcp.NewToolResultText(output), nil } @@ -2007,3 +2041,80 @@ func (s *Server) handleRPGExplore(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultText(output), nil } + +// mcpOutputMode maps MCP compact/format params to a stats.OutputMode string. +func mcpOutputMode(compact bool, format string) stats.OutputMode { + if compact { + return stats.Compact + } + if format == "toon" { + return stats.Toon + } + return stats.Full +} + +// recordMCPStats fires a goroutine to record a stats entry without blocking. +func (s *Server) recordMCPStats(commandType, outputMode string, resultCount int, outputStr string) { + if s.recorder == nil { + return + } + entry := stats.Entry{ + Timestamp: time.Now().UTC().Format(time.RFC3339), + CommandType: commandType, + OutputMode: outputMode, + ResultCount: resultCount, + OutputTokens: embedder.EstimateTokens(outputStr), + GrepTokens: stats.GrepEquivalentTokens(resultCount), + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _ = s.recorder.Record(ctx, entry) + }() +} + +// handleStats handles the grepai_stats MCP tool call. +func (s *Server) handleStats(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + includeHistory := request.GetBool("history", false) + limit := request.GetInt("limit", 30) + + statsFilePath := stats.StatsPath(s.projectRoot) + entries, readErr := stats.ReadAll(statsFilePath) + if readErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to read stats: %v", readErr)), nil + } + + provider := "" + if cfg, cfgErr := config.Load(s.projectRoot); cfgErr == nil { + provider = cfg.Embedder.Provider + } + + summary := stats.Summarize(entries, provider) + + if !includeHistory { + output, encErr := encodeOutput(summary, "json") + if encErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to encode stats: %v", encErr)), nil + } + return mcp.NewToolResultText(output), nil + } + + days := stats.HistoryByDay(entries) + if limit > 0 && len(days) > limit { + days = days[:limit] + } + + histResult := struct { + Summary stats.Summary `json:"summary"` + History []stats.DaySummary `json:"history"` + }{ + Summary: summary, + History: days, + } + + output, encErr := encodeOutput(histResult, "json") + if encErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to encode stats: %v", encErr)), nil + } + return mcp.NewToolResultText(output), nil +} diff --git a/search/path_normalizer_test.go b/search/path_normalizer_test.go index 195b017..88f1c14 100644 --- a/search/path_normalizer_test.go +++ b/search/path_normalizer_test.go @@ -1,6 +1,7 @@ package search import ( + "os" "path/filepath" "testing" @@ -59,6 +60,43 @@ func TestNormalizeProjectPathPrefix(t *testing.T) { } } +func TestNormalizeForPathMatch_RelativePath(t *testing.T) { + // Branche filepath.Abs : projectRoot est un chemin relatif. + // On obtient le workdir courant pour construire le pathPrefix absolu attendu. + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + // projectRoot relatif pointant vers le répertoire courant + projectRoot := "." + // pathPrefix absolu sous le workdir courant + pathPrefix := filepath.Join(wd, "path_normalizer.go") + + got, err := NormalizeProjectPathPrefix(pathPrefix, projectRoot) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "path_normalizer.go" { + t.Fatalf("got %q, want %q", got, "path_normalizer.go") + } +} + +func TestNormalizeForPathMatch_NonExistentRoot(t *testing.T) { + // Branche EvalSymlinks error : root does not exist, fallback to unresolved path. + // Use a sub-directory of TempDir that is never created — absolute on all platforms. + projectRoot := filepath.Join(t.TempDir(), "nonexistent") + pathPrefix := filepath.Join(projectRoot, "src", "foo.go") + + got, err := NormalizeProjectPathPrefix(pathPrefix, projectRoot) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.ToSlash(filepath.Join("src", "foo.go")) + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + func TestNormalizeWorkspacePathPrefix(t *testing.T) { root := t.TempDir() projA := filepath.Join(root, "projA") diff --git a/stats/doc.go b/stats/doc.go new file mode 100644 index 0000000..4e30b61 --- /dev/null +++ b/stats/doc.go @@ -0,0 +1,4 @@ +// Package stats provides token savings tracking for grepai commands. +// It records each search and trace invocation locally and computes +// savings estimates compared to a traditional grep-based workflow. +package stats diff --git a/stats/flock_unix.go b/stats/flock_unix.go new file mode 100644 index 0000000..fbaf198 --- /dev/null +++ b/stats/flock_unix.go @@ -0,0 +1,20 @@ +//go:build !windows + +package stats + +import ( + "fmt" + "os" + "syscall" +) + +func flockExclusive(f *os.File) error { + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { + return fmt.Errorf("failed to acquire exclusive lock: %w", err) + } + return nil +} + +func funlock(f *os.File) error { + return syscall.Flock(int(f.Fd()), syscall.LOCK_UN) +} diff --git a/stats/flock_unix_test.go b/stats/flock_unix_test.go new file mode 100644 index 0000000..a796e1b --- /dev/null +++ b/stats/flock_unix_test.go @@ -0,0 +1,21 @@ +//go:build !windows + +package stats + +import ( + "os" + "testing" +) + +func TestFlockExclusive_Error(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "flock-test-*") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + // Fermer le fichier pour invalider le fd — syscall.Flock retournera EBADF. + f.Close() + + if err := flockExclusive(f); err == nil { + t.Fatal("flockExclusive() expected error on closed fd, got nil") + } +} diff --git a/stats/flock_windows.go b/stats/flock_windows.go new file mode 100644 index 0000000..95ddbd4 --- /dev/null +++ b/stats/flock_windows.go @@ -0,0 +1,45 @@ +//go:build windows + +package stats + +import ( + "fmt" + "os" + "syscall" + "unsafe" +) + +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + procLockFileEx = modkernel32.NewProc("LockFileEx") + procUnlockFile = modkernel32.NewProc("UnlockFileEx") +) + +const winLockfileExclusiveLock = 0x00000002 + +func flockExclusive(f *os.File) error { + var overlapped syscall.Overlapped + ret, _, err := procLockFileEx.Call( + f.Fd(), + uintptr(winLockfileExclusiveLock), + 0, 1, 0, + uintptr(unsafe.Pointer(&overlapped)), + ) + if ret == 0 { + return fmt.Errorf("failed to acquire exclusive lock: %w", err) + } + return nil +} + +func funlock(f *os.File) error { + var overlapped syscall.Overlapped + ret, _, err := procUnlockFile.Call( + f.Fd(), + 0, 1, 0, + uintptr(unsafe.Pointer(&overlapped)), + ) + if ret == 0 { + return fmt.Errorf("failed to unlock file: %w", err) + } + return nil +} diff --git a/stats/reader.go b/stats/reader.go new file mode 100644 index 0000000..0e0073a --- /dev/null +++ b/stats/reader.go @@ -0,0 +1,118 @@ +package stats + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "sort" + "strings" +) + +// ReadAll reads all entries from the NDJSON stats file at statsPath. +// Malformed lines are skipped with a warning to stderr. +// Returns an empty slice (not an error) when the file does not exist. +func ReadAll(statsPath string) ([]Entry, error) { + f, err := os.Open(statsPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("stats: open: %w", err) + } + defer f.Close() + + var entries []Entry + scanner := bufio.NewScanner(f) + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + var e Entry + if err := json.Unmarshal([]byte(line), &e); err != nil { + fmt.Fprintf(os.Stderr, "stats: skipping malformed line %d: %v\n", lineNum, err) + continue + } + entries = append(entries, e) + } + if err := scanner.Err(); err != nil && !errors.Is(err, io.EOF) { + return entries, fmt.Errorf("stats: read: %w", err) + } + return entries, nil +} + +// Summarize aggregates entries into a Summary. +// CostSavedUSD is set only for cloud providers. +func Summarize(entries []Entry, provider string) Summary { + s := Summary{ + ByCommandType: map[string]int{ + Search: 0, + TraceCallers: 0, + TraceCallees: 0, + TraceGraph: 0, + }, + ByOutputMode: map[string]int{ + Full: 0, + Compact: 0, + Toon: 0, + }, + } + + for _, e := range entries { + s.TotalQueries++ + s.OutputTokens += e.OutputTokens + s.GrepTokens += e.GrepTokens + s.ByCommandType[e.CommandType]++ + s.ByOutputMode[e.OutputMode]++ + } + + s.TokensSaved = s.GrepTokens - s.OutputTokens + if s.GrepTokens > 0 { + s.SavingsPct = float64(s.TokensSaved) / float64(s.GrepTokens) * 100 + } + + if IsCloudProvider(provider) { + saved := float64(s.TokensSaved) / 1_000_000 * CostPerMTokenUSD + s.CostSavedUSD = &saved + } + + return s +} + +// HistoryByDay groups entries by calendar day (UTC) and returns a slice +// sorted in descending order (most recent first). +func HistoryByDay(entries []Entry) []DaySummary { + byDate := map[string]*DaySummary{} + + for _, e := range entries { + day := "" + if len(e.Timestamp) >= 10 { + day = e.Timestamp[:10] // "YYYY-MM-DD" + } else { + day = "unknown" + } + d, ok := byDate[day] + if !ok { + d = &DaySummary{Date: day} + byDate[day] = d + } + d.QueryCount++ + d.OutputTokens += e.OutputTokens + d.GrepTokens += e.GrepTokens + d.TokensSaved += e.GrepTokens - e.OutputTokens + } + + days := make([]DaySummary, 0, len(byDate)) + for _, d := range byDate { + days = append(days, *d) + } + sort.Slice(days, func(i, j int) bool { + return days[i].Date > days[j].Date + }) + return days +} diff --git a/stats/recorder.go b/stats/recorder.go new file mode 100644 index 0000000..e5d7631 --- /dev/null +++ b/stats/recorder.go @@ -0,0 +1,69 @@ +package stats + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// Recorder appends stat entries to the local NDJSON stats file. +type Recorder struct { + statsPath string + lockPath string +} + +// NewRecorder creates a Recorder that writes to the stats file inside projectRoot. +func NewRecorder(projectRoot string) *Recorder { + return &Recorder{ + statsPath: StatsPath(projectRoot), + lockPath: LockPath(projectRoot), + } +} + +// Record appends one entry to the stats NDJSON file. +// The write is protected by a file lock for cross-process safety. +// If the context is canceled or the write fails, the error is returned +// but the caller is expected to discard it (fire-and-forget pattern). +func (r *Recorder) Record(ctx context.Context, e Entry) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + line, err := json.Marshal(e) + if err != nil { + return fmt.Errorf("stats: marshal entry: %w", err) + } + line = append(line, '\n') + + lockFile, err := os.OpenFile(r.lockPath, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + // Proceed without locking rather than failing the caller. + return r.appendLine(line) + } + defer lockFile.Close() + + if err := flockExclusive(lockFile); err != nil { + return r.appendLine(line) + } + defer func() { _ = funlock(lockFile) }() + + return r.appendLine(line) +} + +func (r *Recorder) appendLine(line []byte) error { + if err := os.MkdirAll(filepath.Dir(r.statsPath), 0o755); err != nil { + return fmt.Errorf("stats: create dir: %w", err) + } + f, err := os.OpenFile(r.statsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("stats: open file: %w", err) + } + defer f.Close() + + _, err = f.Write(line) + return err +} diff --git a/stats/stats.go b/stats/stats.go new file mode 100644 index 0000000..e71ec7a --- /dev/null +++ b/stats/stats.go @@ -0,0 +1,107 @@ +package stats + +import "path/filepath" + +// CommandType represents the type of grepai command that was executed. +type CommandType = string + +const ( + Search CommandType = "search" + TraceCallers CommandType = "trace-callers" + TraceCallees CommandType = "trace-callees" + TraceGraph CommandType = "trace-graph" +) + +// OutputMode represents the output format used for the command result. +type OutputMode = string + +const ( + Full OutputMode = "full" + Compact OutputMode = "compact" + Toon OutputMode = "toon" +) + +// GrepExpansionFactor is the multiplier applied to result count when estimating +// how many tokens a grep-based workflow would have consumed. A factor of 3 +// accounts for grep returning full file sections rather than isolated chunks. +const GrepExpansionFactor = 3 + +// DefaultChunkTokens is the default chunk size in tokens used for grep estimation. +// Mirrors indexer.DefaultChunkSize. +const DefaultChunkTokens = 512 + +// CostPerMTokenUSD is the reference cost per million input tokens used for +// estimating USD savings on cloud providers (conservative middle-ground rate). +const CostPerMTokenUSD = 5.00 + +// MinGrepTokens is the minimum grep-equivalent token estimate when result count +// is zero, to avoid division-by-zero in savings percentage. +const MinGrepTokens = 50 + +// StatsFileName is the name of the NDJSON stats file inside .grepai/. +const StatsFileName = "stats.json" + +// LockFileName is the name of the lock file used for safe concurrent writes. +const LockFileName = "stats.json.lock" + +// cloudProviders is the set of provider names that have an associated token cost. +var cloudProviders = map[string]bool{ + "openai": true, + "openrouter": true, + "synthetic": true, +} + +// IsCloudProvider returns true when the given provider name has a token cost. +func IsCloudProvider(provider string) bool { + return cloudProviders[provider] +} + +// GrepEquivalentTokens estimates how many tokens a grep-based workflow would +// have consumed for a given number of results. +func GrepEquivalentTokens(resultCount int) int { + if resultCount == 0 { + return MinGrepTokens + } + return resultCount * DefaultChunkTokens * GrepExpansionFactor +} + +// Entry represents a single recorded command event. +type Entry struct { + Timestamp string `json:"timestamp"` // RFC3339 UTC + CommandType string `json:"command_type"` // search | trace-callers | trace-callees | trace-graph + OutputMode string `json:"output_mode"` // full | compact | toon + ResultCount int `json:"result_count"` + OutputTokens int `json:"output_tokens"` // estimated tokens in grepai output + GrepTokens int `json:"grep_tokens"` // estimated tokens for grep equivalent +} + +// Summary is the aggregated view of all recorded entries. +type Summary struct { + TotalQueries int `json:"total_queries"` + OutputTokens int `json:"output_tokens"` + GrepTokens int `json:"grep_tokens"` + TokensSaved int `json:"tokens_saved"` + SavingsPct float64 `json:"savings_pct"` + CostSavedUSD *float64 `json:"cost_saved_usd"` + ByCommandType map[string]int `json:"by_command_type"` + ByOutputMode map[string]int `json:"by_output_mode"` +} + +// DaySummary holds per-day aggregated stats for the --history view. +type DaySummary struct { + Date string `json:"date"` + QueryCount int `json:"query_count"` + OutputTokens int `json:"output_tokens"` + GrepTokens int `json:"grep_tokens"` + TokensSaved int `json:"tokens_saved"` +} + +// StatsPath returns the absolute path of the stats NDJSON file. +func StatsPath(projectRoot string) string { + return filepath.Join(projectRoot, ".grepai", StatsFileName) +} + +// LockPath returns the absolute path of the stats lock file. +func LockPath(projectRoot string) string { + return filepath.Join(projectRoot, ".grepai", LockFileName) +} diff --git a/stats/stats_test.go b/stats/stats_test.go new file mode 100644 index 0000000..9f007c8 --- /dev/null +++ b/stats/stats_test.go @@ -0,0 +1,256 @@ +package stats_test + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/yoanbernabeu/grepai/stats" +) + +// ---- helpers ---- + +func writeStatsFile(t *testing.T, dir string, lines []string) { + t.Helper() + grepaiDir := filepath.Join(dir, ".grepai") + if err := os.MkdirAll(grepaiDir, 0o755); err != nil { + t.Fatalf("mkdir .grepai: %v", err) + } + path := filepath.Join(grepaiDir, stats.StatsFileName) + f, err := os.Create(path) + if err != nil { + t.Fatalf("create stats file: %v", err) + } + defer f.Close() + for _, l := range lines { + f.WriteString(l + "\n") + } +} + +func entryJSON(t *testing.T, e stats.Entry) string { + t.Helper() + b, err := json.Marshal(e) + if err != nil { + t.Fatalf("marshal entry: %v", err) + } + return string(b) +} + +func makeEntry(ts, ct, mode string, results, out, grep int) stats.Entry { + return stats.Entry{ + Timestamp: ts, + CommandType: ct, + OutputMode: mode, + ResultCount: results, + OutputTokens: out, + GrepTokens: grep, + } +} + +// ---- GrepEquivalentTokens ---- + +func TestGrepEquivalentTokens_NonZero(t *testing.T) { + got := stats.GrepEquivalentTokens(4) + want := 4 * stats.DefaultChunkTokens * stats.GrepExpansionFactor + if got != want { + t.Errorf("GrepEquivalentTokens(4) = %d, want %d", got, want) + } +} + +func TestGrepEquivalentTokens_Zero(t *testing.T) { + got := stats.GrepEquivalentTokens(0) + if got != stats.MinGrepTokens { + t.Errorf("GrepEquivalentTokens(0) = %d, want %d", got, stats.MinGrepTokens) + } +} + +// ---- Round-trip Record → ReadAll ---- + +func TestRecordReadAll_RoundTrip(t *testing.T) { + dir := t.TempDir() + rec := stats.NewRecorder(dir) + ctx := context.Background() + + entries := []stats.Entry{ + makeEntry(time.Now().UTC().Format(time.RFC3339), stats.Search, stats.Compact, 5, 100, 2560), + makeEntry(time.Now().UTC().Format(time.RFC3339), stats.TraceCallers, stats.Full, 3, 300, 1536), + } + for _, e := range entries { + if err := rec.Record(ctx, e); err != nil { + t.Fatalf("Record: %v", err) + } + } + + got, err := stats.ReadAll(stats.StatsPath(dir)) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if len(got) != len(entries) { + t.Fatalf("ReadAll returned %d entries, want %d", len(got), len(entries)) + } + for i, e := range got { + if e.CommandType != entries[i].CommandType { + t.Errorf("[%d] CommandType = %q, want %q", i, e.CommandType, entries[i].CommandType) + } + if e.OutputTokens != entries[i].OutputTokens { + t.Errorf("[%d] OutputTokens = %d, want %d", i, e.OutputTokens, entries[i].OutputTokens) + } + } +} + +// ---- ReadAll: file not found ---- + +func TestReadAll_FileNotFound(t *testing.T) { + dir := t.TempDir() + entries, err := stats.ReadAll(filepath.Join(dir, stats.StatsFileName)) + if err != nil { + t.Fatalf("expected nil error for missing file, got %v", err) + } + if len(entries) != 0 { + t.Errorf("expected empty slice, got %d entries", len(entries)) + } +} + +// ---- ReadAll: corrupted line is skipped ---- + +func TestReadAll_CorruptedLineSkipped(t *testing.T) { + dir := t.TempDir() + good := makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Full, 2, 80, 1024) + writeStatsFile(t, dir, []string{ + entryJSON(t, good), + "THIS IS NOT JSON", + entryJSON(t, good), + }) + + entries, err := stats.ReadAll(stats.StatsPath(dir)) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if len(entries) != 2 { + t.Errorf("expected 2 valid entries, got %d", len(entries)) + } +} + +// ---- Summarize ---- + +func TestSummarize_Totals(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Compact, 5, 100, 2560), + makeEntry("2026-02-22T11:00:00Z", stats.TraceCallers, stats.Full, 3, 300, 1536), + } + s := stats.Summarize(entries, "ollama") + + if s.TotalQueries != 2 { + t.Errorf("TotalQueries = %d, want 2", s.TotalQueries) + } + if s.OutputTokens != 400 { + t.Errorf("OutputTokens = %d, want 400", s.OutputTokens) + } + if s.GrepTokens != 4096 { + t.Errorf("GrepTokens = %d, want 4096", s.GrepTokens) + } + if s.TokensSaved != 3696 { + t.Errorf("TokensSaved = %d, want 3696", s.TokensSaved) + } +} + +func TestSummarize_SavingsPct(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Full, 4, 200, 2048), + } + s := stats.Summarize(entries, "openai") + want := float64(2048-200) / float64(2048) * 100 + if s.SavingsPct < want-0.01 || s.SavingsPct > want+0.01 { + t.Errorf("SavingsPct = %.2f, want ~%.2f", s.SavingsPct, want) + } +} + +func TestSummarize_NoGrepTokens_NoPanic(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Full, 0, 10, stats.MinGrepTokens), + } + s := stats.Summarize(entries, "ollama") + if s.SavingsPct < 0 { + t.Errorf("SavingsPct should not be negative") + } +} + +// ---- Summarize: cloud vs local provider ---- + +func TestSummarize_CloudProvider_CostSet(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Compact, 5, 100, 2560), + } + s := stats.Summarize(entries, "openai") + if s.CostSavedUSD == nil { + t.Fatal("expected CostSavedUSD to be non-nil for cloud provider") + } + if *s.CostSavedUSD <= 0 { + t.Errorf("CostSavedUSD = %f, want > 0", *s.CostSavedUSD) + } +} + +func TestSummarize_LocalProvider_CostNil(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-22T10:00:00Z", stats.Search, stats.Full, 5, 100, 2560), + } + for _, provider := range []string{"ollama", "lmstudio"} { + s := stats.Summarize(entries, provider) + if s.CostSavedUSD != nil { + t.Errorf("provider %q: expected CostSavedUSD nil, got %v", provider, *s.CostSavedUSD) + } + } +} + +// ---- HistoryByDay ---- + +func TestHistoryByDay_Grouping(t *testing.T) { + entries := []stats.Entry{ + makeEntry("2026-02-20T10:00:00Z", stats.Search, stats.Full, 2, 80, 1024), + makeEntry("2026-02-21T09:00:00Z", stats.Search, stats.Full, 3, 120, 1536), + makeEntry("2026-02-21T15:00:00Z", stats.TraceCallers, stats.Compact, 1, 40, 512), + makeEntry("2026-02-22T08:00:00Z", stats.Search, stats.Compact, 5, 100, 2560), + } + days := stats.HistoryByDay(entries) + + if len(days) != 3 { + t.Fatalf("expected 3 days, got %d", len(days)) + } + // Sorted descending + if days[0].Date != "2026-02-22" { + t.Errorf("days[0].Date = %q, want 2026-02-22", days[0].Date) + } + if days[1].Date != "2026-02-21" { + t.Errorf("days[1].Date = %q, want 2026-02-21", days[1].Date) + } + if days[1].QueryCount != 2 { + t.Errorf("days[1].QueryCount = %d, want 2", days[1].QueryCount) + } +} + +func TestHistoryByDay_Empty(t *testing.T) { + days := stats.HistoryByDay(nil) + if len(days) != 0 { + t.Errorf("expected empty slice for nil entries") + } +} + +// ---- IsCloudProvider ---- + +func TestIsCloudProvider(t *testing.T) { + cloud := []string{"openai", "openrouter", "synthetic"} + for _, p := range cloud { + if !stats.IsCloudProvider(p) { + t.Errorf("IsCloudProvider(%q) = false, want true", p) + } + } + local := []string{"ollama", "lmstudio", ""} + for _, p := range local { + if stats.IsCloudProvider(p) { + t.Errorf("IsCloudProvider(%q) = true, want false", p) + } + } +}