Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added AiLabs Jamba model. updated Amazon.Bedrock nugets #376

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
<ItemGroup>
<PackageVersion Include="AngleSharp" Version="1.1.2" />
<PackageVersion Include="Anthropic" Version="0.3.1" />
<PackageVersion Include="Anthropic.SDK" Version="3.2.3" />
<PackageVersion Include="Anthropic.SDK" Version="3.2.1" />
<PackageVersion Include="Anyscale" Version="1.0.2" />
<PackageVersion Include="Aspose.PDF" Version="24.5.1" />
<PackageVersion Include="AWSSDK.BedrockAgentRuntime" Version="3.7.308.23" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.306.7" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.301.56" />
<PackageVersion Include="AWSSDK.BedrockAgentRuntime" Version="3.7.309" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.307" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.301.44" />
<PackageVersion Include="AWSSDK.OpenSearchService" Version="3.7.305.8" />
<PackageVersion Include="AWSSDK.SageMakerRuntime" Version="3.7.301.37" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.17" />
Expand Down Expand Up @@ -40,7 +40,7 @@
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.10.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.PublicApiAnalyzers" Version="3.3.4" />
<PackageVersion Include="Microsoft.CSharp" Version="4.7.0" />
<PackageVersion Include="Microsoft.Data.Sqlite.Core" Version="8.0.7" />
<PackageVersion Include="Microsoft.Data.Sqlite.Core" Version="8.0.6" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.10.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.AzureAISearch" Version="1.10.0-alpha" />
Expand Down Expand Up @@ -77,10 +77,10 @@
<PackageVersion Include="System.Net.Http.Json" Version="8.0.0" />
<PackageVersion Include="System.Text.Json" Version="8.0.4" />
<PackageVersion Include="System.ValueTuple" Version="4.5.0" />
<PackageVersion Include="Testcontainers" Version="3.9.0" />
<PackageVersion Include="Testcontainers.MongoDb" Version="3.9.0" />
<PackageVersion Include="Testcontainers.PostgreSql" Version="3.9.0" />
<PackageVersion Include="Testcontainers.Redis" Version="3.9.0" />
<PackageVersion Include="Testcontainers" Version="3.8.0" />
<PackageVersion Include="Testcontainers.MongoDb" Version="3.8.0" />
<PackageVersion Include="Testcontainers.PostgreSql" Version="3.8.0" />
<PackageVersion Include="Testcontainers.Redis" Version="3.8.0" />
<PackageVersion Include="Tiktoken" Version="2.0.2" />
<PackageVersion Include="tryAGI.OpenAI" Version="2.0.9" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="8.0.1" />
Expand Down
87 changes: 87 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Chat/Ai21LabsJambaChatModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using System.Diagnostics;
using System.Text.Json.Nodes;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public class Ai21LabsJambaChatModel(
BedrockProvider provider,
string id)
: ChatModel(id)
{
/// <summary>
/// Generates a chat response based on the provided `ChatRequest`.
/// </summary>
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param>
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns>
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

var usedSettings = Ai21LabJambaChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);

var bodyJson = CreateBodyJson(prompt, usedSettings);

var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var generatedText = response?["choices"]?.AsArray()
[0]?["message"]?.AsObject()
.AsObject()["content"]?.GetValue<string>() ?? "";

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = result,
UsedSettings = usedSettings,
Usage = usage,
};
}
Comment on lines +20 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method GenerateAsync looks good!

Consider adding error handling and logging to improve robustness and traceability.

+ try
+ {
    var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
        .ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+     // Log the exception and rethrow or handle it appropriately
+     throw new ApplicationException("Error invoking model", ex);
+ }

Committable suggestion was skipped due to low confidence.


