Skip to content

Commit

Permalink
feat: added Amazon Bedrock Cohere R plus and Cohere R models (#285)
Browse files Browse the repository at this point in the history
fix: fixed small issues
  • Loading branch information
curlyfro authored May 2, 2024
1 parent 7e173b7 commit 2e8aa1c
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 10 deletions.
105 changes: 105 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class CohereCommandRModel(
BedrockProvider provider,
string id)
: ChatModel(id)
{
/// <summary>
/// Generates a chat response based on the provided `ChatRequest`.
/// </summary>
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param>
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns>
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = CohereCommandChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false);

foreach (var payloadPart in response.Body)
{
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty;

OnPartialResponseGenerated(delta);
stringBuilder.Append(delta);

var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty;
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal))
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty;

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}

/// <summary>
/// Creates the request body JSON for the Cohere model based on the provided prompt and settings.
/// </summary>
/// <param name="prompt">The input prompt for the model.</param>
/// <param name="usedSettings">The settings to use for the request.</param>
/// <returns>A `JsonObject` representing the request body.</returns>
private static JsonObject CreateBodyJson(string prompt, CohereCommandChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["message"] = prompt,
};
return bodyJson;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class MetaLlama2ChatSettings : BedrockChatSettings
User = ChatSettings.Default.User,
UseStreaming = false,
Temperature = 0.5,
MaxTokens = 4000,
MaxTokens = 2048,
TopP = 0.9,
TopK = 0.0
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System.Diagnostics;
using System.Text.Json.Nodes;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class AmazonTitanEmbeddingV2Model(
BedrockProvider provider,
string id)
: Model<EmbeddingSettings>(id), IEmbeddingModel
{

/// <summary>
/// Creates embeddings for the input strings using the Amazon model.
/// </summary>
/// <param name="request">The `EmbeddingRequest` containing the input strings.</param>
/// <param name="settings">Optional `EmbeddingSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>An `EmbeddingResponse` containing the generated embeddings and usage information.</returns>

public async Task<EmbeddingResponse> CreateEmbeddingsAsync(
EmbeddingRequest request,
EmbeddingSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();

var usedSettings = AmazonV2EmbeddingSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.EmbeddingSettings);

var embeddings = new List<float[]>(capacity: request.Strings.Count);

var tasks = request.Strings.Select(text =>
{
var bodyJson = CreateBodyJson(text, usedSettings);
return provider.Api.InvokeModelAsync(Id, bodyJson,
cancellationToken);
})
.ToList();
var results = await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var response in results)
{
var embedding = response?["embedding"]?.AsArray();
if (embedding == null) continue;

var f = new float[(int)usedSettings.Dimensions!];
for (var i = 0; i < embedding.Count; i++)
{
f[i] = (float)embedding[(Index)i]?.AsValue()!;
}

embeddings.Add(f);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new EmbeddingResponse
{
Values = embeddings.ToArray(),
Usage = Usage.Empty,
UsedSettings = usedSettings,
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0,
};
}

