diff --git a/README.md b/README.md index 6184fc7..e53c16e 100644 --- a/README.md +++ b/README.md @@ -38,14 +38,14 @@ guti.ContainsAll() ### AI Operations -The `ai` package provides a flexible interface for interacting with various Language Learning Models (LLMs). Currently supports OpenAI's GPT models with an extensible interface for other providers. +The `ai` package provides a comprehensive interface for working with Language Learning Models (LLMs) and embedding models. It supports multiple providers (OpenAI, Anthropic), streaming responses, and various embedding models. -#### Basic Usage +#### LLM Integration + +Basic text generation with LLMs: ```go -import ( - "github.com/shaharia-lab/guti/ai" -) +import "github.com/shaharia-lab/guti/ai" // Create an OpenAI provider provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ @@ -53,44 +53,93 @@ provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ Model: "gpt-3.5-turbo", // Optional, defaults to gpt-3.5-turbo }) -// Create a request with default configuration -request := ai.NewLLMRequest(ai.NewRequestConfig()) +// Create request with configuration +config := ai.NewRequestConfig( + ai.WithMaxToken(2000), + ai.WithTemperature(0.7), +) +request := ai.NewLLMRequest(config, provider) -// Generate a response -response, err := request.Generate([]LLMMessage{{Role: "user", Text: "What is the capital of France?"}}, provider) +// Generate response +response, err := request.Generate([]ai.LLMMessage{ + {Role: ai.SystemRole, Text: "You are a helpful assistant"}, + {Role: ai.UserRole, Text: "What is the capital of France?"}, +}) if err != nil { log.Fatal(err) } fmt.Printf("Response: %s\n", response.Text) -fmt.Printf("Input tokens: %d\n", response.TotalInputToken) -fmt.Printf("Output tokens: %d\n", response.TotalOutputToken) -fmt.Printf("Completion time: %.2f seconds\n", response.CompletionTime) +fmt.Printf("Tokens used: %d\n", response.TotalOutputToken) ``` -#### Custom Configuration +#### Streaming Responses -You can customize the LLM request configuration using the functional options pattern: +Get realtime token-by-token responses: ```go -// Use specific configuration options -config := ai.NewRequestConfig( - ai.WithMaxToken(2000), - ai.WithTemperature(0.8), - ai.WithTopP(0.95), - ai.WithTopK(100), -) +stream, err := request.GenerateStream(context.Background(), []ai.LLMMessage{ + {Role: ai.UserRole, Text: "Tell me a story"}, +}) +if err != nil { + log.Fatal(err) +} -request := ai.NewLLMRequest(config) +for response := range stream { + if response.Error != nil { + break + } + if response.Done { + break + } + fmt.Print(response.Text) +} +``` + +#### Anthropic Integration + +Use Claude models through Anthropic's API: + +```go +// Create Anthropic client and provider +client := ai.NewRealAnthropicClient("your-api-key") +provider := ai.NewAnthropicLLMProvider(ai.AnthropicProviderConfig{ + Client: client, + Model: "claude-3-sonnet-20240229", // Optional, defaults to latest 3.5 Sonnet +}) + +request := ai.NewLLMRequest(config, provider) ``` -#### Using Templates +#### Embedding Generation -The package also supports templated prompts: +Generate vector embeddings for text: + +```go +provider := ai.NewEmbeddingService("http://api.example.com", nil) + +embedding, err := provider.GenerateEmbedding( + context.Background(), + "Hello world", + ai.EmbeddingModelAllMiniLML6V2, +) +if err != nil { + log.Fatal(err) +} +``` + +Supported embedding models: +- `EmbeddingModelAllMiniLML6V2`: Lightweight, general-purpose model +- `EmbeddingModelAllMpnetBaseV2`: Higher quality, more compute intensive +- `EmbeddingModelParaphraseMultilingualMiniLML12V2`: Optimized for multilingual text + +#### Template Support + +Create dynamic prompts using Go templates: ```go template := &ai.LLMPromptTemplate{ - Template: "Hello {{.Name}}! Please tell me about {{.Topic}}.", + Template: "Hello {{.Name}}! Tell me about {{.Topic}}.", Data: map[string]interface{}{ "Name": "Alice", "Topic": "artificial intelligence", @@ -102,72 +151,32 @@ if err != nil { log.Fatal(err) } -response, err := request.Generate(prompt, provider) +response, err := request.Generate([]ai.LLMMessage{ + {Role: ai.UserRole, Text: prompt}, +}) ``` #### Configuration Options -| Option | Default | Description | -|-------------|---------|--------------------------------------| -| MaxToken | 1000 | Maximum number of tokens to generate | -| TopP | 0.9 | Nucleus sampling parameter (0-1) | -| Temperature | 0.7 | Randomness in output (0-2) | -| TopK | 50 | Top-k sampling parameter | - -#### Error Handling - -The package provides structured error handling: - -```go -response, err := request.Generate(prompt, provider) -if err != nil { - if llmErr, ok := err.(*ai.LLMError); ok { - fmt.Printf("LLM Error %d: %s\n", llmErr.Code, llmErr.Message) - } else { - fmt.Printf("Error: %v\n", err) - } -} -``` +| Option | Default | Description | +|-------------|---------|----------------------------| +| MaxToken | 1000 | Maximum tokens to generate | +| TopP | 0.9 | Nucleus sampling (0-1) | +| Temperature | 0.7 | Output randomness (0-2) | +| TopK | 50 | Top-k sampling parameter | #### Custom Providers -You can implement the `LLMProvider` interface to add support for additional LLM providers: +Implement the provider interfaces to add support for additional services: ```go type LLMProvider interface { GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) -} -``` - -#### Generate Embedding Vector - -You can generate embeddings using the provider-based approach: - -```go -import ( - "github.com/shaharia-lab/guti/ai" -) - -// Create an embedding provider -provider := ai.NewLocalEmbeddingProvider(ai.LocalProviderConfig{ - BaseURL: "http://localhost:8000", - Client: &http.Client{}, -}) - -// Generate embedding -embedding, err := provider.GenerateEmbedding(context.Background(), "Hello world", ai.EmbeddingModelAllMiniLML6V2) -if err != nil { - log.Fatal(err) + GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) } -fmt.Printf("Embedding vector: %+v\n", embedding) -``` - -The library supports multiple embedding providers. You can implement the `EmbeddingProvider` interface to add support for additional providers: - -```go type EmbeddingProvider interface { - GenerateEmbedding(ctx context.Context, text string, model EmbeddingModel) ([]float32, error) + GenerateEmbedding(ctx context.Context, input interface{}, model string) (*EmbeddingResponse, error) } ``` diff --git a/ai/llm.go b/ai/llm.go index 4517719..a074788 100644 --- a/ai/llm.go +++ b/ai/llm.go @@ -1,21 +1,90 @@ // Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). package ai +import "context" + // LLMRequest handles the configuration and execution of LLM requests. // It provides a consistent interface for interacting with different LLM providers. type LLMRequest struct { requestConfig LLMRequestConfig + provider LLMProvider } -// NewLLMRequest creates a new LLMRequest with the specified configuration. -func NewLLMRequest(requestConfig LLMRequestConfig) *LLMRequest { +// NewLLMRequest creates a new LLMRequest with the specified configuration and provider. +// The provider parameter allows injecting different LLM implementations (OpenAI, Anthropic, etc.). +// +// Example usage: +// +// // Create provider +// provider := ai.NewOpenAILLMProvider(ai.OpenAIProviderConfig{ +// APIKey: "your-api-key", +// Model: "gpt-3.5-turbo", +// }) +// +// // Configure request options +// config := ai.NewRequestConfig( +// ai.WithMaxToken(2000), +// ai.WithTemperature(0.7), +// ) +// +// // Create LLM request client +// llm := ai.NewLLMRequest(config, provider) +func NewLLMRequest(config LLMRequestConfig, provider LLMProvider) *LLMRequest { return &LLMRequest{ - requestConfig: requestConfig, + requestConfig: config, + provider: provider, } } -// Generate sends a prompt to the specified LLM provider and returns the response. -// Returns LLMResponse containing the generated text and metadata, or an error if the operation fails. -func (r *LLMRequest) Generate(messages []LLMMessage, llmProvider LLMProvider) (LLMResponse, error) { - return llmProvider.GetResponse(messages, r.requestConfig) +// Generate sends messages to the configured LLM provider and returns the response. +// It uses the provider and configuration specified during initialization. +// +// Example usage: +// +// messages := []ai.LLMMessage{ +// {Role: ai.SystemRole, Text: "You are a helpful assistant"}, +// {Role: ai.UserRole, Text: "What is the capital of France?"}, +// } +// +// response, err := llm.Generate(messages) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Response: %s\n", response.Text) +// fmt.Printf("Tokens used: %d\n", response.TotalOutputToken) +// +// The method returns LLMResponse containing: +// - Generated text +// - Token usage statistics +// - Completion time +// - Other provider-specific metadata +func (r *LLMRequest) Generate(messages []LLMMessage) (LLMResponse, error) { + return r.provider.GetResponse(messages, r.requestConfig) +} + +// GenerateStream creates a streaming response channel for the given messages. +// It returns a channel that receives StreamingLLMResponse chunks and an error if initialization fails. +// +// Example usage: +// +// request := NewLLMRequest(config) +// stream, err := request.GenerateStream(context.Background(), []LLMMessage{ +// {Role: UserRole, Text: "Tell me a story"}, +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// for response := range stream { +// if response.Error != nil { +// log.Printf("Error: %v", response.Error) +// break +// } +// if response.Done { +// break +// } +// fmt.Print(response.Text) +// } +func (r *LLMRequest) GenerateStream(ctx context.Context, messages []LLMMessage) (<-chan StreamingLLMResponse, error) { + return r.provider.GetStreamingResponse(ctx, messages, r.requestConfig) } diff --git a/ai/llm_provider_anthropic.go b/ai/llm_provider_anthropic.go index f4c5b40..9278ce0 100644 --- a/ai/llm_provider_anthropic.go +++ b/ai/llm_provider_anthropic.go @@ -6,20 +6,19 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" ) // AnthropicLLMProvider implements the LLMProvider interface using Anthropic's official Go SDK. // It provides access to Claude models through Anthropic's API. type AnthropicLLMProvider struct { - client *anthropic.Client + client AnthropicClient model anthropic.Model } // AnthropicProviderConfig holds the configuration options for creating an Anthropic provider. type AnthropicProviderConfig struct { - // APIKey is the authentication key for Anthropic's API - APIKey string + // Client is the AnthropicClient implementation to use + Client AnthropicClient // Model specifies which Anthropic model to use (e.g., "claude-3-opus-20240229", "claude-3-sonnet-20240229") Model anthropic.Model @@ -27,34 +26,45 @@ type AnthropicProviderConfig struct { // NewAnthropicLLMProvider creates a new Anthropic provider with the specified configuration. // If no model is specified, it defaults to Claude 3.5 Sonnet. +// +// Example usage: +// +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: anthropic.ModelClaude_3_5_Sonnet_20240620, +// }) +// +// response, err := provider.GetResponse(messages, config) +// if err != nil { +// log.Fatal(err) +// } func NewAnthropicLLMProvider(config AnthropicProviderConfig) *AnthropicLLMProvider { if config.Model == "" { config.Model = anthropic.ModelClaude_3_5_Sonnet_20240620 } return &AnthropicLLMProvider{ - client: anthropic.NewClient(option.WithAPIKey(config.APIKey)), + client: config.Client, model: config.Model, } } -// GetResponse generates a response using Anthropic's API for the given messages and configuration. -// It supports different message roles (user, assistant, system) and handles them appropriately. -// System messages are handled separately through Anthropic's system parameter. -// The function returns an LLMResponse containing the generated text and metadata, or an error if the operation fails. -func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { - startTime := time.Now() - +// prepareMessageParams creates the Anthropic message parameters from LLM messages and config. +// This is an internal helper function to reduce code duplication. +func (p *AnthropicLLMProvider) prepareMessageParams(messages []LLMMessage, config LLMRequestConfig) anthropic.MessageNewParams { var anthropicMessages []anthropic.MessageParam + var systemMessage string + + // Process messages based on their role for _, msg := range messages { switch msg.Role { + case SystemRole: + systemMessage = msg.Text case UserRole: anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Text))) case AssistantRole: anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Text))) - case SystemRole: - // Anthropic handles system messages differently - we'll add it to params.System - continue default: anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Text))) } @@ -69,16 +79,23 @@ func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequ } // Add system message if present - for _, msg := range messages { - if msg.Role == SystemRole { - params.System = anthropic.F([]anthropic.TextBlockParam{ - anthropic.NewTextBlock(msg.Text), - }) - break - } + if systemMessage != "" { + params.System = anthropic.F([]anthropic.TextBlockParam{ + anthropic.NewTextBlock(systemMessage), + }) } - message, err := p.client.Messages.New(context.Background(), params) + return params +} + +// GetResponse generates a response using Anthropic's API for the given messages and configuration. +// It supports different message roles (user, assistant, system) and handles them appropriately. +// System messages are handled separately through Anthropic's system parameter. +func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { + startTime := time.Now() + + params := p.prepareMessageParams(messages, config) + message, err := p.client.CreateMessage(context.Background(), params) if err != nil { return LLMResponse{}, err } @@ -98,3 +115,76 @@ func (p *AnthropicLLMProvider) GetResponse(messages []LLMMessage, config LLMRequ CompletionTime: time.Since(startTime).Seconds(), }, nil } + +// GetStreamingResponse generates a streaming response using Anthropic's API. +// It returns a channel that receives chunks of the response as they're generated. +// +// Example usage: +// +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: anthropic.ModelClaude_3_5_Sonnet_20240620, +// }) +// +// streamingResp, err := provider.GetStreamingResponse(ctx, messages, config) +// if err != nil { +// log.Fatal(err) +// } +// +// for chunk := range streamingResp { +// if chunk.Error != nil { +// log.Printf("Error: %v", chunk.Error) +// break +// } +// fmt.Print(chunk.Text) +// } +func (p *AnthropicLLMProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + params := p.prepareMessageParams(messages, config) + stream := p.client.CreateStreamingMessage(ctx, params) + responseChan := make(chan StreamingLLMResponse, 100) + + go func() { + defer close(responseChan) + + for stream.Next() { + select { + case <-ctx.Done(): + responseChan <- StreamingLLMResponse{ + Error: ctx.Err(), + Done: true, + } + return + default: + event := stream.Current() + + switch event.Type { + case anthropic.MessageStreamEventTypeContentBlockDelta: + delta, ok := event.Delta.(anthropic.ContentBlockDeltaEventDelta) + if !ok { + continue + } + + if delta.Type == anthropic.ContentBlockDeltaEventDeltaTypeTextDelta && delta.Text != "" { + responseChan <- StreamingLLMResponse{ + Text: delta.Text, + TokenCount: 1, + } + } + case anthropic.MessageStreamEventTypeMessageStop: + responseChan <- StreamingLLMResponse{Done: true} + return + } + } + } + + if err := stream.Err(); err != nil { + responseChan <- StreamingLLMResponse{ + Error: err, + Done: true, + } + } + }() + + return responseChan, nil +} diff --git a/ai/llm_provider_anthropic_client.go b/ai/llm_provider_anthropic_client.go new file mode 100644 index 0000000..7ec4538 --- /dev/null +++ b/ai/llm_provider_anthropic_client.go @@ -0,0 +1,63 @@ +// Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). +package ai + +import ( + "context" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" +) + +// AnthropicClient defines the interface for interacting with Anthropic's API. +// This interface abstracts the essential message-related operations used by AnthropicLLMProvider. +type AnthropicClient interface { + // CreateMessage creates a new message using Anthropic's API. + // The method takes a context and MessageNewParams and returns a Message response or an error. + CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) + + // CreateStreamingMessage creates a streaming message using Anthropic's API. + // It returns a stream that can be used to receive message chunks as they're generated. + CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] +} + +// RealAnthropicClient implements the AnthropicClient interface using Anthropic's official SDK. +type RealAnthropicClient struct { + messages *anthropic.MessageService +} + +// NewRealAnthropicClient creates a new instance of RealAnthropicClient with the provided API key. +// +// Example usage: +// +// // Regular message generation +// client := NewRealAnthropicClient("your-api-key") +// provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ +// Client: client, +// Model: "claude-3-sonnet-20240229", +// }) +// +// // Streaming message generation +// streamingResp, err := provider.GetStreamingResponse(ctx, messages, config) +// if err != nil { +// log.Fatal(err) +// } +// for chunk := range streamingResp { +// fmt.Print(chunk.Text) +// } +func NewRealAnthropicClient(apiKey string) *RealAnthropicClient { + client := anthropic.NewClient(option.WithAPIKey(apiKey)) + return &RealAnthropicClient{ + messages: client.Messages, + } +} + +// CreateMessage implements the AnthropicClient interface using the real Anthropic client. +func (c *RealAnthropicClient) CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + return c.messages.New(ctx, params) +} + +// CreateStreamingMessage implements the streaming support for the AnthropicClient interface. +func (c *RealAnthropicClient) CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + return c.messages.NewStreaming(ctx, params) +} diff --git a/ai/llm_provider_anthropic_test.go b/ai/llm_provider_anthropic_test.go index 5de8cb0..c64a65c 100644 --- a/ai/llm_provider_anthropic_test.go +++ b/ai/llm_provider_anthropic_test.go @@ -1,21 +1,103 @@ package ai import ( + "context" + "encoding/json" + "strings" "testing" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/stretchr/testify/assert" ) +// MockAnthropicClient implements AnthropicClient interface for testing +type MockAnthropicClient struct { + createMessageFunc func(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) + createStreamingMessageFunc func(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] +} + +func (m *MockAnthropicClient) CreateMessage(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + if m.createMessageFunc != nil { + return m.createMessageFunc(ctx, params) + } + return nil, nil +} + +func (m *MockAnthropicClient) CreateStreamingMessage(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + if m.createStreamingMessageFunc != nil { + return m.createStreamingMessageFunc(ctx, params) + } + return nil +} + +type mockEventStream struct { + events []anthropic.MessageStreamEvent + index int +} + +// Implement ssestream.Runner interface +func (m *mockEventStream) Run() error { + return nil +} + +type mockDecoder struct { + events []anthropic.MessageStreamEvent + index int +} + +func (d *mockDecoder) Event() ssestream.Event { + if d.index < 0 || d.index >= len(d.events) { + return ssestream.Event{} + } + + event := d.events[d.index] + + // Create a custom payload that can be unmarshaled correctly + payload := map[string]interface{}{ + "type": event.Type, + "delta": event.Delta, + "index": event.Index, + } + + if event.Type == anthropic.MessageStreamEventTypeMessageStart { + payload["message"] = event.Message + } + + data, err := json.Marshal(payload) + if err != nil { + return ssestream.Event{} + } + + return ssestream.Event{ + Type: string(event.Type), + Data: data, + } +} + +func (d *mockDecoder) Next() bool { + d.index++ + return d.index < len(d.events) +} + +func (d *mockDecoder) Err() error { + return nil +} + +func (d *mockDecoder) Close() error { + return nil +} + func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { tests := []struct { name string config AnthropicProviderConfig - expectedModel string + expectedModel anthropic.Model }{ { name: "with specified model", config: AnthropicProviderConfig{ - APIKey: "test-key", + Client: &MockAnthropicClient{}, Model: "claude-3-opus-20240229", }, expectedModel: "claude-3-opus-20240229", @@ -23,9 +105,9 @@ func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { { name: "with default model", config: AnthropicProviderConfig{ - APIKey: "test-key", + Client: &MockAnthropicClient{}, }, - expectedModel: string(anthropic.ModelClaude_3_5_Sonnet_20240620), + expectedModel: anthropic.ModelClaude_3_5_Sonnet_20240620, }, } @@ -33,12 +115,216 @@ func TestAnthropicLLMProvider_NewAnthropicLLMProvider(t *testing.T) { t.Run(tt.name, func(t *testing.T) { provider := NewAnthropicLLMProvider(tt.config) - if provider.model != tt.expectedModel { - t.Errorf("expected model %q, got %q", tt.expectedModel, provider.model) + assert.Equal(t, tt.expectedModel, provider.model, "unexpected model") + assert.NotNil(t, provider.client, "expected client to be initialized") + }) + } +} + +func TestAnthropicLLMProvider_GetResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + config LLMRequestConfig + expectedResult LLMResponse + expectError bool + }{ + { + name: "successful response with all message types", + messages: []LLMMessage{ + {Role: SystemRole, Text: "You are a helpful assistant"}, + {Role: UserRole, Text: "Hello"}, + {Role: AssistantRole, Text: "Hi there"}, + }, + config: LLMRequestConfig{ + MaxToken: 100, + TopP: 0.9, + Temperature: 0.7, + }, + expectedResult: LLMResponse{ + Text: "Test response", + TotalInputToken: 10, + TotalOutputToken: 5, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &MockAnthropicClient{ + createMessageFunc: func(ctx context.Context, params anthropic.MessageNewParams) (*anthropic.Message, error) { + message := &anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + Type: anthropic.MessageTypeMessage, + } + + block := anthropic.ContentBlock{} + if err := block.UnmarshalJSON([]byte(`{ + "type": "text", + "text": "Test response" + }`)); err != nil { + t.Fatal(err) + } + + message.Content = []anthropic.ContentBlock{block} + return message, nil + }, + } + + provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ + Client: mockClient, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }) + + result, err := provider.GetResponse(tt.messages, tt.config) + + if tt.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expectedResult.Text, result.Text) + assert.Equal(t, tt.expectedResult.TotalInputToken, result.TotalInputToken) + assert.Equal(t, tt.expectedResult.TotalOutputToken, result.TotalOutputToken) + assert.Greater(t, result.CompletionTime, float64(0), "completion time should be greater than 0") + }) + } +} + +func TestAnthropicLLMProvider_GetStreamingResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + config LLMRequestConfig + streamText []string + expectError bool + }{ + { + name: "successful streaming response", + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + config: LLMRequestConfig{ + MaxToken: 100, + TopP: 0.9, + Temperature: 0.7, + }, + streamText: []string{"Hello", " world", "!"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &MockAnthropicClient{ + createStreamingMessageFunc: func(ctx context.Context, params anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEvent] { + var events []anthropic.MessageStreamEvent + + // Create start event + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStart, + Message: anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }, + }) + + // Create content block delta events + for i, text := range tt.streamText { + t.Logf("Adding delta event %d with text: %q", i, text) + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockDelta, + Index: int64(i), + Delta: anthropic.ContentBlockDeltaEventDelta{ + Type: anthropic.ContentBlockDeltaEventDeltaTypeTextDelta, + Text: text, + }, + }) + } + + // Add stop event + events = append(events, anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStop, + }) + + decoder := &mockDecoder{ + events: events, + index: -1, + } + + stream := ssestream.NewStream[anthropic.MessageStreamEvent](decoder, nil) + return stream + }, } - if provider.client == nil { - t.Error("expected client to be initialized") + + provider := NewAnthropicLLMProvider(AnthropicProviderConfig{ + Client: mockClient, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }) + + ctx := context.Background() + stream, err := provider.GetStreamingResponse(ctx, tt.messages, tt.config) + assert.NoError(t, err) + + var receivedText string + for chunk := range stream { + t.Logf("Received streaming chunk: %+v", chunk) + if chunk.Error != nil { + t.Fatalf("Unexpected error: %v", chunk.Error) + } + if !chunk.Done { + receivedText += chunk.Text + t.Logf("Current accumulated text: %q", receivedText) + } } + + t.Logf("Final text: %q", receivedText) + assert.Equal(t, strings.Join(tt.streamText, ""), receivedText) }) } } + +func createStreamEvent(eventType string, index int64, text string) anthropic.MessageStreamEvent { + var event anthropic.MessageStreamEvent + + switch eventType { + case "message_start": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStart, + Message: anthropic.Message{ + Role: anthropic.MessageRoleAssistant, + Model: anthropic.ModelClaude_3_5_Sonnet_20240620, + }, + } + case "content_block_delta": + /*textDelta := anthropic.TextDelta{ + Type: anthropic.TextDeltaTypeTextDelta, + Text: text, + }*/ + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockDelta, + Index: index, + Delta: anthropic.ContentBlockDeltaEventDelta{ + Type: anthropic.ContentBlockDeltaEventDeltaTypeTextDelta, + Text: text, + }, + } + case "content_block_stop": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeContentBlockStop, + Index: index, + } + case "message_stop": + event = anthropic.MessageStreamEvent{ + Type: anthropic.MessageStreamEventTypeMessageStop, + } + } + + return event +} diff --git a/ai/llm_provider_openai.go b/ai/llm_provider_openai.go index 95bd30a..ecb76e7 100644 --- a/ai/llm_provider_openai.go +++ b/ai/llm_provider_openai.go @@ -34,11 +34,8 @@ func NewOpenAILLMProvider(config OpenAIProviderConfig) *OpenAILLMProvider { } } -// GetResponse generates a response using OpenAI's API for the given messages and configuration. -// It supports different message roles (user, assistant, system) and handles them appropriately. -func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { - startTime := time.Now() - +// convertToOpenAIMessages converts internal message format to OpenAI's format +func (p *OpenAILLMProvider) convertToOpenAIMessages(messages []LLMMessage) []openai.ChatCompletionMessageParamUnion { var openAIMessages []openai.ChatCompletionMessageParamUnion for _, msg := range messages { switch msg.Role { @@ -52,14 +49,27 @@ func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequest openAIMessages = append(openAIMessages, openai.UserMessage(msg.Text)) } } + return openAIMessages +} - params := openai.ChatCompletionNewParams{ - Messages: openai.F(openAIMessages), +// createCompletionParams creates OpenAI API parameters from request config +func (p *OpenAILLMProvider) createCompletionParams(messages []openai.ChatCompletionMessageParamUnion, config LLMRequestConfig) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Messages: openai.F(messages), Model: openai.F(p.model), MaxTokens: openai.Int(config.MaxToken), TopP: openai.Float(config.TopP), Temperature: openai.Float(config.Temperature), } +} + +// GetResponse generates a response using OpenAI's API for the given messages and configuration. +// It supports different message roles (user, assistant, system) and handles them appropriately. +func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) { + startTime := time.Now() + + openAIMessages := p.convertToOpenAIMessages(messages) + params := p.createCompletionParams(openAIMessages, config) completion, err := p.client.Chat.Completions.New(context.Background(), params) if err != nil { @@ -77,3 +87,48 @@ func (p *OpenAILLMProvider) GetResponse(messages []LLMMessage, config LLMRequest CompletionTime: time.Since(startTime).Seconds(), }, nil } + +// GetStreamingResponse generates a streaming response using OpenAI's API. +// It supports streaming tokens as they're generated and handles context cancellation. +func (p *OpenAILLMProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + openAIMessages := p.convertToOpenAIMessages(messages) + params := p.createCompletionParams(openAIMessages, config) + + stream := p.client.Chat.Completions.NewStreaming(ctx, params) + responseChan := make(chan StreamingLLMResponse, 100) + + go func() { + defer close(responseChan) + + for stream.Next() { + select { + case <-ctx.Done(): + responseChan <- StreamingLLMResponse{ + Error: ctx.Err(), + Done: true, + } + return + default: + chunk := stream.Current() + if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { + responseChan <- StreamingLLMResponse{ + Text: chunk.Choices[0].Delta.Content, + TokenCount: 1, + } + } + } + } + + if err := stream.Err(); err != nil { + responseChan <- StreamingLLMResponse{ + Error: err, + Done: true, + } + return + } + + responseChan <- StreamingLLMResponse{Done: true} + }() + + return responseChan, nil +} diff --git a/ai/llm_provider_openai_test.go b/ai/llm_provider_openai_test.go index b6da52b..26dd911 100644 --- a/ai/llm_provider_openai_test.go +++ b/ai/llm_provider_openai_test.go @@ -1,9 +1,14 @@ package ai import ( + "context" + "io" + "net/http" "testing" + "time" "github.com/openai/openai-go" + "github.com/openai/openai-go/option" ) func TestOpenAILLMProvider_NewOpenAILLMProvider(t *testing.T) { @@ -42,3 +47,94 @@ func TestOpenAILLMProvider_NewOpenAILLMProvider(t *testing.T) { }) } } + +func TestOpenAILLMProvider_GetStreamingResponse(t *testing.T) { + tests := []struct { + name string + messages []LLMMessage + timeout time.Duration + delay time.Duration + wantErr bool + }{ + { + name: "successful streaming", + timeout: 100 * time.Millisecond, + delay: 0, + }, + { + name: "context cancellation", + timeout: 5 * time.Millisecond, + delay: 50 * time.Millisecond, + }, + } + + responses := []string{ + `data: {"id":"123","choices":[{"delta":{"content":"Hello"}}]}`, + `data: {"id":"123","choices":[{"delta":{"content":" world"}}]}`, + `data: [DONE]`, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := &OpenAILLMProvider{ + client: openai.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &mockTransport{ + responses: responses, + delay: tt.delay, + }, + }), + ), + model: openai.ChatModelGPT3_5Turbo, + } + + ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) + defer cancel() + + stream, err := provider.GetStreamingResponse(ctx, []LLMMessage{{Role: UserRole, Text: "test"}}, LLMRequestConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var gotCancel bool + for resp := range stream { + if resp.Error != nil { + gotCancel = true + break + } + } + + if tt.delay > tt.timeout && !gotCancel { + t.Error("expected context cancellation") + } + }) + } +} + +type mockTransport struct { + responses []string + delay time.Duration +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if m.delay > 0 { + time.Sleep(m.delay) + } + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for _, resp := range m.responses { + time.Sleep(10 * time.Millisecond) // Simulate streaming delay + pw.Write([]byte(resp + "\n")) + } + }() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: pr, + }, nil +} diff --git a/ai/llm_test.go b/ai/llm_test.go index f1c7866..9c0c6c6 100644 --- a/ai/llm_test.go +++ b/ai/llm_test.go @@ -1,19 +1,41 @@ package ai import ( + "context" "errors" "testing" ) type mockProvider struct { - response LLMResponse - err error + response LLMResponse + err error + streamResponses []StreamingLLMResponse + streamErr error } func (m *mockProvider) GetResponse(messages []LLMMessage, _ LLMRequestConfig) (LLMResponse, error) { return m.response, m.err } +func (m *mockProvider) GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) { + if m.streamErr != nil { + return nil, m.streamErr + } + + ch := make(chan StreamingLLMResponse, len(m.streamResponses)) + go func() { + defer close(ch) + for _, resp := range m.streamResponses { + select { + case <-ctx.Done(): + return + case ch <- resp: + } + } + }() + return ch, nil +} + func TestLLMRequest_Generate(t *testing.T) { tests := []struct { name string @@ -62,11 +84,11 @@ func TestLLMRequest_Generate(t *testing.T) { err: tt.mockError, } - request := NewLLMRequest(tt.config) + request := NewLLMRequest(tt.config, provider) response, err := request.Generate([]LLMMessage{{ Role: "user", Text: "test prompt", - }}, provider) + }}) if tt.expectedError { if err == nil { @@ -92,3 +114,95 @@ func TestLLMRequest_Generate(t *testing.T) { }) } } + +func TestLLMRequest_GenerateStream(t *testing.T) { + tests := []struct { + name string + config LLMRequestConfig + messages []LLMMessage + streamResponses []StreamingLLMResponse + streamErr error + wantErr bool + }{ + { + name: "successful streaming", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamResponses: []StreamingLLMResponse{ + {Text: "Hello", TokenCount: 1}, + {Text: "World", TokenCount: 1}, + {Done: true}, + }, + }, + { + name: "provider error", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamErr: errors.New("stream error"), + wantErr: true, + }, + { + name: "context cancellation", + config: LLMRequestConfig{ + MaxToken: 100, + }, + messages: []LLMMessage{ + {Role: UserRole, Text: "Hello"}, + }, + streamResponses: []StreamingLLMResponse{ + {Text: "Hello", TokenCount: 1}, + {Error: context.Canceled, Done: true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := &mockProvider{ + streamResponses: tt.streamResponses, + streamErr: tt.streamErr, + } + + request := NewLLMRequest(tt.config, provider) + stream, err := request.GenerateStream(context.Background(), tt.messages) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateStream() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + var got []StreamingLLMResponse + for resp := range stream { + got = append(got, resp) + } + + if len(got) != len(tt.streamResponses) { + t.Errorf("expected %d responses, got %d", len(tt.streamResponses), len(got)) + return + } + + for i, want := range tt.streamResponses { + if got[i].Text != want.Text { + t.Errorf("response[%d].Text = %v, want %v", i, got[i].Text, want.Text) + } + if got[i].Done != want.Done { + t.Errorf("response[%d].Done = %v, want %v", i, got[i].Done, want.Done) + } + if got[i].Error != want.Error { + t.Errorf("response[%d].Error = %v, want %v", i, got[i].Error, want.Error) + } + } + }) + } +} diff --git a/ai/types.go b/ai/types.go index a3bcdda..4484b94 100644 --- a/ai/types.go +++ b/ai/types.go @@ -1,7 +1,10 @@ // Package ai provides a flexible interface for interacting with various Language Learning Models (LLMs). package ai -import "fmt" +import ( + "context" + "fmt" +) // LLMMessageRole represents the role of a message in a conversation. type LLMMessageRole string @@ -123,10 +126,24 @@ type LLMMessage struct { Text string } +// StreamingLLMResponse represents a chunk of streaming response from an LLM provider. +// It contains partial text, completion status, any errors, and token usage information. +type StreamingLLMResponse struct { + // Text contains the partial response text + Text string + // Done indicates if this is the final chunk + Done bool + // Error contains any error that occurred during streaming + Error error + // TokenCount is the number of tokens in this chunk + TokenCount int +} + // LLMProvider defines the interface that all LLM providers must implement. // This allows for easy swapping between different LLM providers. type LLMProvider interface { // GetResponse generates a response for the given question using the specified configuration. // Returns LLMResponse containing the generated text and metadata, or an error if the operation fails. GetResponse(messages []LLMMessage, config LLMRequestConfig) (LLMResponse, error) + GetStreamingResponse(ctx context.Context, messages []LLMMessage, config LLMRequestConfig) (<-chan StreamingLLMResponse, error) }