Skip to content

Commit

Permalink
[C#] feat: Add managed identity support for OpenAIModel and `OpenAI…
Browse files Browse the repository at this point in the history
…Embeddings` classes. (#1869)

## Linked issues

closes: #1496  (issue number)

## Details

Added support for authentication `OpenAIModel` and `OpenAIEmbeddings`
classes using `Azure.Core.TokenCredential`. This in turn supports
authentication calls using user or system assigned managed identity aad
auth.

#### Change details

* Updated `OpenAIModel` and `OpenAIEmbeddings` classes
* Added unit tests
* Updated `08.datasource.azureopenai` sample to work with managed
identity auth.

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes
  • Loading branch information
singhk97 authored Jul 31, 2024
1 parent 135235e commit 0197a00
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using System.Reflection;
using ChatMessage = Microsoft.Teams.AI.AI.Models.ChatMessage;
using ChatRole = Microsoft.Teams.AI.AI.Models.ChatRole;
using Azure.Identity;

namespace Microsoft.Teams.AI.Tests.AITests.Models
{
Expand Down Expand Up @@ -52,6 +53,16 @@ public void Test_Constructor_AzureOpenAI_InvalidAzureApiVersion()
Assert.Equal("Model created with an unsupported API version of `2023-12-01-preview`.", exception.Message);
}

[Fact]
public void Test_Constructor_AzureOpenAI_ManagedIdentityAuth()
{
// Arrange
var options = new AzureOpenAIModelOptions(new DefaultAzureCredential(), "test-deployment", "https://test.openai.azure.com/");

// Act
new OpenAIModel(options);
}

[Fact]
public async void Test_CompletePromptAsync_AzureOpenAI_Chat_PromptTooLong()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
using System.ClientModel;
using Microsoft.Teams.AI.Tests.TestUtils;
using System.ClientModel.Primitives;
using Azure.Identity;

#pragma warning disable CS8604 // Possible null reference argument.
namespace Microsoft.Teams.AI.Tests.AITests
{
public class OpenAIEmbeddingsTests
{
[Fact]
public void Test_Constructor_AzureOpenAI_ManagedIdentityAuth()
{
// Arrange
var options = new AzureOpenAIEmbeddingsOptions(new DefaultAzureCredential(), "test-deployment", "https://test.openai.azure.com/");

// Act
new OpenAIEmbeddings(options);
}

[Fact]
public async void Test_OpenAI_CreateEmbeddings_ReturnEmbeddings()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="2.0.0-beta.2" />
<PackageReference Include="Azure.Identity" Version="1.12.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Microsoft.Bot.Builder" Version="4.22.7" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Teams.AI.Utilities;
using Azure.Core;
using Microsoft.Teams.AI.Utilities;

namespace Microsoft.Teams.AI.AI.Embeddings
{
Expand All @@ -10,7 +11,12 @@ public class AzureOpenAIEmbeddingsOptions : BaseOpenAIEmbeddingsOptions
/// <summary>
/// API key to use when making requests to Azure OpenAI.
/// </summary>
public string AzureApiKey { get; set; }
public string? AzureApiKey { get; set; }

/// <summary>
/// The token credential to use when making requests to Azure OpenAI.
/// </summary>
public TokenCredential? TokenCredential { get; set; }

/// <summary>
/// Name of the Azure OpenAI deployment (model) to use.
Expand Down Expand Up @@ -48,5 +54,23 @@ public AzureOpenAIEmbeddingsOptions(
this.AzureDeployment = azureDeployment;
this.AzureEndpoint = azureEndpoint;
}

/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAIEmbeddingsOptions"/> class.
/// </summary>
/// <param name="tokenCredential">token credential</param>
/// <param name="azureDefaultDeployment">the deployment name</param>
/// <param name="azureEndpoint">azure endpoint</param>
public AzureOpenAIEmbeddingsOptions(TokenCredential tokenCredential, string azureDefaultDeployment, string azureEndpoint)
{
Verify.ParamNotNull(tokenCredential);
Verify.ParamNotNull(azureDefaultDeployment);
Verify.ParamNotNull(azureEndpoint);

this.TokenCredential = tokenCredential;
this.AzureDeployment = azureDefaultDeployment;
this.AzureEndpoint = azureEndpoint;
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,17 @@ public class OpenAIEmbeddings : IEmbeddingsModel
public OpenAIEmbeddings(OpenAIEmbeddingsOptions options, ILoggerFactory? loggerFactory = null, HttpClient? httpClient = null)
{
Verify.ParamNotNull(options);
Verify.ParamNotNull(options.ApiKey, "OpenAIEmbeddingsOptions.ApiKey");
Verify.ParamNotNull(options.Model, "OpenAIEmbeddingsOptions.Model");

_options = new OpenAIEmbeddingsOptions(options.ApiKey, options.Model)
{
Organization = options.Organization,
LogRequests = options.LogRequests ?? false,
RetryPolicy = options.RetryPolicy ?? new List<TimeSpan> { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) },
};

options.LogRequests = options.LogRequests ?? false;
options.RetryPolicy = options.RetryPolicy ?? new List<TimeSpan> { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) };
_logger = loggerFactory == null ? NullLogger.Instance : loggerFactory.CreateLogger<OpenAIModel>();

OpenAIEmbeddingsOptions embeddingsOptions = (OpenAIEmbeddingsOptions)_options;

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint (6.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint (6.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Analyze

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Analyze

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint (7.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint (7.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint / Build/Test/Lint (7.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint / Build/Test/Lint (7.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint / Build/Test/Lint (6.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build/Test/Lint / Build/Test/Lint (6.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Publish (7.0)

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Embeddings/OpenAIEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Publish (6.0)

Converting null literal or possible null value to non-nullable type.
OpenAIClientOptions openAIClientOptions = new()
{
RetryPolicy = new SequentialDelayRetryPolicy(embeddingsOptions.RetryPolicy!, embeddingsOptions.RetryPolicy!.Count)
RetryPolicy = new SequentialDelayRetryPolicy(options.RetryPolicy!, options.RetryPolicy!.Count)
};

openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall);
Expand All @@ -57,13 +53,14 @@ public OpenAIEmbeddings(OpenAIEmbeddingsOptions options, ILoggerFactory? loggerF
openAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient);
}

if (!string.IsNullOrEmpty(embeddingsOptions.Organization))
if (!string.IsNullOrEmpty(options.Organization))
{
openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", options.Organization!), PipelinePosition.PerCall);
}
_openAIClient = new OpenAIClient(new ApiKeyCredential(embeddingsOptions.ApiKey), openAIClientOptions);
_openAIClient = new OpenAIClient(new ApiKeyCredential(options.ApiKey), openAIClientOptions);

_deploymentName = options.Model;
_options = options;
}

/// <summary>
Expand All @@ -75,7 +72,6 @@ public OpenAIEmbeddings(OpenAIEmbeddingsOptions options, ILoggerFactory? loggerF
public OpenAIEmbeddings(AzureOpenAIEmbeddingsOptions options, ILoggerFactory? loggerFactory = null, HttpClient? httpClient = null)
{
Verify.ParamNotNull(options);
Verify.ParamNotNull(options.AzureApiKey, "AzureOpenAIEmbeddingsOptions.AzureApiKey");
Verify.ParamNotNull(options.AzureDeployment, "AzureOpenAIEmbeddingsOptions.AzureDeployment");
Verify.ParamNotNull(options.AzureEndpoint, "AzureOpenAIEmbeddingsOptions.AzureEndpoint");

Expand All @@ -86,19 +82,15 @@ public OpenAIEmbeddings(AzureOpenAIEmbeddingsOptions options, ILoggerFactory? lo
throw new ArgumentException($"Model created with an unsupported API version of `{apiVersion}`.");
}

_options = new AzureOpenAIEmbeddingsOptions(options.AzureApiKey, options.AzureDeployment, options.AzureEndpoint)
{
AzureApiVersion = apiVersion,
LogRequests = options.LogRequests ?? false,
RetryPolicy = options.RetryPolicy ?? new List<TimeSpan> { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) }
};

options.LogRequests = options.LogRequests ?? false;
options.RetryPolicy = options.RetryPolicy ?? new List<TimeSpan> { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) };
_logger = loggerFactory == null ? NullLogger.Instance : loggerFactory.CreateLogger<OpenAIModel>();


AzureOpenAIEmbeddingsOptions azureEmbeddingsOptions = (AzureOpenAIEmbeddingsOptions)_options;
AzureOpenAIClientOptions azureOpenAIClientOptions = new(serviceVersion.Value)
{
RetryPolicy = new SequentialDelayRetryPolicy(_options.RetryPolicy, _options.RetryPolicy.Count)
RetryPolicy = new SequentialDelayRetryPolicy(options.RetryPolicy, options.RetryPolicy.Count)
};

azureOpenAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall);
Expand All @@ -107,8 +99,22 @@ public OpenAIEmbeddings(AzureOpenAIEmbeddingsOptions options, ILoggerFactory? lo
azureOpenAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient);
}

_openAIClient = new AzureOpenAIClient(new Uri(azureEmbeddingsOptions.AzureEndpoint), new ApiKeyCredential(azureEmbeddingsOptions.AzureApiKey), azureOpenAIClientOptions);
Uri uri = new(options.AzureEndpoint);
if (options.TokenCredential != null)
{
_openAIClient = new AzureOpenAIClient(uri, options.TokenCredential, azureOpenAIClientOptions);
}
else if (options.AzureApiKey != null && options.AzureApiKey != string.Empty)
{
_openAIClient = new AzureOpenAIClient(uri, new ApiKeyCredential(options.AzureApiKey), azureOpenAIClientOptions);
}
else
{
throw new ArgumentException("token credential or api key is required.");
}

_deploymentName = options.AzureDeployment;
_options = options;
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Teams.AI.Utilities;
using Azure.Core;
using Microsoft.Teams.AI.Utilities;

namespace Microsoft.Teams.AI.AI.Models
{
Expand All @@ -10,7 +11,12 @@ public class AzureOpenAIModelOptions : BaseOpenAIModelOptions
/// <summary>
/// API key to use when making requests to Azure OpenAI.
/// </summary>
public string AzureApiKey { get; set; }
public string? AzureApiKey { get; set; }

/// <summary>
/// The token credential to use when making requests to Azure OpenAI.
/// </summary>
public TokenCredential? TokenCredential { get; set; }

/// <summary>
/// Default name of the Azure OpenAI deployment (model) to use.
Expand Down Expand Up @@ -49,5 +55,22 @@ public AzureOpenAIModelOptions(
this.AzureDefaultDeployment = azureDefaultDeployment;
this.AzureEndpoint = azureEndpoint;
}

/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAIModelOptions"/> class.
/// </summary>
/// <param name="tokenCredential">token credential</param>
/// <param name="azureDefaultDeployment">the deployment name</param>
/// <param name="azureEndpoint">azure endpoint</param>
public AzureOpenAIModelOptions(TokenCredential tokenCredential, string azureDefaultDeployment, string azureEndpoint)
{
Verify.ParamNotNull(tokenCredential);
Verify.ParamNotNull(azureDefaultDeployment);
Verify.ParamNotNull(azureEndpoint);

this.TokenCredential = tokenCredential;
this.AzureDefaultDeployment = azureDefaultDeployment;
this.AzureEndpoint = azureEndpoint;
}
}
}
Loading

0 comments on commit 0197a00

Please sign in to comment.