Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added more Amazon Bedrock models and added SageMaker project #141

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock", "src\libs\Providers\LangChain.Providers.Bedrock\LangChain.Providers.Bedrock.csproj", "{67985CCB-F606-41F8-9D36-513459F58882}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.IntegrationTests", "src\tests\LangChain.Providers.Bedrock.IntegrationTests\LangChain.Providers.Amazon.IntegrationTests.csproj", "{73C76E80-95C5-4C96-A319-4F32043C903E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.Bedrock", "src\libs\Providers\LangChain.Providers.Amazon.Bedrock\LangChain.Providers.Amazon.Bedrock.csproj", "{67985CCB-F606-41F8-9D36-513459F58882}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Amazon.SageMaker", "src\libs\Providers\LangChain.Providers.Amazon.Sagemaker\LangChain.Providers.Amazon.SageMaker.csproj", "{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -406,6 +412,10 @@ Global
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.Build.0 = Release|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand All @@ -414,6 +424,10 @@ Global
{67985CCB-F606-41F8-9D36-513459F58882}.Debug|Any CPU.Build.0 = Debug|Any CPU
{67985CCB-F606-41F8-9D36-513459F58882}.Release|Any CPU.ActiveCfg = Release|Any CPU
{67985CCB-F606-41F8-9D36-513459F58882}.Release|Any CPU.Build.0 = Release|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F1AD6925-219C-4B17-B8D8-0ACCA6F401C4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<PackageVersion Include="Aspose.PDF" Version="23.11.0" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.301.33" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.300.5" />
<PackageVersion Include="AWSSDK.SageMakerRuntime" Version="3.7.301.37" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.12" />
<PackageVersion Include="Docker.DotNet" Version="3.125.15" />
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.1.1" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
namespace LangChain.Providers.Amazon.Bedrock
{
public static class AmazonModelIds
{
public static string AI21LabsJurassic2UltraV1 = "ai21.j2-ultra-v1";
public static string AI21LabsJurassic2MidV1 = "ai21.j2-mid-v1";

public static string AmazonTitanEmbeddingsG1TextV1 = "amazon.titan-embed-text-v1";
public static string AmazonTitanTextG1LiteV1 = "amazon.titan-text-lite-v1";
public static string AmazonTitanTextG1ExpressV1 = "amazon.titan-text-express-v1";
public static string AmazonTitanImageGeneratorG1V1 = "amazon.titan-image-generator-v1";
public static string AmazonTitanMultiModalEmbeddingsG1V1 = "amazon.titan-embed-image-v1";

public static string AnthropicClaude2_1 = "anthropic.claude-v2:1";
public static string AnthropicClaude2 = "anthropic.claude-v2";
public static string AnthropicClaude1_3 = "anthropic.claude-v1";
public static string AnthropicClaudeInstant1_2 = "anthropic.claude-instant-v1";

public static string CohereCommand = "cohere.command-text-v14";
public static string CohereCommandLight = "cohere.command-light-text-v14";
public static string CohereEmbedEnglish= "cohere.embed-english-v3";
public static string CohereEmbedMultilingual = "cohere.embed-multilingual-v3";

public static string MetaLlama2Chat13B = "meta.llama2-13b-chat-v1";
public static string MetaLlama2Chat70B = "meta.llama2-70b-chat-v1";
public static string MetaLlama213B = "meta.llama2-13b-v1"; //TODO i'm guessing the model id
public static string MetaLlama270B = "meta.llama2-70b-v1"; //TODO i'm guessing the model id

public static string StabilityAISDXL0_8 = "stability.stable-diffusion-xl-v0";
public static string StabilityAISDXL1_0 = "stability.stable-diffusion-xl-v1";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Amazon.Util;
using LangChain.TextSplitters;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class AmazonTitanEmbeddingsRequest : IBedrockEmbeddingsRequest
{
public async Task<float[][]> EmbedDocumentsAsync(
AmazonBedrockRuntimeClient client,
string[] texts,
BedrockEmbeddingsConfiguration configuration)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

List<float> arrEmbeddings = [];
var inputText = string.Join(" ", texts);
var textSplitter = new RecursiveCharacterTextSplitter(chunkSize: 10_000);
var splitText = textSplitter.SplitText(inputText);

foreach (var text in splitText)
{
try
{
string payload = new JsonObject
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
var body = JsonNode.Parse(response.Body);
var embeddings = body?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
arrEmbeddings.AddRange(embeddings);
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}
}

var result = arrEmbeddings.Select(f => (new[] { f })).ToArray();

return result;
}

public async Task<float[]> EmbedQueryAsync(
AmazonBedrockRuntimeClient client,
string text,
BedrockEmbeddingsConfiguration configuration)
{
text = text ?? throw new ArgumentNullException(nameof(text));
float[]? embeddings = [];

try
{
string payload = new JsonObject()
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
embeddings = JsonNode.Parse(response.Body)?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

return embeddings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Amazon.Util;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class AmazonTitanMultiModalEmbeddingsRequest : IBedrockEmbeddingsRequest
{
public async Task<float[][]> EmbedDocumentsAsync(
AmazonBedrockRuntimeClient client,
string[] texts,
BedrockEmbeddingsConfiguration configuration)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

List<float> arrEmbeddings = [];
var inputText = string.Join(" ", texts);

try
{
var payload = Encoding.UTF8.GetBytes(
JsonSerializer.Serialize(new
{
inputText = inputText,
inputImage = configuration.Base64Image
})
);

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = new MemoryStream(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
var body = JsonNode.Parse(response.Body);
var embeddings = body?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
arrEmbeddings.AddRange(embeddings);
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

var result = arrEmbeddings.Select(f => (new[] { f })).ToArray();

return result;
}

public async Task<float[]> EmbedQueryAsync(
AmazonBedrockRuntimeClient client,
string text,
BedrockEmbeddingsConfiguration configuration)
{
text = text ?? throw new ArgumentNullException(nameof(text));
float[]? embeddings = [];

try
{
string payload = new JsonObject()
{
{ "inputText", text },
}.ToJsonString();

InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest()
{
ModelId = configuration.ModelId,
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload),
ContentType = "application/json",
Accept = "application/json"
});

if (response.HttpStatusCode == System.Net.HttpStatusCode.OK)
{
embeddings = JsonNode.Parse(response.Body)?["embedding"]
.AsArray()
.Select(x => (float)x.AsValue())
.ToArray();
}
else
{
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode);
}
}
catch (AmazonBedrockRuntimeException e)
{
Console.WriteLine(e.Message);
}

return embeddings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Diagnostics;

namespace LangChain.Providers.Amazon.Bedrock.Embeddings;

public class BedrockEmbeddings : BedrockEmbeddingsBase
{
private readonly BedrockEmbeddingsConfiguration _configuration;

public BedrockEmbeddings(string modelId)
{
Id = modelId ?? throw new ArgumentException("ModelId is not defined", nameof(modelId));
_configuration = new BedrockEmbeddingsConfiguration { ModelId = modelId };
}

public BedrockEmbeddings(string modelId, BedrockEmbeddingsConfiguration configuration) : this(modelId)
{
_configuration.Base64Image = configuration.Base64Image;
}

public override async Task<float[][]> EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default)
{
var watch = Stopwatch.StartNew();

var response = await CreateCompletionAsync(texts, _configuration, cancellationToken).ConfigureAwait(false);

watch.Stop();

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response;
}

public override async Task<float[]> EmbedQueryAsync(string text, CancellationToken cancellationToken = default)
{
var watch = Stopwatch.StartNew();

var response = await CreateCompletionAsync(text, _configuration, cancellationToken).ConfigureAwait(false);

watch.Stop();

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response;
}
}
Loading
Loading