diff --git a/LangChain.sln b/LangChain.sln index d9565651..2e218ebf 100644 --- a/LangChain.sln +++ b/LangChain.sln @@ -465,6 +465,7 @@ Global {B953ABEC-50DD-4A63-A12A-E82F124C7D5B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} {32FC123E-F269-4352-848C-0161B53093CC} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} {DEAFA0CB-462D-4D74-B16F-68FD83FE3858} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} + {18F5AAB1-1750-41BD-B623-6339CA5754D9} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} {72B1E2CC-1A34-470E-A579-034CB0972BB7} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} {4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} {BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs new file mode 100644 index 00000000..c666cbc2 --- /dev/null +++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs @@ -0,0 +1,125 @@ +using Azure; +using Azure.AI.OpenAI; +using OpenAI; +using OpenAI.Constants; +using OpenAI.Embeddings; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static System.Net.Mime.MediaTypeNames; +using OpenAIClient = Azure.AI.OpenAI.OpenAIClient; + +namespace LangChain.Providers.Azure; + +public partial class AzureOpenAIModel : IEmbeddingModel +{ + #region Properties + + /// + public string EmbeddingModelId { get; init; } = EmbeddingModel.Ada002; + + /// + /// API has limit of 2048 elements in array per request + /// so we need to split texts into batches + /// https://platform.openai.com/docs/api-reference/embeddings + /// + public int EmbeddingBatchSize { get; init; } = 2048; + + /// + public int MaximumInputLength => ContextLengths.Get(EmbeddingModelId); + + #endregion + + #region Methods + + /// + public async Task EmbedQueryAsync( + string text, + CancellationToken cancellationToken = default) + { + var watch = Stopwatch.StartNew(); + + var embeddingOptions = new EmbeddingsOptions(Id, new[] { text }); + + var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false); + + var usage = GetUsage(response) with + { + Time = watch.Elapsed, + }; + lock (_usageLock) + { + TotalUsage += usage; + } + + return response.Value.Data[0].Embedding.ToArray(); + } + + /// + public async Task EmbedDocumentsAsync( + string[] texts, + CancellationToken cancellationToken = default) + { + texts = texts ?? throw new ArgumentNullException(nameof(texts)); + + var watch = Stopwatch.StartNew(); + + var index = 0; + var batches = new List(); + while (index < texts.Length) + { + batches.Add(texts.Skip(index).Take(EmbeddingBatchSize).ToArray()); + index += EmbeddingBatchSize; + } + + var results = await Task.WhenAll(batches.Select(async batch => + { + var watch = Stopwatch.StartNew(); + var embeddingOptions = new EmbeddingsOptions(Id, batch); + + var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false); + + var usage = GetUsage(response) with + { + Time = watch.Elapsed, + }; + lock (_usageLock) + { + TotalUsage += usage; + } + + return response.Value.Data + .Select(x => x.Embedding.ToArray()) + .ToArray(); + })).ConfigureAwait(false); + + var rr = results + .SelectMany(x => x.ToArray()) + .ToArray(); + return rr; + } + + + + private Usage GetUsage(Response? response) + { + if (response?.Value?.Usage == null!) + { + return Usage.Empty; + } + var tokens = response.Value?.Usage.PromptTokens ?? 0; + var priceInUsd = EmbeddingPrices.TryGet( + model: new EmbeddingModel(EmbeddingModelId), + tokens: tokens) ?? 0.0D; + + return Usage.Empty with + { + InputTokens = tokens, + PriceInUsd = priceInUsd, + }; + } + #endregion +} diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs index 005c5175..4c3c13a3 100644 --- a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs +++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs @@ -7,8 +7,9 @@ namespace LangChain.Providers.Azure; /// /// Wrapper around Azure OpenAI large language models /// -public class AzureOpenAIModel : IChatModel +public partial class AzureOpenAIModel : IChatModel { + private readonly object _usageLock = new(); /// /// Azure OpenAI API Key /// @@ -29,7 +30,7 @@ public class AzureOpenAIModel : IChatModel /// Azure OpenAI Resource URI /// public string Endpoint { get; set; } - + public OpenAIClient Client { get; set; } private AzureOpenAIConfiguration Configurations { get; } #region Constructors @@ -46,6 +47,7 @@ public AzureOpenAIModel(string apiKey, string endpoint, string id) Id = id ?? throw new ArgumentNullException(nameof(id)); ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey)); Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); + Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey)); } /// @@ -60,6 +62,7 @@ public AzureOpenAIModel(AzureOpenAIConfiguration configuration) ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration)); Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration)); Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration)); + Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey)); } #endregion @@ -94,9 +97,8 @@ private async Task> CreateChatCompleteAsync(IReadOnlyC ChoiceCount = Configurations.ChoiceCount, Temperature = Configurations.Temperature, }; - - var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey)); - return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false); + + return await Client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false); } private static ChatRequestMessage ToRequestMessage(Message message)