diff --git a/cmd/tapes/serve/proxy/proxy.go b/cmd/tapes/serve/proxy/proxy.go index b8d7784..7db1129 100644 --- a/cmd/tapes/serve/proxy/proxy.go +++ b/cmd/tapes/serve/proxy/proxy.go @@ -40,7 +40,7 @@ const proxyLongDesc string = `Run the proxy server. The proxy intercepts all requests and transparently forwards them to the configured upstream URL, recording request/response conversation turns. -Supported provider types: anthropic, openai, ollama +Supported provider types: anthropic, openai, ollama, vertex Optionally configure vector storage and embeddings of text content for "tapes search" agentic functionality.` @@ -109,7 +109,7 @@ func NewProxyCmd() *cobra.Command { defaults := config.NewDefaultConfig() cmd.Flags().StringVarP(&cmder.listen, "listen", "l", defaults.Proxy.Listen, "Address for proxy to listen on") cmd.Flags().StringVarP(&cmder.upstream, "upstream", "u", defaults.Proxy.Upstream, "Upstream LLM provider URL") - cmd.Flags().StringVarP(&cmder.providerType, "provider", "p", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama)") + cmd.Flags().StringVarP(&cmder.providerType, "provider", "p", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama, vertex)") cmd.Flags().StringVarP(&cmder.sqlitePath, "sqlite", "s", "", "Path to SQLite database (default: in-memory)") cmd.Flags().StringVar(&cmder.vectorStoreProvider, "vector-store-provider", defaults.VectorStore.Provider, "Vector store provider type (e.g., chroma, sqlite)") cmd.Flags().StringVar(&cmder.vectorStoreTarget, "vector-store-target", defaults.VectorStore.Target, "Vector store URL (e.g., http://localhost:8000)") diff --git a/cmd/tapes/serve/serve.go b/cmd/tapes/serve/serve.go index 1b76ee8..99ea794 100644 --- a/cmd/tapes/serve/serve.go +++ b/cmd/tapes/serve/serve.go @@ -143,7 +143,7 @@ func NewServeCmd() *cobra.Command { cmd.Flags().StringVarP(&cmder.proxyListen, "proxy-listen", "p", defaults.Proxy.Listen, "Address for proxy to listen on") cmd.Flags().StringVarP(&cmder.apiListen, "api-listen", "a", defaults.API.Listen, "Address for API server to listen on") cmd.Flags().StringVarP(&cmder.upstream, "upstream", "u", defaults.Proxy.Upstream, "Upstream LLM provider URL") - cmd.Flags().StringVar(&cmder.providerType, "provider", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama)") + cmd.Flags().StringVar(&cmder.providerType, "provider", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama, vertex)") cmd.Flags().StringVarP(&cmder.sqlitePath, "sqlite", "s", "", "Path to SQLite database (e.g., ./tapes.sqlite, in-memory)") cmd.Flags().StringVar(&cmder.vectorStoreProvider, "vector-store-provider", defaults.VectorStore.Provider, "Vector store provider type (e.g., chroma, sqlite)") cmd.Flags().StringVar(&cmder.vectorStoreTarget, "vector-store-target", defaults.VectorStore.Target, "Vector store target filepath for sqlite or URL for vector store service (e.g., http://localhost:8000, ./db.sqlite)") diff --git a/pkg/llm/provider/provider.go b/pkg/llm/provider/provider.go index 8878d8c..0df25bb 100644 --- a/pkg/llm/provider/provider.go +++ b/pkg/llm/provider/provider.go @@ -14,7 +14,7 @@ var ErrStreamingNotImplemented = errors.New("streaming not implemented for this // Each provider implementation knows how to parse its specific // API format into the internal representation. type Provider interface { - // Name returns the canonical provider name (e.g., "anthropic", "openai", "ollama") + // Name returns the canonical provider name (e.g., "anthropic", "openai", "ollama", "vertex") Name() string // DefaultStreaming reports whether this provider streams responses by default diff --git a/pkg/llm/provider/supported.go b/pkg/llm/provider/supported.go index 969db3a..fe92242 100644 --- a/pkg/llm/provider/supported.go +++ b/pkg/llm/provider/supported.go @@ -6,6 +6,7 @@ import ( "github.com/papercomputeco/tapes/pkg/llm/provider/anthropic" "github.com/papercomputeco/tapes/pkg/llm/provider/ollama" "github.com/papercomputeco/tapes/pkg/llm/provider/openai" + "github.com/papercomputeco/tapes/pkg/llm/provider/vertex" ) // Supported provider type constants @@ -13,11 +14,12 @@ const ( Anthropic = "anthropic" OpenAI = "openai" Ollama = "ollama" + Vertex = "vertex" ) // SupportedProviders returns the list of all supported provider type names. func SupportedProviders() []string { - return []string{Anthropic, OpenAI, Ollama} + return []string{Anthropic, OpenAI, Ollama, Vertex} } // New creates a new Provider instance for the given provider type. @@ -30,6 +32,8 @@ func New(providerType string) (Provider, error) { return openai.New(), nil case Ollama: return ollama.New(), nil + case Vertex: + return vertex.New(), nil default: return nil, fmt.Errorf("unknown provider type: %q (supported: %v)", providerType, SupportedProviders()) } diff --git a/pkg/llm/provider/vertex/types.go b/pkg/llm/provider/vertex/types.go new file mode 100644 index 0000000..818c980 --- /dev/null +++ b/pkg/llm/provider/vertex/types.go @@ -0,0 +1,58 @@ +package vertex + +// vertexRequest represents a Vertex AI request for Anthropic Claude models. +// This is the same as the Anthropic Messages API format with two differences: +// - "model" is omitted (it is specified in the Vertex AI endpoint URL) +// - "anthropic_version" is included in the body (not as a header) +type vertexRequest struct { + AnthropicVersion string `json:"anthropic_version,omitempty"` + Model string `json:"model,omitempty"` + Messages []vertexMessage `json:"messages"` + System any `json:"system,omitempty"` + MaxTokens int `json:"max_tokens"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + Stop []string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` +} + +type vertexMessage struct { + Role string `json:"role"` + + // Union type: can be "string" or "[]vertexContentBlock" + Content any `json:"content"` +} + +type vertexContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *vertexSource `json:"source,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` +} + +type vertexSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +// vertexResponse represents a Vertex AI response for Anthropic Claude models. +// The response format is identical to the Anthropic Messages API. +type vertexResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []vertexContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *vertexUsage `json:"usage,omitempty"` +} + +type vertexUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/pkg/llm/provider/vertex/vertex.go b/pkg/llm/provider/vertex/vertex.go new file mode 100644 index 0000000..7819050 --- /dev/null +++ b/pkg/llm/provider/vertex/vertex.go @@ -0,0 +1,191 @@ +// Package vertex implements the Provider interface for Anthropic Claude models +// accessed through Google Cloud's Vertex AI platform. +// +// The Vertex AI Claude API is nearly identical to the Anthropic Messages API, +// with two key differences: +// - "model" is not passed in the request body (it is specified in the Vertex AI endpoint URL) +// - "anthropic_version" is passed in the request body (rather than as a header) +package vertex + +import ( + "encoding/json" + "strings" + "time" + + "github.com/papercomputeco/tapes/pkg/llm" +) + +// Provider implements the Provider interface for Vertex AI (Anthropic Claude). +type Provider struct{} + +// New +func New() *Provider { return &Provider{} } + +// Name +func (p *Provider) Name() string { + return "vertex" +} + +// DefaultStreaming is false - Vertex AI requires explicit "stream": true. +func (p *Provider) DefaultStreaming() bool { + return false +} + +func (p *Provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { + var req vertexRequest + if err := json.Unmarshal(payload, &req); err != nil { + return nil, err + } + + system := parseVertexSystem(req.System) + messages := make([]llm.Message, 0, len(req.Messages)) + for _, msg := range req.Messages { + converted := llm.Message{Role: msg.Role} + + switch content := msg.Content.(type) { + case string: + converted.Content = []llm.ContentBlock{{Type: "text", Text: content}} + case []any: + for _, item := range content { + if block, ok := item.(map[string]any); ok { + cb := llm.ContentBlock{} + if t, ok := block["type"].(string); ok { + cb.Type = t + } + if text, ok := block["text"].(string); ok { + cb.Text = text + } + if source, ok := block["source"].(map[string]any); ok { + if mt, ok := source["media_type"].(string); ok { + cb.MediaType = mt + } + if data, ok := source["data"].(string); ok { + cb.ImageBase64 = data + } + } + + // Tool use + if id, ok := block["id"].(string); ok { + cb.ToolUseID = id + } + if name, ok := block["name"].(string); ok { + cb.ToolName = name + } + if input, ok := block["input"].(map[string]any); ok { + cb.ToolInput = input + } + converted.Content = append(converted.Content, cb) + } + } + } + + messages = append(messages, converted) + } + + extra := map[string]any{} + if req.AnthropicVersion != "" { + extra["anthropic_version"] = req.AnthropicVersion + } + + result := &llm.ChatRequest{ + Model: req.Model, + Messages: messages, + System: system, + MaxTokens: &req.MaxTokens, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + Stop: req.Stop, + Stream: req.Stream, + RawRequest: payload, + } + + if len(extra) > 0 { + result.Extra = extra + } + + return result, nil +} + +func parseVertexSystem(system any) string { + if system == nil { + return "" + } + + switch value := system.(type) { + case string: + return value + case []any: + var builder strings.Builder + for _, item := range value { + block, ok := item.(map[string]any) + if !ok { + continue + } + blockType, _ := block["type"].(string) + text, _ := block["text"].(string) + if blockType == "text" && text != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(text) + } + } + return builder.String() + default: + return "" + } +} + +func (p *Provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { + var resp vertexResponse + if err := json.Unmarshal(payload, &resp); err != nil { + return nil, err + } + + content := make([]llm.ContentBlock, 0, len(resp.Content)) + for _, block := range resp.Content { + cb := llm.ContentBlock{Type: block.Type} + switch block.Type { + case "text": + cb.Text = block.Text + case "tool_use": + cb.ToolUseID = block.ID + cb.ToolName = block.Name + cb.ToolInput = block.Input + } + content = append(content, cb) + } + + var usage *llm.Usage + if resp.Usage != nil { + usage = &llm.Usage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + } + + result := &llm.ChatResponse{ + Model: resp.Model, + Message: llm.Message{ + Role: resp.Role, + Content: content, + }, + Done: true, + StopReason: resp.StopReason, + Usage: usage, + CreatedAt: time.Now(), + RawResponse: payload, + Extra: map[string]any{ + "id": resp.ID, + "type": resp.Type, + }, + } + + return result, nil +} + +func (p *Provider) ParseStreamChunk(_ []byte) (*llm.StreamChunk, error) { + panic("not implemented") +} diff --git a/pkg/llm/provider/vertex/vertex_suite_test.go b/pkg/llm/provider/vertex/vertex_suite_test.go new file mode 100644 index 0000000..b91fc71 --- /dev/null +++ b/pkg/llm/provider/vertex/vertex_suite_test.go @@ -0,0 +1,13 @@ +package vertex_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestVertex(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Vertex Provider Suite") +} diff --git a/pkg/llm/provider/vertex/vertex_test.go b/pkg/llm/provider/vertex/vertex_test.go new file mode 100644 index 0000000..7a2a99f --- /dev/null +++ b/pkg/llm/provider/vertex/vertex_test.go @@ -0,0 +1,395 @@ +package vertex_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/papercomputeco/tapes/pkg/llm/provider" + "github.com/papercomputeco/tapes/pkg/llm/provider/vertex" +) + +var _ = Describe("Vertex Provider", func() { + var p provider.Provider + + BeforeEach(func() { + p = vertex.New() + }) + + Describe("Name", func() { + It("returns 'vertex'", func() { + Expect(p.Name()).To(Equal("vertex")) + }) + }) + + Describe("ParseRequest", func() { + Context("with a simple text request", func() { + It("parses messages correctly", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.MaxTokens).To(Equal(1024)) + Expect(req.Messages).To(HaveLen(1)) + Expect(req.Messages[0].Role).To(Equal("user")) + Expect(req.Messages[0].GetText()).To(Equal("Hello, Claude!")) + Expect(req.Extra).To(HaveKeyWithValue("anthropic_version", "vertex-2023-10-16")) + }) + }) + + Context("with model in the body", func() { + It("parses the model field when present", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "model": "claude-3-5-sonnet@20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Model).To(Equal("claude-3-5-sonnet@20241022")) + }) + }) + + Context("without model in the body", func() { + It("parses successfully with empty model", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 256, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Model).To(BeEmpty()) + }) + }) + + Context("with content block array format", func() { + It("parses text content blocks", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"} + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages).To(HaveLen(1)) + Expect(req.Messages[0].Content).To(HaveLen(1)) + Expect(req.Messages[0].Content[0].Type).To(Equal("text")) + Expect(req.Messages[0].Content[0].Text).To(Equal("What's in this image?")) + }) + + It("parses image content blocks with base64 source", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgo..." + } + } + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages[0].Content).To(HaveLen(2)) + Expect(req.Messages[0].Content[1].MediaType).To(Equal("image/png")) + Expect(req.Messages[0].Content[1].ImageBase64).To(Equal("iVBORw0KGgo...")) + }) + }) + + Context("with system prompt", func() { + It("parses a string system field", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "system": "You are a helpful coding assistant.", + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.System).To(Equal("You are a helpful coding assistant.")) + }) + + It("parses an array system field", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "system": [ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise."} + ], + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.System).To(Equal("You are a helpful assistant.\nBe concise.")) + }) + }) + + Context("with generation parameters", func() { + It("parses temperature, top_p, top_k", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Temperature).To(BeNumerically("~", 0.7, 0.001)) + Expect(*req.TopP).To(BeNumerically("~", 0.9, 0.001)) + Expect(*req.TopK).To(Equal(40)) + }) + + It("parses stop_sequences", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "stop_sequences": ["END", "STOP"], + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Stop).To(ConsistOf("END", "STOP")) + }) + }) + + Context("with streaming flag", func() { + It("parses stream: true", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "stream": true, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Stream).To(BeTrue()) + }) + + It("parses stream: false", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "stream": false, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Stream).To(BeFalse()) + }) + }) + + Context("with tool use in messages", func() { + It("parses tool_use content blocks", func() { + payload := []byte(`{ + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1024, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {"location": "San Francisco"} + } + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages[0].Content).To(HaveLen(1)) + Expect(req.Messages[0].Content[0].Type).To(Equal("tool_use")) + Expect(req.Messages[0].Content[0].ToolUseID).To(Equal("toolu_123")) + Expect(req.Messages[0].Content[0].ToolName).To(Equal("get_weather")) + Expect(req.Messages[0].Content[0].ToolInput).To(HaveKeyWithValue("location", "San Francisco")) + }) + }) + + Context("with invalid payload", func() { + It("returns an error for invalid JSON", func() { + payload := []byte(`not valid json`) + _, err := p.ParseRequest(payload) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("preserves raw request", func() { + It("stores the original payload in RawRequest", func() { + payload := []byte(`{"anthropic_version": "vertex-2023-10-16", "max_tokens": 1024, "messages": []}`) + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect([]byte(req.RawRequest)).To(Equal(payload)) + }) + }) + }) + + Describe("ParseResponse", func() { + Context("with a simple text response", func() { + It("parses the response correctly", func() { + payload := []byte(`{ + "id": "msg_01234567890", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello! How can I help you today?"} + ], + "model": "claude-3-5-sonnet@20241022", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 25 + } + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Model).To(Equal("claude-3-5-sonnet@20241022")) + Expect(resp.Message.Role).To(Equal("assistant")) + Expect(resp.Message.GetText()).To(Equal("Hello! How can I help you today?")) + Expect(resp.StopReason).To(Equal("end_turn")) + Expect(resp.Done).To(BeTrue()) + }) + }) + + Context("with usage metrics", func() { + It("parses token counts correctly", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "claude-3-5-sonnet@20241022", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 100, + "output_tokens": 50 + } + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Usage).NotTo(BeNil()) + Expect(resp.Usage.PromptTokens).To(Equal(100)) + Expect(resp.Usage.CompletionTokens).To(Equal(50)) + Expect(resp.Usage.TotalTokens).To(Equal(150)) + }) + }) + + Context("with tool_use response", func() { + It("parses tool_use content blocks", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll check the weather for you."}, + { + "type": "tool_use", + "id": "toolu_456", + "name": "get_weather", + "input": {"location": "NYC", "unit": "celsius"} + } + ], + "model": "claude-3-5-sonnet@20241022", + "stop_reason": "tool_use" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Message.Content).To(HaveLen(2)) + Expect(resp.Message.Content[0].Type).To(Equal("text")) + Expect(resp.Message.Content[1].Type).To(Equal("tool_use")) + Expect(resp.Message.Content[1].ToolUseID).To(Equal("toolu_456")) + Expect(resp.Message.Content[1].ToolName).To(Equal("get_weather")) + Expect(resp.StopReason).To(Equal("tool_use")) + }) + }) + + Context("with Extra fields", func() { + It("stores id and type in Extra", func() { + payload := []byte(`{ + "id": "msg_abc123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "claude-3-5-sonnet@20241022", + "stop_reason": "end_turn" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Extra).To(HaveKeyWithValue("id", "msg_abc123")) + Expect(resp.Extra).To(HaveKeyWithValue("type", "message")) + }) + }) + + Context("preserves raw response", func() { + It("stores the original payload in RawResponse", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "claude-3-5-sonnet@20241022", + "stop_reason": "end_turn" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect([]byte(resp.RawResponse)).To(Equal(payload)) + }) + }) + + Context("with invalid payload", func() { + It("returns an error for invalid JSON", func() { + payload := []byte(`not valid json`) + _, err := p.ParseResponse(payload) + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/proxy/config.go b/proxy/config.go index f5ae6a5..7785c2f 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -13,7 +13,7 @@ type Config struct { // UpstreamURL is the upstream LLM provider URL (e.g., "http://localhost:11434") UpstreamURL string - // ProviderType specifies the LLM provider type (e.g., "anthropic", "openai", "ollama") + // ProviderType specifies the LLM provider type (e.g., "anthropic", "openai", "ollama", "vertex") // This determines how requests and responses are parsed. ProviderType string