Skip to content

Commit

Permalink
[C#] feat: AssistantsPlanner support for Azure Assistants API (#1609)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1499 (issue number)

## Details

Introduces Azure OpenAI support to the `AssistantPlanner`

#### Change details

* Replaced internal client with `AssistantsClient` from
`Microsoft.AI.OpenAI.Assistants` library
* Removed `OpenAIClient.Assistant.cs`, `OpenAIClient.Thread.cs` and all
internal models used to facilitate call to Assistants API. Marked public
models as `Obsolete`.

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes

### Additional information

> Feel free to add other relevant information below
  • Loading branch information
singhk97 authored May 7, 2024
1 parent 15eac52 commit 8ee2c3b
Show file tree
Hide file tree
Showing 31 changed files with 490 additions and 1,317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Moq;
using System.Reflection;
using Microsoft.Teams.AI.AI.Planners;
using Azure.AI.OpenAI.Assistants;

namespace Microsoft.Teams.AI.Tests.AITests
{
Expand All @@ -22,7 +23,7 @@ public async Task Test_BeginTaskAsync_Assistant_Single_Reply()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down Expand Up @@ -53,7 +54,7 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForCurrentRun()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down Expand Up @@ -85,7 +86,7 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForPreviousRun()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand All @@ -97,8 +98,8 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForPreviousRun()
testClient.RemainingRunStatus.Enqueue("completed");
testClient.RemainingMessages.Enqueue("welcome");

var thread = await testClient.CreateThreadAsync(new(), CancellationToken.None);
await testClient.CreateRunAsync(thread.Id, new(), CancellationToken.None);
AssistantThread thread = await testClient.CreateThreadAsync(new(), CancellationToken.None);
await testClient.CreateRunAsync(thread.Id, AssistantsModelFactory.CreateRunOptions(), CancellationToken.None);
turnState.ThreadId = thread.Id;

// Act
Expand All @@ -121,7 +122,7 @@ public async Task Test_BeginTaskAsync_Assistant_RunCancelled()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down Expand Up @@ -150,7 +151,7 @@ public async Task Test_BeginTaskAsync_Assistant_RunExpired()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down Expand Up @@ -181,7 +182,7 @@ public async Task Test_BeginTaskAsync_Assistant_RunFailed()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down Expand Up @@ -210,32 +211,18 @@ public async Task Test_ContinueTaskAsync_Assistant_RequiresAction()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";

var aiOptions = new AIOptions<AssistantsState>(planner);
var ai = new AI<AssistantsState>(aiOptions);

testClient.RemainingActions.Enqueue(new()
{
SubmitToolOutputs = new()
{
ToolCalls = new()
{
new()
{
Id = "test-tool-id",
Function = new()
{
Name = "test-action",
Arguments = "{}"
}
}
}
}
});
var functionToolCall = AssistantsModelFactory.RequiredFunctionToolCall("test-tool-id", "test-action", "{}");
var requiredAction = AssistantsModelFactory.SubmitToolOutputsAction(new List<RequiredToolCall>{ functionToolCall });

testClient.RemainingActions.Enqueue(requiredAction);
testClient.RemainingRunStatus.Enqueue("requires_action");
testClient.RemainingRunStatus.Enqueue("in_progress");
testClient.RemainingRunStatus.Enqueue("completed");
Expand Down Expand Up @@ -271,7 +258,7 @@ public async Task Test_ContinueTaskAsync_Assistant_IgnoreRedundantAction()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand All @@ -280,24 +267,10 @@ public async Task Test_ContinueTaskAsync_Assistant_IgnoreRedundantAction()
var aiOptions = new AIOptions<AssistantsState>(planner);
var ai = new AI<AssistantsState>(aiOptions);

testClient.RemainingActions.Enqueue(new()
{
SubmitToolOutputs = new()
{
ToolCalls = new()
{
new()
{
Id = "test-tool-id",
Function = new()
{
Name = "test-action",
Arguments = "{}"
}
}
}
}
});
var functionToolCall = AssistantsModelFactory.RequiredFunctionToolCall("test-tool-id", "test-action", "{}");
var requiredAction = AssistantsModelFactory.SubmitToolOutputsAction(new List<RequiredToolCall> { functionToolCall });

testClient.RemainingActions.Enqueue(requiredAction);
testClient.RemainingRunStatus.Enqueue("requires_action");
testClient.RemainingRunStatus.Enqueue("in_progress");
testClient.RemainingRunStatus.Enqueue("completed");
Expand Down Expand Up @@ -334,7 +307,7 @@ public async Task Test_ContinueTaskAsync_Assistant_MultipleMessages()
{
PollingInterval = TimeSpan.FromMilliseconds(100)
});
planner.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
planner.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(planner, testClient);
var turnContextMock = new Mock<ITurnContext>();
var turnState = await _CreateAssistantsState();
turnState.Temp!.Input = "hello";
Expand Down

This file was deleted.

Loading

0 comments on commit 8ee2c3b

Please sign in to comment.