Skip to content

Commit

Permalink
fix: anthropic system cache, vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Sep 1, 2024
1 parent 04bb27e commit 8d6e66d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
26 changes: 16 additions & 10 deletions packages/llm/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,21 @@ func NewAnthropicProviderWithCache(apiKey string) *Provider {
}
}

func reqToMessages(req llm.InferRequest) ([]anthropic.Message, *string, error) {
systemPrompt := ""
func reqToMessages(req llm.InferRequest) ([]anthropic.Message, []anthropic.MessageSystemPart, error) {
msgs := make([]anthropic.Message, 0)
systemMsgs := make([]anthropic.MessageSystemPart, 0)
for _, m := range req.Messages {
if m.Role == "system" {
systemPrompt += m.Content
msgContent := anthropic.MessageSystemPart{
Type: "text",
Text: m.Content,
}
if m.ShouldCache {
msgContent.CacheControl = &anthropic.MessageCacheControl{
Type: anthropic.CacheControlTypeEphemeral,
}
}
systemMsgs = append(systemMsgs, msgContent)
continue
}

Expand Down Expand Up @@ -65,10 +74,7 @@ func reqToMessages(req llm.InferRequest) ([]anthropic.Message, *string, error) {
msgs = append(msgs, newMsg)
}

if systemPrompt != "" {
return msgs, &systemPrompt, nil
}
return msgs, nil, nil
return msgs, systemMsgs, nil
}

func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
Expand All @@ -82,8 +88,8 @@ func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (
MaxTokens: req.MessageOptions.MaxTokens,
Temperature: &req.MessageOptions.Temperature,
}
if systemPrompt != nil {
msgsReq.System = *systemPrompt
if systemPrompt != nil && len(systemPrompt) > 0 {
msgsReq.MultiSystem = systemPrompt
}
res, err := p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{
MessagesRequest: msgsReq,
Expand Down Expand Up @@ -111,7 +117,7 @@ func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferReque
Temperature: &req.MessageOptions.Temperature,
}
if systemPrompt != nil {
msgsReq.System = *systemPrompt
msgsReq.MultiSystem = systemPrompt
}

_, err = p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{
Expand Down
44 changes: 34 additions & 10 deletions packages/llm/providers/vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,18 @@ func (p *VertexAIProvider) generateResponseSingleTurn(ctx context.Context, req l

func (p *VertexAIProvider) generateResponseMultiTurn(ctx context.Context, req llm.InferRequest) (string, error) {
model := p.getModel(req)
cs := model.StartChat()

// Add previous messages to chat history
for _, msg := range req.Messages[:len(req.Messages)-1] {
parts := messageToParts(msg)
cs.History = append(cs.History, &genai.Content{
Parts: parts,
Role: msg.Role,
})
msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1])
if sysInstr != nil {
model.SystemInstruction = sysInstr
}

cs := model.StartChat()
cs.History = msgs
mostRecentMessage := req.Messages[len(req.Messages)-1]

// Send the last message
lastMsg := req.Messages[len(req.Messages)-1]
resp, err := cs.SendMessage(ctx, messageToParts(lastMsg)...)
resp, err := cs.SendMessage(ctx, genai.Text(mostRecentMessage.Content))
if err != nil {
return "", errors.Wrap(err, "failed to send message in chat")
}
Expand Down Expand Up @@ -188,6 +186,32 @@ func messageToParts(message llm.InferMessage) []genai.Part {
return parts
}

func multiTurnMessageToParts(messages []llm.InferMessage) ([]*genai.Content, *genai.Content) {
sysInstructionParts := make([]genai.Part, 0)
hist := make([]*genai.Content, 0, len(messages))
for _, message := range messages {
parts := []genai.Part{genai.Text(message.Content)}
if message.Image != nil && len(message.Image) > 0 {
parts = append(parts, genai.ImageData("png", message.Image))
}
if message.Role == "system" {
sysInstructionParts = append(sysInstructionParts, parts...)
continue
}
hist = append(hist, &genai.Content{
Parts: parts,
Role: message.Role,
})
}
if len(sysInstructionParts) > 0 {
return hist, &genai.Content{
Parts: sysInstructionParts,
}
}

return hist, nil
}

func flattenResponse(resp *genai.GenerateContentResponse) string {
var result string
for _, cand := range resp.Candidates {
Expand Down

0 comments on commit 8d6e66d

Please sign in to comment.