Skip to content

Commit

Permalink
feat: added more Amazon Bedrock models and added SageMaker project (#141
Browse files Browse the repository at this point in the history
)

* feat: added AWS Bedrock provider and tests.  TODO add Amazon Titan text and image models and Cohere models

* updated Bedrock and added a project for SageMaker

---------

Co-authored-by: Ty Augustine <[email protected]>
Co-authored-by: Konstantin S <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent 5d98e93 commit 622e0f9
Show file tree
Hide file tree
Showing 37 changed files with 1,595 additions and 13 deletions.
14 changes: 14 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock", "src\libs\Providers\LangChain.Providers.Bedrock\LangChain.Providers.Bedrock.csproj", "{67985CCB-F606-41F8-9D36-513459F58882}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.IntegrationTests", "src\tests\LangChain.Providers.Bedrock.IntegrationTests\LangChain.Providers.Amazon.IntegrationTests.csproj", "{73C76E80-95C5-4C96-A319-4F32043C903E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.Bedrock", "src\libs\Providers\LangChain.Providers.Amazon.Bedrock\LangChain.Providers.Amazon.Bedrock.csproj", "{67985CCB-F606-41F8-9D36-513459F58882}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.SageMaker", "src\libs\Providers\LangChain.Providers.Amazon.Sagemaker\LangChain.Providers.Amazon.SageMaker.csproj", "{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -406,6 +412,10 @@ Global
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.Build.0 = Release|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand All @@ -414,6 +424,10 @@ Global
{67985CCB-F606-41F8-9D36-513459F58882}.Debug|Any CPU.Build.0 = Debug|Any CPU
{67985CCB-F606-41F8-9D36-513459F58882}.Release|Any CPU.ActiveCfg = Release|Any CPU
{67985CCB-F606-41F8-9D36-513459F58882}.Release|Any CPU.Build.0 = Release|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<PackageVersion Include="Aspose.PDF" Version="23.11.0" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.301.33" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.300.5" />
<PackageVersion Include="AWSSDK.SageMakerRuntime" Version="3.7.301.37" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.12" />
<PackageVersion Include="Docker.DotNet" Version="3.125.15" />
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.1.1" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
namespace LangChain.Providers.Amazon.Bedrock
{
public static class AmazonModelIds
{
public static string AI21LabsJurassic2UltraV1 = "ai21.j2-ultra-v1";
public static string AI21LabsJurassic2MidV1 = "ai21.j2-mid-v1";

public static string AmazonTitanEmbeddingsG1TextV1 = "amazon.titan-embed-text-v1";
public static string AmazonTitanTextG1LiteV1 = "amazon.titan-text-lite-v1";
public static string AmazonTitanTextG1ExpressV1 = "amazon.titan-text-express-v1";
public static string AmazonTitanImageGeneratorG1V1 = "amazon.titan-image-generator-v1";
public static string AmazonTitanMultiModalEmbeddingsG1V1 = "amazon.titan-embed-image-v1";

public static string AnthropicClaude2_1 = "anthropic.claude-v2:1";
public static string AnthropicClaude2 = "anthropic.claude-v2";
public static string AnthropicClaude1_3 = "anthropic.claude-v1";
public static string AnthropicClaudeInstant1_2 = "anthropic.claude-instant-v1";

public static string CohereCommand = "cohere.command-text-v14";
public static string CohereCommandLight = "cohere.command-light-text-v14";
public static string CohereEmbedEnglish= "cohere.embed-english-v3";
public static string CohereEmbedMultilingual = "cohere.embed-multilingual-v3";

public static string MetaLlama2Chat13B = "meta.llama2-13b-chat-v1";
public static string MetaLlama2Chat70B = "meta.llama2-70b-chat-v1";
public static string MetaLlama213B = "meta.llama2-13b-v1"; //TODO i'm guessing the model id
public static string MetaLlama270B = "meta.llama2-70b-v1"; //TODO i'm guessing the model id

public static string StabilityAISDXL0_8 = "stability.stable-diffusion-xl-v0";
public static string StabilityAISDXL1_0 = "stability.stable-diffusion-xl-v1";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Amazon.Util;
using LangChain.TextSplitters;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class AmazonTitanEmbeddingsRequest : IBedrockEmbeddingsRequest
{
public async Task<float[][]> EmbedDocumentsAsync(
AmazonBedrockRuntimeClient client,
string[] texts,
BedrockEmbeddingsConfiguration configuration)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

List<float> arrEmbeddings = [];
var inputText = string.Join(" ", texts);
var textSplitter = new RecursiveCharacterTextSplitter(chunkSize: 10_000);
var splitText = textSplitter.SplitText(inputText);

foreach (var text in splitText)
{
try
{
string payload = new JsonObject
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
var body = JsonNode.Parse(response.Body);
var embeddings = body?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
arrEmbeddings.AddRange(embeddings);
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}
}

var result = arrEmbeddings.Select(f => (new[] { f })).ToArray();

return result;
}

public async Task<float[]> EmbedQueryAsync(
AmazonBedrockRuntimeClient client,
string text,
BedrockEmbeddingsConfiguration configuration)
{
text = text ?? throw new ArgumentNullException(nameof(text));
float[]? embeddings = [];

try
{
string payload = new JsonObject()
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
embeddings = JsonNode.Parse(response.Body)?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

return embeddings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Amazon.Util;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class AmazonTitanMultiModalEmbeddingsRequest : IBedrockEmbeddingsRequest
{
public async Task<float[][]> EmbedDocumentsAsync(
AmazonBedrockRuntimeClient client,
string[] texts,
BedrockEmbeddingsConfiguration configuration)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

List<float> arrEmbeddings = [];
var inputText = string.Join(" ", texts);

try
{
var payload = Encoding.UTF8.GetBytes(
JsonSerializer.Serialize(new
{
inputText = inputText,
inputImage = configuration.Base64Image
})
);

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = new MemoryStream(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
var body = JsonNode.Parse(response.Body);
var embeddings = body?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
arrEmbeddings.AddRange(embeddings);
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

var result = arrEmbeddings.Select(f => (new[] { f })).ToArray();

return result;
}

public async Task<float[]> EmbedQueryAsync(
AmazonBedrockRuntimeClient client,
string text,
BedrockEmbeddingsConfiguration configuration)
{
text = text ?? throw new ArgumentNullException(nameof(text));
float[]? embeddings = [];

try
{
string payload = new JsonObject()
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
embeddings = JsonNode.Parse(response.Body)?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

return embeddings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Diagnostics;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class BedrockEmbeddings : BedrockEmbeddingsBase
{
private readonly BedrockEmbeddingsConfiguration _configuration;

public BedrockEmbeddings(string modelId)
{
Id = modelId ?? throw new ArgumentException("ModelId is not defined", nameof(modelId));
_configuration = new BedrockEmbeddingsConfiguration { ModelId = modelId };
}

public BedrockEmbeddings(string modelId, BedrockEmbeddingsConfiguration configuration) : this(modelId)
{
_configuration.Base64Image = configuration.Base64Image;
}

public override async Task<float[][]> EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default)
{
var watch = Stopwatch.StartNew();

var response = await CreateCompletionAsync(texts, _configuration, cancellationToken).ConfigureAwait(false);

watch.Stop();

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response;
}

public override async Task<float[]> EmbedQueryAsync(string text, CancellationToken cancellationToken = default)
{
var watch = Stopwatch.StartNew();

var response = await CreateCompletionAsync(text, _configuration, cancellationToken).ConfigureAwait(false);

watch.Stop();

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response;
}
}
Loading

0 comments on commit 622e0f9

Please sign in to comment.