Skip to content

Commit

Permalink
feat: add Azure Embeddings Interface implementation (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
IRooc authored Jan 26, 2024
1 parent d7d7756 commit 26b21ac
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 5 deletions.
1 change: 1 addition & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ Global
{B953ABEC-50DD-4A63-A12A-E82F124C7D5B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{32FC123E-F269-4352-848C-0161B53093CC} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{18F5AAB1-1750-41BD-B623-6339CA5754D9} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{72B1E2CC-1A34-470E-A579-034CB0972BB7} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using Azure;
using Azure.AI.OpenAI;
using OpenAI;
using OpenAI.Constants;
using OpenAI.Embeddings;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static System.Net.Mime.MediaTypeNames;
using OpenAIClient = Azure.AI.OpenAI.OpenAIClient;

namespace LangChain.Providers.Azure;

public partial class AzureOpenAIModel : IEmbeddingModel
{
#region Properties

/// <inheritdoc cref="OpenAiConfiguration.EmbeddingModelId"/>

Check warning on line 21 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

XML comment has cref attribute 'EmbeddingModelId' that could not be resolved

Check warning on line 21 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.Embeddings.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

XML comment has cref attribute 'EmbeddingModelId' that could not be resolved
public string EmbeddingModelId { get; init; } = EmbeddingModel.Ada002;

/// <summary>
/// API has limit of 2048 elements in array per request
/// so we need to split texts into batches
/// https://platform.openai.com/docs/api-reference/embeddings
/// </summary>
public int EmbeddingBatchSize { get; init; } = 2048;

/// <inheritdoc/>
public int MaximumInputLength => ContextLengths.Get(EmbeddingModelId);

#endregion

#region Methods

/// <inheritdoc/>
public async Task<float[]> EmbedQueryAsync(
string text,
CancellationToken cancellationToken = default)
{
var watch = Stopwatch.StartNew();

var embeddingOptions = new EmbeddingsOptions(Id, new[] { text });

var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false);

var usage = GetUsage(response) with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response.Value.Data[0].Embedding.ToArray();
}

/// <inheritdoc/>
public async Task<float[][]> EmbedDocumentsAsync(
string[] texts,
CancellationToken cancellationToken = default)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

var watch = Stopwatch.StartNew();

var index = 0;
var batches = new List<string[]>();
while (index < texts.Length)
{
batches.Add(texts.Skip(index).Take(EmbeddingBatchSize).ToArray());
index += EmbeddingBatchSize;
}

var results = await Task.WhenAll(batches.Select(async batch =>
{
var watch = Stopwatch.StartNew();
var embeddingOptions = new EmbeddingsOptions(Id, batch);

var response = await Client.GetEmbeddingsAsync(embeddingOptions, cancellationToken).ConfigureAwait(false);

var usage = GetUsage(response) with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}

return response.Value.Data
.Select(x => x.Embedding.ToArray())
.ToArray();
})).ConfigureAwait(false);

var rr = results
.SelectMany(x => x.ToArray())
.ToArray();
return rr;
}



private Usage GetUsage(Response<Embeddings>? response)
{
if (response?.Value?.Usage == null!)
{
return Usage.Empty;
}
var tokens = response.Value?.Usage.PromptTokens ?? 0;
var priceInUsd = EmbeddingPrices.TryGet(
model: new EmbeddingModel(EmbeddingModelId),
tokens: tokens) ?? 0.0D;

return Usage.Empty with
{
InputTokens = tokens,
PriceInUsd = priceInUsd,
};
}
#endregion
}
12 changes: 7 additions & 5 deletions src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ namespace LangChain.Providers.Azure;
/// <summary>
/// Wrapper around Azure OpenAI large language models
/// </summary>
public class AzureOpenAIModel : IChatModel
public partial class AzureOpenAIModel : IChatModel
{
private readonly object _usageLock = new();
/// <summary>
/// Azure OpenAI API Key
/// </summary>
Expand All @@ -29,7 +30,7 @@ public class AzureOpenAIModel : IChatModel
/// Azure OpenAI Resource URI
/// </summary>
public string Endpoint { get; set; }

public OpenAIClient Client { get; set; }

Check warning on line 33 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Type of 'AzureOpenAIModel.Client' is not CLS-compliant

Check warning on line 33 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Type of 'AzureOpenAIModel.Client' is not CLS-compliant

Check warning on line 33 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Type of 'AzureOpenAIModel.Client' is not CLS-compliant
private AzureOpenAIConfiguration Configurations { get; }

#region Constructors
Expand All @@ -46,6 +47,7 @@ public AzureOpenAIModel(string apiKey, string endpoint, string id)
Id = id ?? throw new ArgumentNullException(nameof(id));
ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
}

/// <summary>
Expand All @@ -60,6 +62,7 @@ public AzureOpenAIModel(AzureOpenAIConfiguration configuration)
ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration));
Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration));
Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration));
Client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
}
#endregion

Expand Down Expand Up @@ -94,9 +97,8 @@ private async Task<Response<ChatCompletions>> CreateChatCompleteAsync(IReadOnlyC
ChoiceCount = Configurations.ChoiceCount,
Temperature = Configurations.Temperature,
};

var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);

return await Client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);
}

private static ChatRequestMessage ToRequestMessage(Message message)
Expand Down

0 comments on commit 26b21ac

Please sign in to comment.