-
-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added more Amazon Bedrock models and added SageMaker project (#141
) * feat: added AWS Bedrock provider and tests. TODO add Amazon Titan text and image models and Cohere models * updated Bedrock and added a project for SageMaker --------- Co-authored-by: Ty Augustine <[email protected]> Co-authored-by: Konstantin S <[email protected]>
- Loading branch information
1 parent
5d98e93
commit 622e0f9
Showing
37 changed files
with
1,595 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
src/libs/Providers/LangChain.Providers.Amazon.Bedrock/AmazonModelNames.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
namespace LangChain.Providers.Amazon.Bedrock | ||
{ | ||
public static class AmazonModelIds | ||
{ | ||
public static string AI21LabsJurassic2UltraV1 = "ai21.j2-ultra-v1"; | ||
public static string AI21LabsJurassic2MidV1 = "ai21.j2-mid-v1"; | ||
|
||
public static string AmazonTitanEmbeddingsG1TextV1 = "amazon.titan-embed-text-v1"; | ||
public static string AmazonTitanTextG1LiteV1 = "amazon.titan-text-lite-v1"; | ||
public static string AmazonTitanTextG1ExpressV1 = "amazon.titan-text-express-v1"; | ||
public static string AmazonTitanImageGeneratorG1V1 = "amazon.titan-image-generator-v1"; | ||
public static string AmazonTitanMultiModalEmbeddingsG1V1 = "amazon.titan-embed-image-v1"; | ||
|
||
public static string AnthropicClaude2_1 = "anthropic.claude-v2:1"; | ||
public static string AnthropicClaude2 = "anthropic.claude-v2"; | ||
public static string AnthropicClaude1_3 = "anthropic.claude-v1"; | ||
public static string AnthropicClaudeInstant1_2 = "anthropic.claude-instant-v1"; | ||
|
||
public static string CohereCommand = "cohere.command-text-v14"; | ||
public static string CohereCommandLight = "cohere.command-light-text-v14"; | ||
public static string CohereEmbedEnglish= "cohere.embed-english-v3"; | ||
public static string CohereEmbedMultilingual = "cohere.embed-multilingual-v3"; | ||
|
||
public static string MetaLlama2Chat13B = "meta.llama2-13b-chat-v1"; | ||
public static string MetaLlama2Chat70B = "meta.llama2-70b-chat-v1"; | ||
public static string MetaLlama213B = "meta.llama2-13b-v1"; //TODO i'm guessing the model id | ||
public static string MetaLlama270B = "meta.llama2-70b-v1"; //TODO i'm guessing the model id | ||
|
||
public static string StabilityAISDXL0_8 = "stability.stable-diffusion-xl-v0"; | ||
public static string StabilityAISDXL1_0 = "stability.stable-diffusion-xl-v1"; | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
...s/Providers/LangChain.Providers.Amazon.Bedrock/Embeddings/AmazonTitanEmbeddingsRequest.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
using System.Text.Json.Nodes; | ||
using Amazon.BedrockRuntime; | ||
using Amazon.BedrockRuntime.Model; | ||
using Amazon.Util; | ||
using LangChain.TextSplitters; | ||
|
||
namespace LangChain.Providers.Amazon.Bedrock.Embeddings; | ||
|
||
public class AmazonTitanEmbeddingsRequest : IBedrockEmbeddingsRequest | ||
{ | ||
public async Task<float[][]> EmbedDocumentsAsync( | ||
AmazonBedrockRuntimeClient client, | ||
string[] texts, | ||
BedrockEmbeddingsConfiguration configuration) | ||
{ | ||
texts = texts ?? throw new ArgumentNullException(nameof(texts)); | ||
|
||
List<float> arrEmbeddings = []; | ||
var inputText = string.Join(" ", texts); | ||
var textSplitter = new RecursiveCharacterTextSplitter(chunkSize: 10_000); | ||
var splitText = textSplitter.SplitText(inputText); | ||
|
||
foreach (var text in splitText) | ||
{ | ||
try | ||
{ | ||
string payload = new JsonObject | ||
{ | ||
{ "inputText", text }, | ||
}.ToJsonString(); | ||
|
||
InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest() | ||
{ | ||
ModelId = configuration.ModelId, | ||
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload), | ||
ContentType = "application/json", | ||
Accept = "application/json" | ||
}); | ||
|
||
if (response.HttpStatusCode == System.Net.HttpStatusCode.OK) | ||
{ | ||
var body = JsonNode.Parse(response.Body); | ||
var embeddings = body?["embedding"] | ||
.AsArray() | ||
.Select(x => (float)x.AsValue()) | ||
.ToArray(); | ||
arrEmbeddings.AddRange(embeddings); | ||
} | ||
else | ||
{ | ||
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode); | ||
} | ||
} | ||
catch (AmazonBedrockRuntimeException e) | ||
{ | ||
Console.WriteLine(e.Message); | ||
} | ||
} | ||
|
||
var result = arrEmbeddings.Select(f => (new[] { f })).ToArray(); | ||
|
||
return result; | ||
} | ||
|
||
public async Task<float[]> EmbedQueryAsync( | ||
AmazonBedrockRuntimeClient client, | ||
string text, | ||
BedrockEmbeddingsConfiguration configuration) | ||
{ | ||
text = text ?? throw new ArgumentNullException(nameof(text)); | ||
float[]? embeddings = []; | ||
|
||
try | ||
{ | ||
string payload = new JsonObject() | ||
{ | ||
{ "inputText", text }, | ||
}.ToJsonString(); | ||
|
||
InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest() | ||
{ | ||
ModelId = configuration.ModelId, | ||
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload), | ||
ContentType = "application/json", | ||
Accept = "application/json" | ||
}); | ||
|
||
if (response.HttpStatusCode == System.Net.HttpStatusCode.OK) | ||
{ | ||
embeddings = JsonNode.Parse(response.Body)?["embedding"] | ||
.AsArray() | ||
.Select(x => (float)x.AsValue()) | ||
.ToArray(); | ||
} | ||
else | ||
{ | ||
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode); | ||
} | ||
} | ||
catch (AmazonBedrockRuntimeException e) | ||
{ | ||
Console.WriteLine(e.Message); | ||
} | ||
|
||
return embeddings; | ||
} | ||
} |
106 changes: 106 additions & 0 deletions
106
...s/LangChain.Providers.Amazon.Bedrock/Embeddings/AmazonTitanMultiModalEmbeddingsRequest.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Text.Json.Nodes; | ||
using Amazon.BedrockRuntime; | ||
using Amazon.BedrockRuntime.Model; | ||
using Amazon.Util; | ||
|
||
namespace LangChain.Providers.Amazon.Bedrock.Embeddings; | ||
|
||
public class AmazonTitanMultiModalEmbeddingsRequest : IBedrockEmbeddingsRequest | ||
{ | ||
public async Task<float[][]> EmbedDocumentsAsync( | ||
AmazonBedrockRuntimeClient client, | ||
string[] texts, | ||
BedrockEmbeddingsConfiguration configuration) | ||
{ | ||
texts = texts ?? throw new ArgumentNullException(nameof(texts)); | ||
|
||
List<float> arrEmbeddings = []; | ||
var inputText = string.Join(" ", texts); | ||
|
||
try | ||
{ | ||
var payload = Encoding.UTF8.GetBytes( | ||
JsonSerializer.Serialize(new | ||
{ | ||
inputText = inputText, | ||
inputImage = configuration.Base64Image | ||
}) | ||
); | ||
|
||
InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest() | ||
{ | ||
ModelId = configuration.ModelId, | ||
Body = new MemoryStream(payload), | ||
ContentType = "application/json", | ||
Accept = "application/json" | ||
}); | ||
|
||
if (response.HttpStatusCode == System.Net.HttpStatusCode.OK) | ||
{ | ||
var body = JsonNode.Parse(response.Body); | ||
var embeddings = body?["embedding"] | ||
.AsArray() | ||
.Select(x => (float)x.AsValue()) | ||
.ToArray(); | ||
arrEmbeddings.AddRange(embeddings); | ||
} | ||
else | ||
{ | ||
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode); | ||
} | ||
} | ||
catch (AmazonBedrockRuntimeException e) | ||
{ | ||
Console.WriteLine(e.Message); | ||
} | ||
|
||
var result = arrEmbeddings.Select(f => (new[] { f })).ToArray(); | ||
|
||
return result; | ||
} | ||
|
||
public async Task<float[]> EmbedQueryAsync( | ||
AmazonBedrockRuntimeClient client, | ||
string text, | ||
BedrockEmbeddingsConfiguration configuration) | ||
{ | ||
text = text ?? throw new ArgumentNullException(nameof(text)); | ||
float[]? embeddings = []; | ||
|
||
try | ||
{ | ||
string payload = new JsonObject() | ||
{ | ||
{ "inputText", text }, | ||
}.ToJsonString(); | ||
|
||
InvokeModelResponse response = await client.InvokeModelAsync(new InvokeModelRequest() | ||
{ | ||
ModelId = configuration.ModelId, | ||
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload), | ||
ContentType = "application/json", | ||
Accept = "application/json" | ||
}); | ||
|
||
if (response.HttpStatusCode == System.Net.HttpStatusCode.OK) | ||
{ | ||
embeddings = JsonNode.Parse(response.Body)?["embedding"] | ||
.AsArray() | ||
.Select(x => (float)x.AsValue()) | ||
.ToArray(); | ||
} | ||
else | ||
{ | ||
Console.WriteLine("InvokeModelAsync failed with status code " + response.HttpStatusCode); | ||
} | ||
} | ||
catch (AmazonBedrockRuntimeException e) | ||
{ | ||
Console.WriteLine(e.Message); | ||
} | ||
|
||
return embeddings; | ||
} | ||
} |
61 changes: 61 additions & 0 deletions
61
src/libs/Providers/LangChain.Providers.Amazon.Bedrock/Embeddings/BedrockEmbeddings.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
using System.Diagnostics; | ||
|
||
namespace LangChain.Providers.Amazon.Bedrock.Embeddings; | ||
|
||
public class BedrockEmbeddings : BedrockEmbeddingsBase | ||
{ | ||
private readonly BedrockEmbeddingsConfiguration _configuration; | ||
|
||
public BedrockEmbeddings(string modelId) | ||
{ | ||
Id = modelId ?? throw new ArgumentException("ModelId is not defined", nameof(modelId)); | ||
_configuration = new BedrockEmbeddingsConfiguration { ModelId = modelId }; | ||
} | ||
|
||
public BedrockEmbeddings(string modelId, BedrockEmbeddingsConfiguration configuration) : this(modelId) | ||
{ | ||
_configuration.Base64Image = configuration.Base64Image; | ||
} | ||
|
||
public override async Task<float[][]> EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default) | ||
{ | ||
var watch = Stopwatch.StartNew(); | ||
|
||
var response = await CreateCompletionAsync(texts, _configuration, cancellationToken).ConfigureAwait(false); | ||
|
||
watch.Stop(); | ||
|
||
// Unsupported | ||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
lock (_usageLock) | ||
{ | ||
TotalUsage += usage; | ||
} | ||
|
||
return response; | ||
} | ||
|
||
public override async Task<float[]> EmbedQueryAsync(string text, CancellationToken cancellationToken = default) | ||
{ | ||
var watch = Stopwatch.StartNew(); | ||
|
||
var response = await CreateCompletionAsync(text, _configuration, cancellationToken).ConfigureAwait(false); | ||
|
||
watch.Stop(); | ||
|
||
// Unsupported | ||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
lock (_usageLock) | ||
{ | ||
TotalUsage += usage; | ||
} | ||
|
||
return response; | ||
} | ||
} |
Oops, something went wrong.