Skip to content

Commit

Permalink
feat: Added support for Azure OpenAI provider (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
kharedev247 authored Jan 5, 2024
1 parent 7f91993 commit a74358d
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 100 deletions.
6 changes: 6 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.UnitTests", "src\tests\LangChain.Providers.Google.UnitTests\LangChain.Providers.Google.UnitTests.csproj", "{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -372,6 +374,10 @@ Global
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.Build.0 = Release|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\libs\LangChain\LangChain.csproj" />
<ProjectReference Include="..\..\src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj" />
</ItemGroup>

</Project>
17 changes: 8 additions & 9 deletions examples/LangChain.Samples.Azure/Program.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// using LangChain.Providers;
// using LangChain.Providers.Azure;
//
// using var httpClient = new HttpClient();
// var model = new Gpt35TurboModel("apiKey", "endpoint", new HttpClient());
// var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?");
//
// Console.WriteLine(result);
Console.WriteLine("Not implemented");
using LangChain.Providers;
using LangChain.Providers.Azure;

var model = new AzureOpenAIModel("AZURE_OPEN_AI_KEY", "ENDPOINT", "DEPLOYMENT_NAME");

var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?");

Console.WriteLine(result);
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<PackageVersion Include="Anyscale" Version="1.0.2" />
<PackageVersion Include="Aspose.PDF" Version="23.11.0" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.300.5" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.12" />
<PackageVersion Include="Docker.DotNet" Version="3.125.15" />
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.1.1" />
<PackageVersion Include="FluentAssertions" Version="6.12.0" />
Expand Down
36 changes: 0 additions & 36 deletions src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
namespace LangChain.Providers
{
/// <summary>
/// Configuration options for Azure OpenAI
/// </summary>
public class AzureOpenAIConfiguration
{
/// <summary>
/// Context size
/// How much tokens model will remember.
/// Most models have 2048
/// </summary>
public int ContextSize { get; set; } = 2048;

/// <summary>
/// Temperature
/// controls the apparent creativity of generated completions.
/// Has a valid range of 0.0 to 2.0
/// Defaults to 1.0 if not otherwise specified.
/// </summary>
public float Temperature { get; set; } = 0.7f;

/// <summary>
/// Gets the maximum number of tokens to generate. Has minimum of 0.
/// </summary>
public int MaxTokens { get; set; } = 800;

/// <summary>
/// Number of choices that should be generated per provided prompt.
/// Has a valid range of 1 to 128.
/// </summary>
public int ChoiceCount { get; set; } = 1;

/// <summary>
/// Azure OpenAI API Key
/// </summary>
public string? ApiKey { get; set; }

/// <summary>
/// Deployment name
/// </summary>
public string Id { get; set; }

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

/// <summary>
/// Azure OpenAI Resource URI
/// </summary>
public string Endpoint { get; set; }

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
131 changes: 131 additions & 0 deletions src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
using Azure;
using Azure.AI.OpenAI;
using System.Diagnostics;

namespace LangChain.Providers.Azure;

/// <summary>
/// Wrapper around Azure OpenAI large language models
/// </summary>
public class AzureOpenAIModel : IChatModel
{
/// <summary>
/// Azure OpenAI API Key
/// </summary>
public string ApiKey { get; init; }

/// <inheritdoc/>
public Usage TotalUsage { get; private set; }

/// <summary>
/// Deployment name
/// </summary>
public string Id { get; init; }

/// <inheritdoc/>
public int ContextLength => Configurations.ContextSize;

/// <summary>
/// Azure OpenAI Resource URI
/// </summary>
public string Endpoint { get; set; }

private AzureOpenAIConfiguration Configurations { get; }

#region Constructors
/// <summary>
/// Wrapper around Azure OpenAI
/// </summary>
/// <param name="apiKey">API Key</param>
/// <param name="endpoint">Azure Open AI Resource URI</param>
/// <param name="id">Deployment Model name</param>
/// <exception cref="ArgumentNullException"></exception>
public AzureOpenAIModel(string apiKey, string endpoint, string id)
{
Configurations = new AzureOpenAIConfiguration();
Id = id ?? throw new ArgumentNullException(nameof(id));
ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
}

/// <summary>
/// Wrapper around Azure OpenAI
/// </summary>
/// <param name="configuration">AzureOpenAIConfiguration</param>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentException"></exception>
public AzureOpenAIModel(AzureOpenAIConfiguration configuration)
{
Configurations = configuration ?? throw new ArgumentNullException(nameof(configuration));
ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration));
Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration));
Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration));
}
#endregion

#region Methods
/// <inheritdoc/>
public async Task<ChatResponse> GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default)
{
var messages = request.Messages.ToList();
var watch = Stopwatch.StartNew();
var response = await CreateChatCompleteAsync(messages, cancellationToken).ConfigureAwait(false);

messages.Add(ToMessage(response.Value));

watch.Stop();

var usage = GetUsage(response.Value.Usage) with
{
Time = watch.Elapsed,
};
TotalUsage += usage;

return new ChatResponse(
Messages: messages,
Usage: usage);
}

private async Task<Response<ChatCompletions>> CreateChatCompleteAsync(IReadOnlyCollection<Message> messages, CancellationToken cancellationToken = default)
{
var chatCompletionOptions = new ChatCompletionsOptions(Id, messages.Select(ToRequestMessage))
{
MaxTokens = Configurations.MaxTokens,
ChoiceCount = Configurations.ChoiceCount,
Temperature = Configurations.Temperature,
};

var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);
}

private static ChatRequestMessage ToRequestMessage(Message message)
{
return message.Role switch
{
MessageRole.System => new ChatRequestSystemMessage(message.Content),
MessageRole.Ai => new ChatRequestAssistantMessage(message.Content),
MessageRole.Human => new ChatRequestUserMessage(message.Content),
MessageRole.FunctionCall => throw new NotImplementedException(),
MessageRole.FunctionResult => throw new NotImplementedException(),
_ => throw new NotImplementedException()
};
}

private static Message ToMessage(ChatCompletions message)
{
return new Message(
Content: message.Choices[0].Message.Content,
Role: MessageRole.Ai);
}

private static Usage GetUsage(CompletionsUsage usage)
{
return Usage.Empty with
{
InputTokens = usage.PromptTokens,
OutputTokens = usage.CompletionTokens
};
}
#endregion
}
18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs

This file was deleted.

18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs

This file was deleted.

18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0</TargetFrameworks>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0;net8.0</TargetFrameworks>
</PropertyGroup>

<ItemGroup Label="Usings">
Expand All @@ -15,6 +15,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="tryAGI.OpenAI" />
</ItemGroup>

Expand Down

0 comments on commit a74358d

Please sign in to comment.