From bd0f4782dca4d86b19734b2c8f0c7c9283054a0e Mon Sep 17 00:00:00 2001 From: Shaharia Azam Date: Mon, 16 Dec 2024 19:21:47 +0100 Subject: [PATCH] Added some improvements & new feature to ai package (#15) ai: enhance Anthropic integration with streaming support and refactored design This PR adds streaming capability to the Anthropic LLM provider and improves the overall design through better abstraction and testability. Key changes: - Add streaming response support via GetStreamingResponse - Extract AnthropicClient interface to improve testability - Create RealAnthropicClient implementation wrapping the official SDK - Reduce code duplication by extracting common message handling logic - Add comprehensive tests with mock implementations - Update documentation with streaming examples - Improve general code organization and maintainability - Fix import ordering and godoc formatting The changes maintain backward compatibility while adding new streaming capabilities in line with other providers. The refactoring improves the codebase's testability and reduces duplication through better abstraction. --- README.md | 163 ++++++++------- ai/llm.go | 83 +++++++- ai/llm_provider_anthropic.go | 136 ++++++++++--- ai/llm_provider_anthropic_client.go | 63 ++++++ ai/llm_provider_anthropic_test.go | 302 +++++++++++++++++++++++++++- ai/llm_provider_openai.go | 69 ++++++- ai/llm_provider_openai_test.go | 96 +++++++++ ai/llm_test.go | 122 ++++++++++- ai/types.go | 19 +- 9 files changed, 926 insertions(+), 127 deletions(-) create mode 100644 ai/llm_provider_anthropic_client.go diff --git a/README.md b/README.md index 6184fc7..e53c16e 100644 --- a/README.md +++ b/README.md @@ -38,14 +38,14 @@ guti.ContainsAll() ### AI Operations -The `ai` package provides a flexible interface for interacting with various Language Learning Models (LLMs). Currently supports OpenAI's GPT models with an extensible interface for other providers. +The `ai` package provides a comprehensive interface for working with Language Learning Models (LLMs) and embedding models. It supports multiple providers (OpenAI, Anthropic), streaming responses, and various embedding models. -#### Basic Usage +#### LLM Integration + +Basic text generation with LLMs: ```go -import ( - "github.com/shaharia-lab/guti/ai" -) +import "github.com/shaharia-lab/guti/ai" // Create an OpenAI provider provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ @@ -53,44 +53,93 @@ provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ Model: "gpt-3.5-turbo", // Optional, defaults to gpt-3.5-turbo }) -// Create a request with default configuration -request := ai.NewLLMRequest(ai.NewRequestConfig()) +// Create request with configuration +config := ai.NewRequestConfig( + ai.WithMaxToken(2000), + ai.WithTemperature(0.7), +) +request := ai.NewLLMRequest(config, provider) -// Generate a response -response, err := request.Generate([]LLMMessage{{Role: "user", Text: "What is the capital of France?"}}, provider) +// Generate response +response, err := request.Generate([]ai.LLMMessage{ + {Role: ai.SystemRole, Text: "You are a helpful assistant"}, + {Role: ai.UserRole, Text: "What is the capital of France?"}, +}) if err != nil { log.Fatal(err) } fmt.Printf("Response: %s\n", response.Text) -fmt.Printf("Input tokens: %d\n", response.TotalInputToken) -fmt.Printf("Output tokens: %d\n", response.TotalOutputToken) -fmt.Printf("Completion time: %.2f seconds\n", response.CompletionTime) +fmt.Printf("Tokens used: %d\n", response.TotalOutputToken) ``` -#### Custom Configuration +#### Streaming Responses -You can customize the LLM request configuration using the functional options pattern: +Get realtime token-by-token responses: ```go -// Use specific configuration options -config := ai.NewRequestConfig( - ai.WithMaxToken(2000), - ai.WithTemperature(0.8), - ai.WithTopP(0.95), - ai.WithTopK(100), -) +stream, err := request.GenerateStream(context.Background(), []ai.LLMMessage{ + {Role: ai.UserRole, Text: "Tell me a story"}, +}) +if err != nil { + log.Fatal(err) +} -request := ai.NewLLMRequest(config) +for response := range stream { + if response.Error != nil { + break + } + if response.Done { + break + } + fmt.Print(response.Text) +} +``` + +#### Anthropic Integration + +Use Claude models through Anthropic's API: + +```go +// Create Anthropic client and provider +client := ai.NewRealAnthropicClient("your-api-key") +provider := ai.NewAnthropicLLMProvider(ai.AnthropicProviderConfig{ + Client: client, + Model: "claude-3-sonnet-20240229", // Optional, defaults to latest 3.5 Sonnet +}) + +request := ai.NewLLMRequest(config, provider) ``` -#### Using Templates +#### Embedding Generation -The package also supports templated prompts: +Generate vector embeddings for text: + +```go +provider := ai.NewEmbeddingService("http://api.example.com", nil) + +embedding, err := provider.GenerateEmbedding( + context.Background(), + "Hello world", + ai.EmbeddingModelAllMiniLML6V2, +) +if err != nil { + log.Fatal(err) +} +``` + +Supported embedding models: +- `EmbeddingModelAllMiniLML6V2`: Lightweight, general-purpose model +- `EmbeddingModelAllMpnetBaseV2`: Higher quality, more compute intensive +- `EmbeddingModelParaphraseMultilingualMiniLML12V2`: Optimized for multilingual text + +#### Template Support + +Create dynamic prompts using Go templates: ```go template := &ai.LLMPromptTemplate{ - Template: "Hello {{.Name}}! Please tell me about {{.Topic}}.", + Template: "Hello {{.Name}}! Tell me about {{.Topic}}.", Data: map[string]interface{}{ "Name": "Alice", "Topic": "artificial intelligence", @@ -102,72 +151,32 @@ if err != nil { log.Fatal(err) } -response, err := request.Generate(prompt, provider) +response, err := request.Generate([]ai.LLMMessage{ + {Role: ai.UserRole, Text: prompt}, +}) ``` #### Configuration Options -| Option | Default | Description | -|-------------|---------|--------------------------------------| -| MaxToken | 1000 | Maximum number of tokens to generate | -| TopP | 0.9 | Nucleus sampling parameter (0-1) | -| Temperature | 0.7 | Randomness in output (0-2) | -| TopK | 50 | Top-k sampling parameter | - -#### Error Handling - -The package provides structured error handling: - -```go -response, err := request.Generate(prompt, provider) -if err != nil { - if llmErr, ok := err.(*ai.LLMError); ok { - fmt.Printf("LLM Error %d: %s\n", llmErr.Code, llmErr.Message) - } else { - fmt.Printf("Error: %v\n", err) - } -} -``` +| Option | Default | Description | +|-------------|---------|----------------------------| +| MaxToken | 1000 | Maximum tokens to generate | +| TopP | 0.9 | Nucleus sampling (0-1) | +| Temperature | 0.7 | Output randomness (0-2) | +| TopK | 50 | Top-k sampling parameter | #### Custom Providers -You can implement the `LLMProvider` interface to add support for additional LLM providers: +Implement the provider interfaces to add support for additional services: ```go type LLMProvider interface { GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) -} -``` - -#### Generate Embedding Vector - -You can generate embeddings using the provider-based approach: - -```go -import ( - "github.com/shaharia-lab/guti/ai" -) - -// Create an embedding provider -provider := ai.NewLocalEmbeddingProvider(ai.LocalProviderConfig{ - BaseURL: "http://localhost:8000", - Client: &http.Client{}, -}) - -// Generate embedding -embedding, err := provider.GenerateEmbedding(context.Background(), "Hello world", ai.EmbeddingModelAllMiniLML6V2) -if err != nil { - log.Fatal(err) + GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) } -fmt.Printf("Embedding vector: %+v\n", embedding) -``` - -The library supports multiple embedding providers. You can implement the `EmbeddingProvider` interface to add support for additional providers: - -```go type EmbeddingProvider interface { - GenerateEmbedding(ctx context.Context, text string, model EmbeddingModel) ([]float32, error) + GenerateEmbedding(ctx context.Context, input interface{}, model string) (*EmbeddingResponse, error) } ``` diff --git a/ai/llm.go b/ai/llm.go index 4517719..a074788 100644 --- a/ai/llm.go +++ b/ai/llm.go @@ -1,21 +1,90 @@ // Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). package ai +import "context" + // LLMRequest handles the configuration and execution of LLM requests. // It provides a consistent interface for interacting with different LLM providers. type LLMRequest struct { requestConfig LLMRequestConfig + provider LLMProvider } -// NewLLMRequest creates a new LLMRequest with the specified configuration. -func NewLLMRequest(requestConfig LLMRequestConfig) *LLMRequest { +// NewLLMRequest creates a new LLMRequest with the specified configuration and provider. +// The provider parameter allows injecting different LLM implementations (OpenAI, Anthropic, etc.). +// +// Example usage: +// +// // Create provider +// provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ +// APIKey: "your-api-key", +// Model: "gpt-3.5-turbo", +// }) +// +// // Configure request options +// config := ai.NewRequestConfig( +// ai.WithMaxToken(2000), +// ai.WithTemperature(0.7), +// ) +// +// // Create LLM request client +// llm := ai.NewLLMRequest(config, provider) +func NewLLMRequest(config LLMRequestConfig, provider LLMProvider) *LLMRequest { return &LLMRequest{ - requestConfig: requestConfig, + requestConfig: config, + provider: provider, } } -// Generate sends a prompt to the specified LLM provider and returns the response. -// Returns LLMResponse containing the generated text and metadata, or an error if the operation fails. -func (r *LLMRequest) Generate(messages []LLMMessage, llmProvider LLMProvider) (LLMResponse, error) { - return llmProvider.GetResponse(messages, r.requestConfig) +// Generate sends messages to the configured LLM provider and returns the response. +// It uses the provider and configuration specified during initialization. +// +// Example usage: +// +// messages := []ai.LLMMessage{ +// {Role: ai.SystemRole, Text: "You are a helpful assistant"}, +// {Role: ai.UserRole, Text: "What is the capital of France?"}, +// } +// +// response, err := llm.Generate(messages) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Response: %s\n", response.Text) +// fmt.Printf("Tokens used: %d\n", response.TotalOutputToken) +// +// The method returns LLMResponse containing: +// - Generated text +// - Token usage statistics +// - Completion time +// - Other provider-specific metadata +func (r *LLMRequest) Generate(messages []LLMMessage) (LLMResponse, error) { + return r.provider.GetResponse(messages, r.requestConfig) +} + +// GenerateStream creates a streaming response channel for the given messages. +// It returns a channel that receives StreamingLLMResponse chunks and an error if initialization fails. +// +// Example usage: +// +// request := NewLLMRequest(config) +// stream, err := request.GenerateStream(context.Background(), []LLMMessage{ +// {Role: UserRole, Text: "Tell me a story"}, +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// for response := range stream { +// if response.Error != nil { +// log.Printf("Error: %v", response.Error) +// break +// } +// if response.Done { +// break +// } +// fmt.Print(response.Text) +// } +func (r *LLMRequest) GenerateStream(ctx context.Context, messages []LLMMessage) (<-chan StreamingLLMResponse, error) { + return r.provider.GetStreamingResponse(ctx, messages, r.requestConfig) } diff --git a/ai/llm_provider_anthropic.go b/ai/llm_provider_anthropic.go index f4c5b40..9278ce0 100644 --- a/ai/llm_provider_anthropic.go +++ b/ai/llm_provider_anthropic.go @@ -6,20 +6,19 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" ) // AnthropicLLMProvider implements the LLMProvider interface using Anthropic's official Go SDK. // It provides access to Claude models through Anthropic's API. type AnthropicLLMProvider struct { - client *anthropic.Client + client AnthropicClient model anthropic.Model } // AnthropicProviderConfig holds the configuration options for creating an Anthropic provider. type AnthropicProviderConfig struct { - // APIKey is the authentication key for Anthropic's API - APIKey string + // Client is the AnthropicClient implementation to use + Client AnthropicClient // Model specifies which Anthropic model to use (e.g., "claude-3-opus-20240229", "claude-3-sonnet-20240229") Model anthropic.Model @@ -27,34 +26,45 @@ type AnthropicProviderConfig struct { // NewAnthropicLLMProvider creates a new Anthropic provider with the specified configuration. // If no model is specified, it defaults to Claude 3.5 Sonnet. +// +// Example usage: +// +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: anthropic.ModelClaude_3_5_Sonnet_20240620, +// }) +// +// response, err := provider.GetResponse(messages, config) +// if err != nil { +// log.Fatal(err) +// } func NewAnthropicLLMProvider(config AnthropicProviderConfig) *AnthropicLLMProvider { if config.Model == "" { config.Model = anthropic.ModelClaude_3_5_Sonnet_20240620 } return &AnthropicLLMProvider{ - client: anthropic.NewClient(option.WithAPIKey(config.APIKey)), + client: config.Client, model: config.Model, } } -// GetResponse generates a response using Anthropic's API for the given messages and configuration. -// It supports different message roles (user, assistant, system) and handles them appropriately. -// System messages are handled separately through Anthropic's system parameter. -// The function returns an LLMResponse containing the generated text and metadata, or an error if the operation fails. -func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { - startTime := time.Now() - +// prepareMessageParams creates the Anthropic message parameters from LLM messages and config. +// This is an internal helper function to reduce code duplication. +func (p *AnthropicLLMProvider) prepareMessageParams(messages []LLMMessage, config LLMRequestConfig) anthropic.MessageNewParams { var anthropicMessages []anthropic.MessageParam + var systemMessage string + + // Process messages based on their role for _, msg := range messages { switch msg.Role { + case SystemRole: + systemMessage = msg.Text case UserRole: anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Text))) case AssistantRole: anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Text))) - case SystemRole: - // Anthropic handles system messages differently - we'll add it to params.System - continue default: anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Text))) } @@ -69,16 +79,23 @@ func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequ } // Add system message if present - for _, msg := range messages { - if msg.Role == SystemRole { - params.System = anthropic.F([]anthropic.TextBlockParam{ - anthropic.NewTextBlock(msg.Text), - }) - break - } + if systemMessage != "" { + params.System = anthropic.F([]anthropic.TextBlockParam{ + anthropic.NewTextBlock(systemMessage), + }) } - message, err := p.client.Messages.New(context.Background(), params) + return params +} + +// GetResponse generates a response using Anthropic's API for the given messages and configuration. +// It supports different message roles (user, assistant, system) and handles them appropriately. +// System messages are handled separately through Anthropic's system parameter. +func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { + startTime := time.Now() + + params := p.prepareMessageParams(messages, config) + message, err := p.client.CreateMessage(context.Background(), params) if err != nil { return LLMResponse{}, err } @@ -98,3 +115,76 @@ func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequ CompletionTime: time.Since(startTime).Seconds(), }, nil } + +// GetStreamingResponse generates a streaming response using Anthropic's API. +// It returns a channel that receives chunks of the response as they're generated. +// +// Example usage: +// +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: anthropic.ModelClaude_3_5_Sonnet_20240620, +// }) +// +// streamingResp, err := provider.GetStreamingResponse(ctx, messages, config) +// if err != nil { +// log.Fatal(err) +// } +// +// for chunk := range streamingResp { +// if chunk.Error != nil { +// log.Printf("Error: %v", chunk.Error) +// break +// } +// fmt.Print(chunk.Text) +// } +func (p *AnthropicLLMProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + params := p.prepareMessageParams(messages, config) + stream := p.client.CreateStreamingMessage(ctx, params) + responseChan := make(chan StreamingLLMResponse, 100) + + go func() { + defer close(responseChan) + + for stream.Next() { + select { + case <-ctx.Done(): + responseChan <- StreamingLLMResponse{ + Error: ctx.Err(), + Done: true, + } + return + default: + event := stream.Current() + + switch event.Type { + case anthropic.MessageStreamEventTypeContentBlockDelta: + delta, ok := event.Delta.(anthropic.ContentBlockDeltaEventDelta) + if !ok { + continue + } + + if delta.Type == anthropic.ContentBlockDeltaEventDeltaTypeTextDelta && delta.Text != "" { + responseChan <- StreamingLLMResponse{ + Text: delta.Text, + TokenCount: 1, + } + } + case anthropic.MessageStreamEventTypeMessageStop: + responseChan <- StreamingLLMResponse{Done: true} + return + } + } + } + + if err := stream.Err(); err != nil { + responseChan <- StreamingLLMResponse{ + Error: err, + Done: true, + } + } + }() + + return responseChan, nil +} diff --git a/ai/llm_provider_anthropic_client.go b/ai/llm_provider_anthropic_client.go new file mode 100644 index 0000000..7ec4538 --- /dev/null +++ b/ai/llm_provider_anthropic_client.go @@ -0,0 +1,63 @@ +// Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). +package ai + +import ( + "context" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" +) + +// AnthropicClient defines the interface for interacting with Anthropic's API. +// This interface abstracts the essential message-related operations used by AnthropicLLMProvider. +type AnthropicClient interface { + // CreateMessage creates a new message using Anthropic's API. + // The method takes a context and MessageNewParams and returns a Message response or an error. + CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) + + // CreateStreamingMessage creates a streaming message using Anthropic's API. + // It returns a stream that can be used to receive message chunks as they're generated. + CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] +} + +// RealAnthropicClient implements the AnthropicClient interface using Anthropic's official SDK. +type RealAnthropicClient struct { + messages *anthropic.MessageService +} + +// NewRealAnthropicClient creates a new instance of RealAnthropicClient with the provided API key. +// +// Example usage: +// +// // Regular message generation +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: "claude-3-sonnet-20240229", +// }) +// +// // Streaming message generation +// streamingResp, err := provider.GetStreamingResponse(ctx, messages, config) +// if err != nil { +// log.Fatal(err) +// } +// for chunk := range streamingResp { +// fmt.Print(chunk.Text) +// } +func NewRealAnthropicClient(apiKey string) *RealAnthropicClient { + client := anthropic.NewClient(option.WithAPIKey(apiKey)) + return &RealAnthropicClient{ + messages: client.Messages, + } +} + +// CreateMessage implements the AnthropicClient interface using the real Anthropic client. +func (c *RealAnthropicClient) CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + return c.messages.New(ctx, params) +} + +// CreateStreamingMessage implements the streaming support for the AnthropicClient interface. +func (c *RealAnthropicClient) CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + return c.messages.NewStreaming(ctx, params) +} diff --git a/ai/llm_provider_anthropic_test.go b/ai/llm_provider_anthropic_test.go index 5de8cb0..c64a65c 100644 --- a/ai/llm_provider_anthropic_test.go +++ b/ai/llm_provider_anthropic_test.go @@ -1,21 +1,103 @@ package ai import ( + "context" + "encoding/json" + "strings" "testing" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/stretchr/testify/assert" ) +// MockAnthropicClient implements AnthropicClient interface for testing +type MockAnthropicClient struct { + createMessageFunc func(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) + createStreamingMessageFunc func(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] +} + +func (m *MockAnthropicClient) CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + if m.createMessageFunc != nil { + return m.createMessageFunc(ctx, params) + } + return nil, nil +} + +func (m *MockAnthropicClient) CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + if m.createStreamingMessageFunc != nil { + return m.createStreamingMessageFunc(ctx, params) + } + return nil +} + +type mockEventStream struct { + events []anthropic.MessageStreamEvent + index int +} + +// Implement ssestream.Runner interface +func (m *mockEventStream) Run() error { + return nil +} + +type mockDecoder struct { + events []anthropic.MessageStreamEvent + index int +} + +func (d *mockDecoder) Event() ssestream.Event { + if d.index < 0 || d.index >= len(d.events) { + return ssestream.Event{} + } + + event := d.events[d.index] + + // Create a custom payload that can be unmarshaled correctly + payload := map[string]interface{}{ + "type": event.Type, + "delta": event.Delta, + "index": event.Index, + } + + if event.Type == anthropic.MessageStreamEventTypeMessageStart { + payload["message"] = event.Message + } + + data, err := json.Marshal(payload) + if err != nil { + return ssestream.Event{} + } + + return ssestream.Event{ + Type: string(event.Type), + Data: data, + } +} + +func (d *mockDecoder) Next() bool { + d.index++ + return d.index < len(d.events) +} + +func (d *mockDecoder) Err() error { + return nil +} + +func (d *mockDecoder) Close() error { + return nil +} + func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { tests := []struct { name string config AnthropicProviderConfig - expectedModel string + expectedModel anthropic.Model }{ { name: "with specified model", config: AnthropicProviderConfig{ - APIKey: "test-key", + Client: &MockAnthropicClient{}, Model: "claude-3-opus-20240229", }, expectedModel: "claude-3-opus-20240229", @@ -23,9 +105,9 @@ func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { { name: "with default model", config: AnthropicProviderConfig{ - APIKey: "test-key", + Client: &MockAnthropicClient{}, }, - expectedModel: string(anthropic.ModelClaude_3_5_Sonnet_20240620), + expectedModel: anthropic.ModelClaude_3_5_Sonnet_20240620, }, } @@ -33,12 +115,216 @@ func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { t.Run(tt.name, func(t *testing.T) { provider := NewAnthropicLLMProvider(tt.config) - if provider.model != tt.expectedModel { - t.Errorf("expected model %q, got %q", tt.expectedModel, provider.model) + assert.Equal(t, tt.expectedModel, provider.model, "unexpected model") + assert.NotNil(t, provider.client, "expected client to be initialized") + }) + } +} + +func TestAnthropicLLMProvider_GetResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + config LLMRequestConfig + expectedResult LLMResponse + expectError bool + }{ + { + name: "successful response with all message types", + messages: []LLMMessage{ + {Role: SystemRole, Text: "You are a helpful assistant"}, + {Role: UserRole, Text: "Hello"}, + {Role: AssistantRole, Text: "Hi there"}, + }, + config: LLMRequestConfig{ + MaxToken: 100, + TopP: 0.9, + Temperature: 0.7, + }, + expectedResult: LLMResponse{ + Text: "Test response", + TotalInputToken: 10, + TotalOutputToken: 5, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &MockAnthropicClient{ + createMessageFunc: func(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + message := &anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + Type: anthropic.MessageTypeMessage, + } + + block := anthropic.ContentBlock{} + if err := block.UnmarshalJSON([]byte(`{ + "type": "text", + "text": "Test response" + }`)); err != nil { + t.Fatal(err) + } + + message.Content = []anthropic.ContentBlock{block} + return message, nil + }, + } + + provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ + Client: mockClient, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }) + + result, err := provider.GetResponse(tt.messages, tt.config) + + if tt.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expectedResult.Text, result.Text) + assert.Equal(t, tt.expectedResult.TotalInputToken, result.TotalInputToken) + assert.Equal(t, tt.expectedResult.TotalOutputToken, result.TotalOutputToken) + assert.Greater(t, result.CompletionTime, float64(0), "completion time should be greater than 0") + }) + } +} + +func TestAnthropicLLMProvider_GetStreamingResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + config LLMRequestConfig + streamText []string + expectError bool + }{ + { + name: "successful streaming response", + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + config: LLMRequestConfig{ + MaxToken: 100, + TopP: 0.9, + Temperature: 0.7, + }, + streamText: []string{"Hello", " world", "!"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &MockAnthropicClient{ + createStreamingMessageFunc: func(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + var events []anthropic.MessageStreamEvent + + // Create start event + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStart, + Message: anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }, + }) + + // Create content block delta events + for i, text := range tt.streamText { + t.Logf("Adding delta event %d with text: %q", i, text) + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockDelta, + Index: int64(i), + Delta: anthropic.ContentBlockDeltaEventDelta{ + Type: anthropic.ContentBlockDeltaEventDeltaTypeTextDelta, + Text: text, + }, + }) + } + + // Add stop event + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStop, + }) + + decoder := &mockDecoder{ + events: events, + index: -1, + } + + stream := ssestream.NewStream[anthropic.MessageStreamEvent](decoder, nil) + return stream + }, } - if provider.client == nil { - t.Error("expected client to be initialized") + + provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ + Client: mockClient, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }) + + ctx := context.Background() + stream, err := provider.GetStreamingResponse(ctx, tt.messages, tt.config) + assert.NoError(t, err) + + var receivedText string + for chunk := range stream { + t.Logf("Received streaming chunk: %+v", chunk) + if chunk.Error != nil { + t.Fatalf("Unexpected error: %v", chunk.Error) + } + if !chunk.Done { + receivedText += chunk.Text + t.Logf("Current accumulated text: %q", receivedText) + } } + + t.Logf("Final text: %q", receivedText) + assert.Equal(t, strings.Join(tt.streamText, ""), receivedText) }) } } + +func createStreamEvent(eventType string, index int64, text string) anthropic.MessageStreamEvent { + var event anthropic.MessageStreamEvent + + switch eventType { + case "message_start": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStart, + Message: anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }, + } + case "content_block_delta": + /*textDelta := anthropic.TextDelta{ + Type: anthropic.TextDeltaTypeTextDelta, + Text: text, + }*/ + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockDelta, + Index: index, + Delta: anthropic.ContentBlockDeltaEventDelta{ + Type: anthropic.ContentBlockDeltaEventDeltaTypeTextDelta, + Text: text, + }, + } + case "content_block_stop": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockStop, + Index: index, + } + case "message_stop": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStop, + } + } + + return event +} diff --git a/ai/llm_provider_openai.go b/ai/llm_provider_openai.go index 95bd30a..ecb76e7 100644 --- a/ai/llm_provider_openai.go +++ b/ai/llm_provider_openai.go @@ -34,11 +34,8 @@ func NewOpenAILLMProvider(config OpenAIProviderConfig) *OpenAILLMProvider { } } -// GetResponse generates a response using OpenAI's API for the given messages and configuration. -// It supports different message roles (user, assistant, system) and handles them appropriately. -func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { - startTime := time.Now() - +// convertToOpenAIMessages converts internal message format to OpenAI's format +func (p *OpenAILLMProvider) convertToOpenAIMessages(messages []LLMMessage) []openai.ChatCompletionMessageParamUnion { var openAIMessages []openai.ChatCompletionMessageParamUnion for _, msg := range messages { switch msg.Role { @@ -52,14 +49,27 @@ func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequest openAIMessages = append(openAIMessages, openai.UserMessage(msg.Text)) } } + return openAIMessages +} - params := openai.ChatCompletionNewParams{ - Messages: openai.F(openAIMessages), +// createCompletionParams creates OpenAI API parameters from request config +func (p *OpenAILLMProvider) createCompletionParams(messages []openai.ChatCompletionMessageParamUnion, config LLMRequestConfig) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Messages: openai.F(messages), Model: openai.F(p.model), MaxTokens: openai.Int(config.MaxToken), TopP: openai.Float(config.TopP), Temperature: openai.Float(config.Temperature), } +} + +// GetResponse generates a response using OpenAI's API for the given messages and configuration. +// It supports different message roles (user, assistant, system) and handles them appropriately. +func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { + startTime := time.Now() + + openAIMessages := p.convertToOpenAIMessages(messages) + params := p.createCompletionParams(openAIMessages, config) completion, err := p.client.Chat.Completions.New(context.Background(), params) if err != nil { @@ -77,3 +87,48 @@ func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequest CompletionTime: time.Since(startTime).Seconds(), }, nil } + +// GetStreamingResponse generates a streaming response using OpenAI's API. +// It supports streaming tokens as they're generated and handles context cancellation. +func (p *OpenAILLMProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + openAIMessages := p.convertToOpenAIMessages(messages) + params := p.createCompletionParams(openAIMessages, config) + + stream := p.client.Chat.Completions.NewStreaming(ctx, params) + responseChan := make(chan StreamingLLMResponse, 100) + + go func() { + defer close(responseChan) + + for stream.Next() { + select { + case <-ctx.Done(): + responseChan <- StreamingLLMResponse{ + Error: ctx.Err(), + Done: true, + } + return + default: + chunk := stream.Current() + if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { + responseChan <- StreamingLLMResponse{ + Text: chunk.Choices[0].Delta.Content, + TokenCount: 1, + } + } + } + } + + if err := stream.Err(); err != nil { + responseChan <- StreamingLLMResponse{ + Error: err, + Done: true, + } + return + } + + responseChan <- StreamingLLMResponse{Done: true} + }() + + return responseChan, nil +} diff --git a/ai/llm_provider_openai_test.go b/ai/llm_provider_openai_test.go index b6da52b..26dd911 100644 --- a/ai/llm_provider_openai_test.go +++ b/ai/llm_provider_openai_test.go @@ -1,9 +1,14 @@ package ai import ( + "context" + "io" + "net/http" "testing" + "time" "github.com/openai/openai-go" + "github.com/openai/openai-go/option" ) func TestOpenAILLMProvider_NewOpenAILLMProvider(t *testing.T) { @@ -42,3 +47,94 @@ func TestOpenAILLMProvider_NewOpenAILLMProvider(t *testing.T) { }) } } + +func TestOpenAILLMProvider_GetStreamingResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + timeout time.Duration + delay time.Duration + wantErr bool + }{ + { + name: "successful streaming", + timeout: 100 * time.Millisecond, + delay: 0, + }, + { + name: "context cancellation", + timeout: 5 * time.Millisecond, + delay: 50 * time.Millisecond, + }, + } + + responses := []string{ + `data: {"id":"123","choices":[{"delta":{"content":"Hello"}}]}`, + `data: {"id":"123","choices":[{"delta":{"content":" world"}}]}`, + `data: [DONE]`, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := &OpenAILLMProvider{ + client: openai.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &mockTransport{ + responses: responses, + delay: tt.delay, + }, + }), + ), + model: openai.ChatModelGPT3_5Turbo, + } + + ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) + defer cancel() + + stream, err := provider.GetStreamingResponse(ctx, []LLMMessage{{Role: UserRole, Text: "test"}}, LLMRequestConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var gotCancel bool + for resp := range stream { + if resp.Error != nil { + gotCancel = true + break + } + } + + if tt.delay > tt.timeout && !gotCancel { + t.Error("expected context cancellation") + } + }) + } +} + +type mockTransport struct { + responses []string + delay time.Duration +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if m.delay > 0 { + time.Sleep(m.delay) + } + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for _, resp := range m.responses { + time.Sleep(10 * time.Millisecond) // Simulate streaming delay + pw.Write([]byte(resp + "\n")) + } + }() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: pr, + }, nil +} diff --git a/ai/llm_test.go b/ai/llm_test.go index f1c7866..9c0c6c6 100644 --- a/ai/llm_test.go +++ b/ai/llm_test.go @@ -1,19 +1,41 @@ package ai import ( + "context" "errors" "testing" ) type mockProvider struct { - response LLMResponse - err error + response LLMResponse + err error + streamResponses []StreamingLLMResponse + streamErr error } func (m *mockProvider) GetResponse(messages []LLMMessage, _ LLMRequestConfig) (LLMResponse, error) { return m.response, m.err } +func (m *mockProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + if m.streamErr != nil { + return nil, m.streamErr + } + + ch := make(chan StreamingLLMResponse, len(m.streamResponses)) + go func() { + defer close(ch) + for _, resp := range m.streamResponses { + select { + case <-ctx.Done(): + return + case ch <- resp: + } + } + }() + return ch, nil +} + func TestLLMRequest_Generate(t *testing.T) { tests := []struct { name string @@ -62,11 +84,11 @@ func TestLLMRequest_Generate(t *testing.T) { err: tt.mockError, } - request := NewLLMRequest(tt.config) + request := NewLLMRequest(tt.config, provider) response, err := request.Generate([]LLMMessage{{ Role: "user", Text: "test prompt", - }}, provider) + }}) if tt.expectedError { if err == nil { @@ -92,3 +114,95 @@ func TestLLMRequest_Generate(t *testing.T) { }) } } + +func TestLLMRequest_GenerateStream(t *testing.T) { + tests := []struct { + name string + config LLMRequestConfig + messages []LLMMessage + streamResponses []StreamingLLMResponse + streamErr error + wantErr bool + }{ + { + name: "successful streaming", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamResponses: []StreamingLLMResponse{ + {Text: "Hello", TokenCount: 1}, + {Text: "World", TokenCount: 1}, + {Done: true}, + }, + }, + { + name: "provider error", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamErr: errors.New("stream error"), + wantErr: true, + }, + { + name: "context cancellation", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamResponses: []StreamingLLMResponse{ + {Text: "Hello", TokenCount: 1}, + {Error: context.Canceled, Done: true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := &mockProvider{ + streamResponses: tt.streamResponses, + streamErr: tt.streamErr, + } + + request := NewLLMRequest(tt.config, provider) + stream, err := request.GenerateStream(context.Background(), tt.messages) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateStream() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + var got []StreamingLLMResponse + for resp := range stream { + got = append(got, resp) + } + + if len(got) != len(tt.streamResponses) { + t.Errorf("expected %d responses, got %d", len(tt.streamResponses), len(got)) + return + } + + for i, want := range tt.streamResponses { + if got[i].Text != want.Text { + t.Errorf("response[%d].Text = %v, want %v", i, got[i].Text, want.Text) + } + if got[i].Done != want.Done { + t.Errorf("response[%d].Done = %v, want %v", i, got[i].Done, want.Done) + } + if got[i].Error != want.Error { + t.Errorf("response[%d].Error = %v, want %v", i, got[i].Error, want.Error) + } + } + }) + } +} diff --git a/ai/types.go b/ai/types.go index a3bcdda..4484b94 100644 --- a/ai/types.go +++ b/ai/types.go @@ -1,7 +1,10 @@ // Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). package ai -import "fmt" +import ( + "context" + "fmt" +) // LLMMessageRole represents the role of a message in a conversation. type LLMMessageRole string @@ -123,10 +126,24 @@ type LLMMessage struct { Text string } +// StreamingLLMResponse represents a chunk of streaming response from an LLM provider. +// It contains partial text, completion status, any errors, and token usage information. +type StreamingLLMResponse struct { + // Text contains the partial response text + Text string + // Done indicates if this is the final chunk + Done bool + // Error contains any error that occurred during streaming + Error error + // TokenCount is the number of tokens in this chunk + TokenCount int +} + // LLMProvider defines the interface that all LLM providers must implement. // This allows for easy swapping between different LLM providers. type LLMProvider interface { // GetResponse generates a response for the given question using the specified configuration. // Returns LLMResponse containing the generated text and metadata, or an error if the operation fails. GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) + GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) }