diff --git a/LangChain.sln b/LangChain.sln
index d9565651..2e218ebf 100644
--- a/LangChain.sln
+++ b/LangChain.sln
@@ -465,6 +465,7 @@ Global
{B953ABEC-50DD-4A63-A12A-E82F124C7D5B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{32FC123E-F269-4352-848C-0161B53093CC} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
+ {18F5AAB1-1750-41BD-B623-6339CA5754D9} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{72B1E2CC-1A34-470E-A579-034CB0972BB7} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs
new file mode 100644
index 00000000..c666cbc2
--- /dev/null
+++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs
@@ -0,0 +1,125 @@
+using Azure;
+using Azure.AI.OpenAI;
+using OpenAI;
+using OpenAI.Constants;
+using OpenAI.Embeddings;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using static System.Net.Mime.MediaTypeNames;
+using OpenAIClient = Azure.AI.OpenAI.OpenAIClient;
+
+namespace LangChain.Providers.Azure;
+
+public partial class AzureOpenAIModel : IEmbeddingModel
+{
+ #region Properties
+
+ ///
+ public string EmbeddingModelId { get; init; } = EmbeddingModel.Ada002;
+
+ ///
+ /// API has limit of 2048 elements in array per request
+ /// so we need to split texts into batches
+ /// https://platform.openai.com/docs/api-reference/embeddings
+ ///
+ public int EmbeddingBatchSize { get; init; } = 2048;
+
+ ///
+ public int MaximumInputLength => ContextLengths.Get(EmbeddingModelId);
+
+ #endregion
+
+ #region Methods
+
+ ///
+ public async Task EmbedQueryAsync(
+ string text,
+ CancellationToken cancellationToken = default)
+ {
+ var watch = Stopwatch.StartNew();
+
+ var embeddingOptions = new EmbeddingsOptions(Id, new[] { text });
+
+ var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false);
+
+ var usage = GetUsage(response) with
+ {
+ Time = watch.Elapsed,
+ };
+ lock (_usageLock)
+ {
+ TotalUsage += usage;
+ }
+
+ return response.Value.Data[0].Embedding.ToArray();
+ }
+
+ ///
+ public async Task EmbedDocumentsAsync(
+ string[] texts,
+ CancellationToken cancellationToken = default)
+ {
+ texts = texts ?? throw new ArgumentNullException(nameof(texts));
+
+ var watch = Stopwatch.StartNew();
+
+ var index = 0;
+ var batches = new List();
+ while (index < texts.Length)
+ {
+ batches.Add(texts.Skip(index).Take(EmbeddingBatchSize).ToArray());
+ index += EmbeddingBatchSize;
+ }
+
+ var results = await Task.WhenAll(batches.Select(async batch =>
+ {
+ var watch = Stopwatch.StartNew();
+ var embeddingOptions = new EmbeddingsOptions(Id, batch);
+
+ var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false);
+
+ var usage = GetUsage(response) with
+ {
+ Time = watch.Elapsed,
+ };
+ lock (_usageLock)
+ {
+ TotalUsage += usage;
+ }
+
+ return response.Value.Data
+ .Select(x => x.Embedding.ToArray())
+ .ToArray();
+ })).ConfigureAwait(false);
+
+ var rr = results
+ .SelectMany(x => x.ToArray())
+ .ToArray();
+ return rr;
+ }
+
+
+
+ private Usage GetUsage(Response? response)
+ {
+ if (response?.Value?.Usage == null!)
+ {
+ return Usage.Empty;
+ }
+ var tokens = response.Value?.Usage.PromptTokens ?? 0;
+ var priceInUsd = EmbeddingPrices.TryGet(
+ model: new EmbeddingModel(EmbeddingModelId),
+ tokens: tokens) ?? 0.0D;
+
+ return Usage.Empty with
+ {
+ InputTokens = tokens,
+ PriceInUsd = priceInUsd,
+ };
+ }
+ #endregion
+}
diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
index 005c5175..4c3c13a3 100644
--- a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
+++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
@@ -7,8 +7,9 @@ namespace LangChain.Providers.Azure;
///
/// Wrapper around Azure OpenAI large language models
///
-public class AzureOpenAIModel : IChatModel
+public partial class AzureOpenAIModel : IChatModel
{
+ private readonly object _usageLock = new();
///
/// Azure OpenAI API Key
///
@@ -29,7 +30,7 @@ public class AzureOpenAIModel : IChatModel
/// Azure OpenAI Resource URI
///
public string Endpoint { get; set; }
-
+ public OpenAIClient Client { get; set; }
private AzureOpenAIConfiguration Configurations { get; }
#region Constructors
@@ -46,6 +47,7 @@ public AzureOpenAIModel(string apiKey, string endpoint, string id)
Id = id ?? throw new ArgumentNullException(nameof(id));
ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
+ Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
}
///
@@ -60,6 +62,7 @@ public AzureOpenAIModel(AzureOpenAIConfiguration configuration)
ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration));
Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration));
Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration));
+ Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
}
#endregion
@@ -94,9 +97,8 @@ private async Task> CreateChatCompleteAsync(IReadOnlyC
ChoiceCount = Configurations.ChoiceCount,
Temperature = Configurations.Temperature,
};
-
- var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
- return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);
+
+ return await Client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);
}
private static ChatRequestMessage ToRequestMessage(Message message)