diff --git a/pkg/deck/pricing.go b/pkg/deck/pricing.go index 52935d0..9dc290e 100644 --- a/pkg/deck/pricing.go +++ b/pkg/deck/pricing.go @@ -34,6 +34,7 @@ func DefaultPricing() PricingTable { "claude-sonnet-4": {Input: 3.00, Output: 15.00, CacheRead: 0.30, CacheWrite: 3.75}, "claude-sonnet-3.7": {Input: 3.00, Output: 15.00, CacheRead: 0.30, CacheWrite: 3.75}, "claude-haiku-4.5": {Input: 1.00, Output: 5.00, CacheRead: 0.10, CacheWrite: 1.25}, + "claude-haiku-4.6": {Input: 1.00, Output: 5.00, CacheRead: 0.10, CacheWrite: 1.25}, "claude-3.5-sonnet": {Input: 3.00, Output: 15.00, CacheRead: 0.30, CacheWrite: 3.75}, "claude-3.5-haiku": {Input: 0.80, Output: 4.00, CacheRead: 0.08, CacheWrite: 1.00}, "claude-3-opus": {Input: 15.00, Output: 75.00, CacheRead: 1.50, CacheWrite: 18.75}, diff --git a/pkg/deck/pricing_test.go b/pkg/deck/pricing_test.go index 5e9d908..630d1c2 100644 --- a/pkg/deck/pricing_test.go +++ b/pkg/deck/pricing_test.go @@ -193,6 +193,19 @@ var _ = Describe("PricingForModel", func() { Expect(p.CacheRead).To(Equal(0.14)) }) + It("resolves claude-haiku-4.6", func() { + p, ok := PricingForModel(pricing, "claude-haiku-4.6") + Expect(ok).To(BeTrue()) + Expect(p.Input).To(Equal(1.00)) + Expect(p.Output).To(Equal(5.00)) + }) + + It("resolves claude-haiku-4-6 with date suffix", func() { + p, ok := PricingForModel(pricing, "claude-haiku-4-6-20260219") + Expect(ok).To(BeTrue()) + Expect(p.Input).To(Equal(1.00)) + }) + It("returns false for unknown models", func() { _, ok := PricingForModel(pricing, "totally-unknown-model") Expect(ok).To(BeFalse()) diff --git a/pkg/deck/query.go b/pkg/deck/query.go index 82ccbb2..3dcebd5 100644 --- a/pkg/deck/query.go +++ b/pkg/deck/query.go @@ -561,10 +561,20 @@ func (q *Query) buildSessionMessages(nodes []*ent.Node) ([]SessionMessage, map[s toolFrequency := map[string]int{} var lastTime time.Time + var lastModel string for i, node := range nodes { blocks, _ := parseContentBlocks(node.Content) t := tokenCounts(node) - inputCost, outputCost, totalCost := q.costForNode(node, t) + + model := normalizeModel(node.Model) + if model == "" { + model = lastModel + } + if model != "" { + lastModel = model + } + + inputCost, outputCost, totalCost := q.costForModel(model, t) toolCalls := extractToolCalls(blocks) for _, tool := range toolCalls { @@ -581,7 +591,7 @@ func (q *Query) buildSessionMessages(nodes []*ent.Node) ([]SessionMessage, map[s messages = append(messages, SessionMessage{ Hash: node.ID, Role: node.Role, - Model: node.Model, + Model: model, Timestamp: node.CreatedAt, Delta: delta, InputTokens: t.Input, @@ -759,6 +769,7 @@ func (q *Query) buildSessionSummaryFromNodes(nodes []*ent.Node) (SessionSummary, hasToolError := false hasGitActivity := false + var lastModel string for _, n := range nodes { blocks, _ := parseContentBlocks(n.Content) toolCalls += countToolCalls(blocks) @@ -782,9 +793,13 @@ func (q *Query) buildSessionSummaryFromNodes(nodes []*ent.Node) (SessionSummary, outputTokens += t.Output model := normalizeModel(n.Model) + if model == "" { + model = lastModel + } if model == "" { continue } + lastModel = model pricing, ok := PricingForModel(q.pricing, model) if !ok { @@ -876,8 +891,7 @@ func (q *Query) loadAncestry(ctx context.Context, leaf *ent.Node) ([]*ent.Node, return nodes, nil } -func (q *Query) costForNode(node *ent.Node, t nodeTokens) (float64, float64, float64) { - model := normalizeModel(node.Model) +func (q *Query) costForModel(model string, t nodeTokens) (float64, float64, float64) { if model == "" { return 0, 0, 0 } diff --git a/pkg/deck/query_test.go b/pkg/deck/query_test.go index 3ac02aa..a976291 100644 --- a/pkg/deck/query_test.go +++ b/pkg/deck/query_test.go @@ -71,6 +71,102 @@ var _ = Describe("Session labels", func() { }) }) +var _ = Describe("Empty-model cost fallback", func() { + intPtr := func(v int) *int { return &v } + + It("uses last-seen model for response nodes with empty model", func() { + pricing := DefaultPricing() + q := &Query{pricing: pricing} + + nodes := []*ent.Node{ + { + ID: "node-1", + Role: "user", + Model: "claude-opus-4-6-20260219", + Content: []map[string]any{{ + "text": "Hello", + "type": "text", + }}, + PromptTokens: intPtr(100), + CompletionTokens: intPtr(0), + }, + { + ID: "node-2", + Role: "assistant", + Model: "", // empty model — the bug + Content: []map[string]any{{"text": "Hi!", "type": "text"}}, + PromptTokens: intPtr(0), + CompletionTokens: intPtr(50), + }, + } + + summary, modelCosts, _, err := q.buildSessionSummaryFromNodes(nodes) + Expect(err).NotTo(HaveOccurred()) + + // The assistant node should have been costed using the user node's model + Expect(summary.TotalCost).To(BeNumerically(">", 0)) + Expect(modelCosts).To(HaveKey("claude-opus-4.6")) + cost := modelCosts["claude-opus-4.6"] + Expect(cost.OutputTokens).To(Equal(int64(50))) + Expect(cost.TotalCost).To(BeNumerically(">", 0)) + }) + + It("keeps summary and message totals consistent", func() { + pricing := DefaultPricing() + q := &Query{pricing: pricing} + + nodes := []*ent.Node{ + { + ID: "node-1", + Role: "user", + Model: "claude-opus-4-6-20260219", + Content: []map[string]any{{ + "text": "Hello", + "type": "text", + }}, + PromptTokens: intPtr(100), + }, + { + ID: "node-2", + Role: "assistant", + Model: "", + Content: []map[string]any{{"text": "Hi!", "type": "text"}}, + CompletionTokens: intPtr(500000), + }, + } + + summary, _, _, err := q.buildSessionSummaryFromNodes(nodes) + Expect(err).NotTo(HaveOccurred()) + + messages, _ := q.buildSessionMessages(nodes) + messageTotal := 0.0 + for _, msg := range messages { + messageTotal += msg.TotalCost + } + + Expect(summary.TotalCost).To(BeNumerically("~", messageTotal, 1e-12)) + }) + + It("skips nodes when no model has been seen yet", func() { + pricing := DefaultPricing() + q := &Query{pricing: pricing} + + nodes := []*ent.Node{ + { + ID: "node-1", + Role: "assistant", + Model: "", // no model, and no prior model + Content: []map[string]any{{"text": "orphan", "type": "text"}}, + CompletionTokens: intPtr(50), + }, + } + + _, modelCosts, _, err := q.buildSessionSummaryFromNodes(nodes) + Expect(err).NotTo(HaveOccurred()) + Expect(modelCosts).To(BeEmpty()) + }) +}) + var _ = Describe("Analytics helper functions", func() { Describe("buildDurationBuckets", func() { It("distributes sessions into correct duration buckets", func() { diff --git a/proxy/proxy.go b/proxy/proxy.go index 9256416..a39097e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -559,6 +559,9 @@ func (p *Proxy) enqueueStreamedResponse(allChunks [][]byte, fullContent string, finalResp := p.reconstructStreamedResponse(allChunks, fullContent, streamUsage, meta, prov) if finalResp != nil { + if finalResp.Model == "" && parsedReq.Model != "" { + finalResp.Model = parsedReq.Model + } p.workerPool.Enqueue(worker.Job{ Provider: prov.Name(), AgentName: agentName,