private static JsonObject CreateBodyJson(string? prompt, AmazonV2EmbeddingSettings usedSettings)
{
var bodyJson = new JsonObject
{
["inputText"] = prompt,
["dimensions"] = usedSettings.Dimensions,
["normalize"] = usedSettings.Normalize
};

return bodyJson;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public class AmazonV2EmbeddingSettings : BedrockEmbeddingSettings
{
public new static AmazonV2EmbeddingSettings Default { get; } = new()
{
Dimensions = 1024,
MaximumInputLength = 10_000,
Normalize = true
};

public bool Normalize { get; set; }

/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public new static AmazonV2EmbeddingSettings Calculate(
EmbeddingSettings? requestSettings,
EmbeddingSettings? modelSettings,
EmbeddingSettings? providerSettings)
{
var requestSettingsCasted = requestSettings as AmazonV2EmbeddingSettings;
var modelSettingsCasted = modelSettings as AmazonV2EmbeddingSettings;
var providerSettingsCasted = providerSettings as AmazonV2EmbeddingSettings;

return new AmazonV2EmbeddingSettings
{
Dimensions =
requestSettingsCasted?.Dimensions ??
modelSettingsCasted?.Dimensions ??
providerSettingsCasted?.Dimensions ??
Default.Dimensions ??
throw new InvalidOperationException("Default Dimensions is not set."),

MaximumInputLength =
requestSettingsCasted?.MaximumInputLength ??
modelSettingsCasted?.MaximumInputLength ??
providerSettingsCasted?.MaximumInputLength ??
Default.MaximumInputLength ??
throw new InvalidOperationException("Default MaximumInputLength is not set."),

Normalize =
requestSettingsCasted?.Normalize ??
modelSettingsCasted?.Normalize ??
providerSettingsCasted?.Normalize ??
Default.Normalize,
};
}
}
4 changes: 4 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ public class TitanTextLiteV1Model(BedrockProvider provider)
public class TitanEmbedTextV1Model(BedrockProvider provider)
: AmazonTitanEmbeddingModel(provider, id: "amazon.titan-embed-text-v1");

/// <inheritdoc />
public class TitanEmbedTextV2Model(BedrockProvider provider)
: AmazonTitanEmbeddingV2Model(provider, id: "amazon.titan-embed-text-v2:0");

/// <inheritdoc />
public class TitanEmbedImageV1Model(BedrockProvider provider)
: AmazonTitanImageEmbeddingModel(provider, id: "amazon.titan-embed-image-v1");
Expand Down
8 changes: 8 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ public class EmbedEnglishV3Model(BedrockProvider provider)
/// <inheritdoc />
public class EmbedMultilingualV3Model(BedrockProvider provider)
: CohereEmbeddingModel(provider, id: "cohere.embed-multilingual-v3");

/// <inheritdoc />
public class CommandRPlusModel(BedrockProvider provider)
: CohereCommandRModel(provider, id: "cohere.command-r-plus-v1:0");

/// <inheritdoc />
public class CommandRModel(BedrockProvider provider)
: CohereCommandRModel(provider, id: "cohere.command-r-v1:0");
2 changes: 1 addition & 1 deletion src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class Mistral7BInstruct(BedrockProvider provider)

/// <inheritdoc />
public class Mistral8x7BInstruct(BedrockProvider provider)
: MistralModel(provider, id: "mistral.mistral-8x7b-instruct-v0:1");
: MistralModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1");

/// <inheritdoc />
public class MistralLarge(BedrockProvider provider)
Expand Down
7 changes: 4 additions & 3 deletions src/Providers/Amazon.Bedrock/test/BedrockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic;
using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere;
using LangChain.Providers.Amazon.Bedrock.Predefined.Meta;
using LangChain.Providers.Amazon.Bedrock.Predefined.Mistral;
using LangChain.Providers.Amazon.Bedrock.Predefined.Stability;
using LangChain.Schema;
using LangChain.Splitters.Text;
using static LangChain.Chains.Chain;
using Microsoft.SemanticKernel.AI.Embeddings;

namespace LangChain.Providers.Amazon.Bedrock.Tests;

Expand All @@ -31,7 +31,7 @@ public async Task Chains()
//var llm = new Jurassic2MidModel(provider);
//var llm = new ClaudeV21Model(provider);
//var llm = new Mistral7BInstruct(provider);
var llm = new Claude3SonnetModel(provider);
var llm = new CommandRModel(provider);

var template = "What is a good name for a company that makes {product}?";
var prompt = new PromptTemplate(new PromptTemplateInput(template, new List<string>(1) { "product" }));
Expand Down Expand Up @@ -271,7 +271,7 @@ public async Task ClaudeImageToText()
public async Task SimpleTest(bool useStreaming, bool useChatSettings)
{
var provider = new BedrockProvider();
var llm = new CommandLightTextV14Model(provider);
var llm = new CommandRModel(provider);

llm.PromptSent += (_, prompt) => Console.WriteLine($"Prompt: {prompt}");
llm.PartialResponseGenerated += (_, delta) => Console.Write(delta);
Expand All @@ -290,6 +290,7 @@ you are a comic book writer. you will be given a question and you will answer i
{
var response = await llm.GenerateAsync(prompt);
response.LastMessageContent.Should().NotBeNull();
Console.WriteLine(response.LastMessageContent);
}
}

Expand Down
18 changes: 13 additions & 5 deletions src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.CSharp;
using System.Reflection;
using Amazon;

namespace LangChain.Providers.Amazon.Bedrock.Tests;

Expand All @@ -20,7 +21,7 @@ public Task TestAllTextLLMs()

var derivedTypeNames = FindDerivedTypes(predefinedDir);

var provider = new BedrockProvider();
var provider = new BedrockProvider(RegionEndpoint.USWest2);

var failedTypes = new List<string>();
var workingTypes = new Dictionary<string, double>();
Expand All @@ -44,13 +45,20 @@ public Task TestAllTextLLMs()

Console.WriteLine($"############## {type.FullName}");

object[] args = { provider };
object[] args = [provider];
var llm = (ChatModel)Activator.CreateInstance(type, args)!;
var result = llm.GenerateAsync("who's your favor superhero?");
try
{
var result = llm.GenerateAsync("who's your favorite superhero?");

workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds);
workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds);

Console.WriteLine(result.Result + "\n\n\n");
Console.WriteLine(result.Result + "\n\n\n");
}
catch (Exception e)
{
Console.WriteLine($"**** **** **** ERROR: " + e);
}
}
}
catch (Exception e)
Expand Down

0 comments on commit 2e8aa1c

Please sign in to comment.