-
-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,8 @@ | ||
// using LangChain.Providers; | ||
// using LangChain.Providers.Azure; | ||
// | ||
// using var httpClient = new HttpClient(); | ||
// var model = new Gpt35TurboModel("apiKey", "endpoint", new HttpClient()); | ||
// var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?"); | ||
// | ||
// Console.WriteLine(result); | ||
Console.WriteLine("Not implemented"); | ||
using LangChain.Providers; | ||
using LangChain.Providers.Azure; | ||
|
||
var model = new AzureOpenAIModel("AZURE_OPEN_AI_KEY", "ENDPOINT", "DEPLOYMENT_NAME"); | ||
|
||
var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?"); | ||
|
||
Console.WriteLine(result); |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
namespace LangChain.Providers | ||
{ | ||
/// <summary> | ||
/// Configuration options for Azure OpenAI | ||
/// </summary> | ||
public class AzureOpenAIConfiguration | ||
{ | ||
/// <summary> | ||
/// Context size | ||
/// How much tokens model will remember. | ||
/// Most models have 2048 | ||
/// </summary> | ||
public int ContextSize { get; set; } = 2048; | ||
|
||
/// <summary> | ||
/// Temperature | ||
/// controls the apparent creativity of generated completions. | ||
/// Has a valid range of 0.0 to 2.0 | ||
/// Defaults to 1.0 if not otherwise specified. | ||
/// </summary> | ||
public float Temperature { get; set; } = 0.7f; | ||
|
||
/// <summary> | ||
/// Gets the maximum number of tokens to generate. Has minimum of 0. | ||
/// </summary> | ||
public int MaxTokens { get; set; } = 800; | ||
|
||
/// <summary> | ||
/// Number of choices that should be generated per provided prompt. | ||
/// Has a valid range of 1 to 128. | ||
/// </summary> | ||
public int ChoiceCount { get; set; } = 1; | ||
|
||
/// <summary> | ||
/// Azure OpenAI API Key | ||
/// </summary> | ||
public string? ApiKey { get; set; } | ||
|
||
/// <summary> | ||
/// Deployment name | ||
/// </summary> | ||
public string Id { get; set; } | ||
Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs
|
||
|
||
/// <summary> | ||
/// Azure OpenAI Resource URI | ||
/// </summary> | ||
public string Endpoint { get; set; } | ||
Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
using Azure; | ||
using Azure.AI.OpenAI; | ||
using System.Diagnostics; | ||
|
||
namespace LangChain.Providers.Azure; | ||
|
||
/// <summary> | ||
/// Wrapper around Azure OpenAI large language models | ||
/// </summary> | ||
public class AzureOpenAIModel : IChatModel | ||
{ | ||
/// <summary> | ||
/// Azure OpenAI API Key | ||
/// </summary> | ||
public string ApiKey { get; init; } | ||
|
||
/// <inheritdoc/> | ||
public Usage TotalUsage { get; private set; } | ||
|
||
/// <summary> | ||
/// Deployment name | ||
/// </summary> | ||
public string Id { get; init; } | ||
|
||
/// <inheritdoc/> | ||
public int ContextLength => Configurations.ContextSize; | ||
|
||
/// <summary> | ||
/// Azure OpenAI Resource URI | ||
/// </summary> | ||
public string Endpoint { get; set; } | ||
|
||
private AzureOpenAIConfiguration Configurations { get; } | ||
|
||
#region Constructors | ||
/// <summary> | ||
/// Wrapper around Azure OpenAI | ||
/// </summary> | ||
/// <param name="apiKey">API Key</param> | ||
/// <param name="endpoint">Azure Open AI Resource URI</param> | ||
/// <param name="id">Deployment Model name</param> | ||
/// <exception cref="ArgumentNullException"></exception> | ||
public AzureOpenAIModel(string apiKey, string endpoint, string id) | ||
{ | ||
Configurations = new AzureOpenAIConfiguration(); | ||
Id = id ?? throw new ArgumentNullException(nameof(id)); | ||
ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey)); | ||
Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); | ||
} | ||
|
||
/// <summary> | ||
/// Wrapper around Azure OpenAI | ||
/// </summary> | ||
/// <param name="configuration">AzureOpenAIConfiguration</param> | ||
/// <exception cref="ArgumentNullException"></exception> | ||
/// <exception cref="ArgumentException"></exception> | ||
public AzureOpenAIModel(AzureOpenAIConfiguration configuration) | ||
{ | ||
Configurations = configuration ?? throw new ArgumentNullException(nameof(configuration)); | ||
ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration)); | ||
Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration)); | ||
Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration)); | ||
} | ||
#endregion | ||
|
||
#region Methods | ||
/// <inheritdoc/> | ||
public async Task<ChatResponse> GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default) | ||
{ | ||
var messages = request.Messages.ToList(); | ||
var watch = Stopwatch.StartNew(); | ||
var response = await CreateChatCompleteAsync(messages, cancellationToken).ConfigureAwait(false); | ||
|
||
messages.Add(ToMessage(response.Value)); | ||
|
||
watch.Stop(); | ||
|
||
var usage = GetUsage(response.Value.Usage) with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
TotalUsage += usage; | ||
|
||
return new ChatResponse( | ||
Messages: messages, | ||
Usage: usage); | ||
} | ||
|
||
private async Task<Response<ChatCompletions>> CreateChatCompleteAsync(IReadOnlyCollection<Message> messages, CancellationToken cancellationToken = default) | ||
{ | ||
var chatCompletionOptions = new ChatCompletionsOptions(Id, messages.Select(ToRequestMessage)) | ||
{ | ||
MaxTokens = Configurations.MaxTokens, | ||
ChoiceCount = Configurations.ChoiceCount, | ||
Temperature = Configurations.Temperature, | ||
}; | ||
|
||
var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey)); | ||
return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false); | ||
} | ||
|
||
private static ChatRequestMessage ToRequestMessage(Message message) | ||
{ | ||
return message.Role switch | ||
{ | ||
MessageRole.System => new ChatRequestSystemMessage(message.Content), | ||
MessageRole.Ai => new ChatRequestAssistantMessage(message.Content), | ||
MessageRole.Human => new ChatRequestUserMessage(message.Content), | ||
MessageRole.FunctionCall => throw new NotImplementedException(), | ||
MessageRole.FunctionResult => throw new NotImplementedException(), | ||
_ => throw new NotImplementedException() | ||
}; | ||
} | ||
|
||
private static Message ToMessage(ChatCompletions message) | ||
{ | ||
return new Message( | ||
Content: message.Choices[0].Message.Content, | ||
Role: MessageRole.Ai); | ||
} | ||
|
||
private static Usage GetUsage(CompletionsUsage usage) | ||
{ | ||
return Usage.Empty with | ||
{ | ||
InputTokens = usage.PromptTokens, | ||
OutputTokens = usage.CompletionTokens | ||
}; | ||
} | ||
#endregion | ||
} |
This file was deleted.
This file was deleted.
This file was deleted.