/// <summary>
/// Creates the request body JSON for the Ai21Labs model based on the provided prompt and settings.
/// </summary>
/// <param name="prompt">The input prompt for the model.</param>
/// <param name="usedSettings">The settings to use for the request.</param>
/// <returns>A `JsonObject` representing the request body.</returns>
private static JsonObject CreateBodyJson(string prompt, Ai21LabJambaChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["messages"] = new JsonArray
{
new JsonObject
{
["role"] = "user",
["content"] = prompt
}
},
["max_tokens"] = usedSettings.MaxTokens!.Value,
["top_p"] = usedSettings.TopP!.Value,
["temperature"] = usedSettings.Temperature!.Value,

};
return bodyJson;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ private static JsonObject CreateBodyJson(

var bodyJson = new JsonObject
{
["anthropic_version"] = "bedrock-2023-05-31",
["max_tokens"] = usedSettings.MaxTokens!.Value,
["messages"] = new JsonArray
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public class Ai21LabJambaChatSettings : BedrockChatSettings
{
public new static Ai21LabJambaChatSettings Default { get; } = new()
{
StopSequences = ChatSettings.Default.StopSequences,
User = ChatSettings.Default.User,
UseStreaming = false,
Temperature = 0.7,
MaxTokens = 4000,
TopP = 0.8,
TopK = 0.0
};

/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public new static Ai21LabJambaChatSettings Calculate(
ChatSettings? requestSettings,
ChatSettings? modelSettings,
ChatSettings? providerSettings)
{
var requestSettingsCasted = requestSettings as Ai21LabJambaChatSettings;
var modelSettingsCasted = modelSettings as Ai21LabJambaChatSettings;
var providerSettingsCasted = providerSettings as Ai21LabJambaChatSettings;

return new Ai21LabJambaChatSettings
{
StopSequences =
requestSettingsCasted?.StopSequences ??
modelSettingsCasted?.StopSequences ??
providerSettingsCasted?.StopSequences ??
Default.StopSequences ??
throw new InvalidOperationException("Default StopSequences is not set."),
User =
requestSettingsCasted?.User ??
modelSettingsCasted?.User ??
providerSettingsCasted?.User ??
Default.User ??
throw new InvalidOperationException("Default User is not set."),
UseStreaming =
requestSettings?.UseStreaming ??
modelSettings?.UseStreaming ??
providerSettings?.UseStreaming ??
Default.UseStreaming ??
throw new InvalidOperationException("Default UseStreaming is not set."),
Temperature =
requestSettingsCasted?.Temperature ??
modelSettingsCasted?.Temperature ??
providerSettingsCasted?.Temperature ??
Default.Temperature ??
throw new InvalidOperationException("Default Temperature is not set."),
MaxTokens =
requestSettingsCasted?.MaxTokens ??
modelSettingsCasted?.MaxTokens ??
providerSettingsCasted?.MaxTokens ??
Default.MaxTokens ??
throw new InvalidOperationException("Default MaxTokens is not set."),
TopP =
requestSettingsCasted?.TopP ??
modelSettingsCasted?.TopP ??
providerSettingsCasted?.TopP ??
Default.TopP ??
throw new InvalidOperationException("Default TopP is not set."),
TopK =
requestSettingsCasted?.TopK ??
modelSettingsCasted?.TopK ??
providerSettingsCasted?.TopK ??
Default.TopK ??
throw new InvalidOperationException("Default TopK is not set."),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,11 @@
<ProjectReference Include="..\..\Abstractions\src\LangChain.Providers.Abstractions.csproj" />
</ItemGroup>

<ItemGroup>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future - this has already been applied to all packages by default via the Directory.Build.props file

<PackageReference Update="DotNet.ReproducibleBuilds">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
</ItemGroup>

</Project>
6 changes: 5 additions & 1 deletion src/Providers/Amazon.Bedrock/src/Predefined/Ai21Labs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ public class Jurassic2MidModel(BedrockProvider provider)

/// <inheritdoc />
public class Jurassic2UltraModel(BedrockProvider provider)
: Ai21LabsJurassic2ChatModel(provider, id: "ai21.j2-ultra-v1");
: Ai21LabsJurassic2ChatModel(provider, id: "ai21.j2-ultra-v1");

/// <inheritdoc />
public class JambaInstructModel(BedrockProvider provider)
: Ai21LabsJambaChatModel(provider, id: "ai21.jamba-instruct-v1:0");
4 changes: 2 additions & 2 deletions src/Providers/Amazon.Bedrock/test/BedrockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public class BedrockTests
[Test]
public async Task Chains()
{
var provider = new BedrockProvider(RegionEndpoint.USWest2);
var provider = new BedrockProvider();
//var llm = new Jurassic2MidModel(provider);
//var llm = new ClaudeV21Model(provider);
//var llm = new Mistral7BInstruct(provider);
var llm = new CommandRModel(provider);
var llm = new JambaInstructModel(provider);

var template = "What is a good name for a company that makes {product}?";
var prompt = new PromptTemplate(new PromptTemplateInput(template, new List<string>(1) { "product" }));
Expand Down
Loading