-
-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Azure Embeddings Interface implementation (#116)
- Loading branch information
Showing
3 changed files
with
133 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
using Azure; | ||
using Azure.AI.OpenAI; | ||
using OpenAI; | ||
using OpenAI.Constants; | ||
using OpenAI.Embeddings; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
using static System.Net.Mime.MediaTypeNames; | ||
using OpenAIClient = Azure.AI.OpenAI.OpenAIClient; | ||
|
||
namespace LangChain.Providers.Azure; | ||
|
||
public partial class AzureOpenAIModel : IEmbeddingModel | ||
{ | ||
#region Properties | ||
|
||
/// <inheritdoc cref="OpenAiConfiguration.EmbeddingModelId"/> | ||
Check warning on line 21 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs GitHub Actions / Build, test and publish / Build, test and publish
|
||
public string EmbeddingModelId { get; init; } = EmbeddingModel.Ada002; | ||
|
||
/// <summary> | ||
/// API has limit of 2048 elements in array per request | ||
/// so we need to split texts into batches | ||
/// https://platform.openai.com/docs/api-reference/embeddings | ||
/// </summary> | ||
public int EmbeddingBatchSize { get; init; } = 2048; | ||
|
||
/// <inheritdoc/> | ||
public int MaximumInputLength => ContextLengths.Get(EmbeddingModelId); | ||
|
||
#endregion | ||
|
||
#region Methods | ||
|
||
/// <inheritdoc/> | ||
public async Task<float[]> EmbedQueryAsync( | ||
string text, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var watch = Stopwatch.StartNew(); | ||
|
||
var embeddingOptions = new EmbeddingsOptions(Id, new[] { text }); | ||
|
||
var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false); | ||
|
||
var usage = GetUsage(response) with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
lock (_usageLock) | ||
{ | ||
TotalUsage += usage; | ||
} | ||
|
||
return response.Value.Data[0].Embedding.ToArray(); | ||
} | ||
|
||
/// <inheritdoc/> | ||
public async Task<float[][]> EmbedDocumentsAsync( | ||
string[] texts, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
texts = texts ?? throw new ArgumentNullException(nameof(texts)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
|
||
var index = 0; | ||
var batches = new List<string[]>(); | ||
while (index < texts.Length) | ||
{ | ||
batches.Add(texts.Skip(index).Take(EmbeddingBatchSize).ToArray()); | ||
index += EmbeddingBatchSize; | ||
} | ||
|
||
var results = await Task.WhenAll(batches.Select(async batch => | ||
{ | ||
var watch = Stopwatch.StartNew(); | ||
var embeddingOptions = new EmbeddingsOptions(Id, batch); | ||
|
||
var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false); | ||
|
||
var usage = GetUsage(response) with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
lock (_usageLock) | ||
{ | ||
TotalUsage += usage; | ||
} | ||
|
||
return response.Value.Data | ||
.Select(x => x.Embedding.ToArray()) | ||
.ToArray(); | ||
})).ConfigureAwait(false); | ||
|
||
var rr = results | ||
.SelectMany(x => x.ToArray()) | ||
.ToArray(); | ||
return rr; | ||
} | ||
|
||
|
||
|
||
private Usage GetUsage(Response<Embeddings>? response) | ||
{ | ||
if (response?.Value?.Usage == null!) | ||
{ | ||
return Usage.Empty; | ||
} | ||
var tokens = response.Value?.Usage.PromptTokens ?? 0; | ||
var priceInUsd = EmbeddingPrices.TryGet( | ||
model: new EmbeddingModel(EmbeddingModelId), | ||
tokens: tokens) ?? 0.0D; | ||
|
||
return Usage.Empty with | ||
{ | ||
InputTokens = tokens, | ||
PriceInUsd = priceInUsd, | ||
}; | ||
} | ||
#endregion | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters