Skip to content

Commit

Permalink
Updated Amazon Bedrock models.
Browse files Browse the repository at this point in the history
updated Calculate method and usedSettings
  • Loading branch information
curlyfro committed Feb 26, 2024
1 parent 31f46d4 commit b9f66a9
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
<PackageVersion Include="LeonardoAi" Version="0.1.0" />
<PackageVersion Include="LLamaSharp" Version="0.10.0" />
<PackageVersion Include="LLamaSharp.Backend.Cpu" Version="0.10.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.8.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.8.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.PublicApiAnalyzers" Version="3.3.4" />
<PackageVersion Include="Microsoft.Data.Sqlite.Core" Version="8.0.2" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Diagnostics;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using LangChain.Providers.Amazon.Bedrock.Internal;

Expand All @@ -10,15 +12,13 @@ public class Ai21LabsJurassic2ChatModel(
string id)
: ChatModel(id)
{
public override int ContextLength => 4096;

public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

Expand All @@ -30,20 +30,21 @@ public override async Task<ChatResponse> GenerateAsync(
Id,
new JsonObject
{
{ "prompt", prompt },
{ "maxTokens", usedSettings.MaxTokens!.Value },
{ "temperature", usedSettings.Temperature!.Value }
["prompt"] = prompt,
["maxTokens"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["topP"] = usedSettings.TopP!.Value,
["stopSequences"] = usedSettings.StopSequences!.AsArray()
},
cancellationToken).ConfigureAwait(false);

var generatedText = response?["completions"]?
.AsArray()[0]?["data"]?
.AsObject()["text"]?.GetValue<string>() ?? "";

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand Down
8 changes: 5 additions & 3 deletions src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Diagnostics;
using System.Linq;
using System.Text.Json.Nodes;
using LangChain.Providers.Amazon.Bedrock.Internal;

Expand All @@ -18,7 +19,7 @@ public override async Task<ChatResponse> GenerateAsync(
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

Expand All @@ -35,13 +36,14 @@ public override async Task<ChatResponse> GenerateAsync(
{
["maxTokenCount"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["topP"] = 0.9
["topP"] = usedSettings.TopP!.Value,
["stopSequences"] = usedSettings.StopSequences!.AsArray()
}
},
cancellationToken).ConfigureAwait(false);

var generatedText = response?["results"]?[0]?["outputText"]?.GetValue<string>() ?? string.Empty;

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ public class AnthropicClaudeChatModel(
string id)
: ChatModel(id)
{
public override int ContextLength => 4096;

public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToRolePrompt();

Expand All @@ -30,10 +28,12 @@ public override async Task<ChatResponse> GenerateAsync(
Id,
new JsonObject
{
{ "prompt", prompt },
{ "max_tokens_to_sample", usedSettings.MaxTokens!.Value },
{ "temperature", usedSettings.Temperature!.Value },
{ "stop_sequences", new JsonArray("\n\nHuman:") }
["prompt"] = prompt,
["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["top_p"] = usedSettings.TopP!.Value,
["top_k"] = usedSettings.TopK!.Value,
["stop_sequences"] = new JsonArray("\n\nHuman:")
},
cancellationToken).ConfigureAwait(false);

Expand Down
38 changes: 32 additions & 6 deletions src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,32 @@ public class BedrockChatSettings : ChatSettings
User = ChatSettings.Default.User,
Temperature = 0.7,
MaxTokens = 4096,
TopP = 0.9,
TopK = 0.0
};

/// <summary>
/// Sampling temperature
/// </summary>
public double? Temperature { get; init; }

/// <summary>
///
/// </summary>
public int? MaxTokens { get; init; }

/// <summary>
/// The cumulative probability cutoff for token selection.
/// Lower values mean sampling from a smaller, more top-weighted nucleus
/// </summary>
public double? TopP { get; init; }

/// <summary>
/// Sample from the k most likely next tokens at each step.
/// Lower k focuses on higher probability tokens.
/// </summary>
public double? TopK { get; init; }

/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
Expand All @@ -37,33 +51,45 @@ public class BedrockChatSettings : ChatSettings
var requestSettingsCasted = requestSettings as BedrockChatSettings;
var modelSettingsCasted = modelSettings as BedrockChatSettings;
var providerSettingsCasted = providerSettings as BedrockChatSettings;

return new BedrockChatSettings
{
StopSequences =
StopSequences =
requestSettingsCasted?.StopSequences ??
modelSettingsCasted?.StopSequences ??
providerSettingsCasted?.StopSequences ??
Default.StopSequences ??
throw new InvalidOperationException("Default StopSequences is not set."),
User =
User =
requestSettingsCasted?.User ??
modelSettingsCasted?.User ??
providerSettingsCasted?.User ??
Default.User ??
throw new InvalidOperationException("Default User is not set."),
Temperature =
Temperature =
requestSettingsCasted?.Temperature ??
modelSettingsCasted?.Temperature ??
providerSettingsCasted?.Temperature ??
Default.Temperature ??
throw new InvalidOperationException("Default Temperature is not set."),
MaxTokens =
MaxTokens =
requestSettingsCasted?.MaxTokens ??
modelSettingsCasted?.MaxTokens ??
providerSettingsCasted?.MaxTokens ??
Default.MaxTokens ??
throw new InvalidOperationException("Default MaxTokens is not set."),
TopP =
requestSettingsCasted?.TopP ??
modelSettingsCasted?.TopP ??
providerSettingsCasted?.TopP ??
Default.TopP ??
throw new InvalidOperationException("Default TopP is not set."),
TopK =
requestSettingsCasted?.TopK ??
modelSettingsCasted?.TopK ??
providerSettingsCasted?.TopK ??
Default.TopK ??
throw new InvalidOperationException("Default TopK is not set."),
};
}
}
17 changes: 7 additions & 10 deletions src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ public abstract class CohereCommandChatModel(
string id)
: ChatModel(id)
{
public override int ContextLength => 4096;

public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

Expand All @@ -30,20 +28,19 @@ public override async Task<ChatResponse> GenerateAsync(
Id,
new JsonObject
{
{ "prompt", prompt },
{ "max_tokens", usedSettings.MaxTokens!.Value },
{ "temperature", usedSettings.Temperature!.Value },
{ "p",1 },
{ "k",0 },
["prompt"] = prompt,
["max_tokens"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["p"] = usedSettings.TopP!.Value,
["k"] = usedSettings.TopK!.Value,
},
cancellationToken).ConfigureAwait(false);

var generatedText = response?["generations"]?[0]?["text"]?.GetValue<string>() ?? string.Empty;

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand Down
24 changes: 11 additions & 13 deletions src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,29 @@ public class MetaLlama2ChatModel(
: ChatModel(id)
{
public override int ContextLength => 4096;

public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

// TODO: implement settings
// var usedSettings = MetaLlama2ChatSettings.Calculate(
// requestSettings: settings,
// modelSettings: Settings,
// providerSettings: provider.ChatSettings);
var usedSettings = BedrockChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var response = await provider.Api.InvokeModelAsync(
Id,
new JsonObject
{
{ "prompt", prompt },
{ "max_gen_len", 512 },
{ "temperature", 0.5 },
{ "top_p", 0.9 },
["prompt"] = prompt,
["max_gen_len"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["topP"] = usedSettings.TopP!.Value,
},
cancellationToken).ConfigureAwait(false);

Expand All @@ -44,7 +43,6 @@ public override async Task<ChatResponse> GenerateAsync(
var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand All @@ -55,7 +53,7 @@ public override async Task<ChatResponse> GenerateAsync(
return new ChatResponse
{
Messages = result,
UsedSettings = ChatSettings.Default,
UsedSettings = usedSettings,
Usage = usage,
};
}
Expand Down
24 changes: 24 additions & 0 deletions src/Providers/Amazon.Bedrock/src/StringArrayExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Text.Json.Nodes;

namespace LangChain.Providers.Amazon.Bedrock;

/// <summary>
///
/// </summary>
public static class StringArrayExtensions
{
/// <summary>
///
/// </summary>
/// <param name="stringArray"></param>
/// <returns></returns>
public static JsonArray AsArray(this IReadOnlyList<string> stringArray)
{
var jsonArray = new JsonArray();
foreach (var arr in stringArray)
{
jsonArray.Add(arr);
}
return jsonArray;
}
}
10 changes: 4 additions & 6 deletions src/Providers/Amazon.Bedrock/test/BedrockTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LangChain.Base;
using System.Diagnostics;
using LangChain.Chains.LLM;
using LangChain.Chains.Sequentials;
Expand All @@ -9,26 +8,25 @@
using LangChain.Prompts;
using LangChain.Providers.Amazon.Bedrock.Predefined.Ai21Labs;
using LangChain.Providers.Amazon.Bedrock.Predefined.Amazon;
using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic;
using LangChain.Providers.Amazon.Bedrock.Predefined.Meta;
using LangChain.Providers.Amazon.Bedrock.Predefined.Stability;
using LangChain.Schema;
using LangChain.Sources;
using LangChain.Splitters;
using LangChain.Splitters.Text;
using static LangChain.Chains.Chain;

namespace LangChain.Providers.Amazon.Bedrock.IntegrationTests;
namespace LangChain.Providers.Amazon.Bedrock.Tests;

[TestFixture, Explicit]
public class BedrockTests
{

[Test]
public async Task Chains()
{
var provider = new BedrockProvider();
//var llm = new Jurassic2MidModel(provider);
//var llm = new ClaudeV21Model(provider);
var llm = new Llama2With13BModel(provider);
var llm = new ClaudeV21Model(provider);
//var modelId = "amazon.titan-text-express-v1";
// var modelId = "cohere.command-light-text-v14";

Expand Down
Loading

0 comments on commit b9f66a9

Please sign in to comment.