Skip to content

Commit

Permalink
[C#] feat: AssistantsPlanner file upload and download support (#1945)
Browse files Browse the repository at this point in the history
## Linked issues

fixes #1919 

## Details
* `AssistantsPlanner` now supports passing in attached files and images
within the user message. Files will be uploaded through the OpenAI/Azure
OpenAI `Files` API and the `file_id` is attached to the message.
* `AssistantsPlanner` now supports downloading files and images
generated in a single run. Only image files are attached to the outgoing
activity as a list of `Attachments` in the `PredictedSayCommand` default
action.

#### Change details
* Created an `AssistantsMessage` class that extends `ChatMessage`. It
stores a single `MessageContent` and files generated with in it in the
`AttachedFiles` property.
* Added a `FileClient` field to the `AssistantsPlanner`. It wraps around
the `Files` api.
* Added `FileName` field in `InputFile.cs` class. A filename is required
to upload a file to `Files` api.

**Samples Updates**
* `OrderBot` is now configured with the `file_search` tool. A vector
store is created, and the `menu.pdf` file is uploaded to it and the
store is attached to the assistant on creation. Users can ask for the
menu items or prices and the assistant will be using the `file_search`
tool under the hood to get that information.
* `MathBot` has no updates - but it can be used to get the assistant to
generate a png image of a graph.

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes

### Additional information

> Feel free to add other relevant information below
  • Loading branch information
singhk97 authored Sep 9, 2024
1 parent 7fc1b05 commit dc986b5
Show file tree
Hide file tree
Showing 16 changed files with 463 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using Microsoft.Teams.AI.AI.Models;
using System.ClientModel;
using System.ClientModel.Primitives;
using Microsoft.Teams.AI.AI.Models;
using Microsoft.Teams.AI.Tests.TestUtils;
using Moq;
using OpenAI.Assistants;
using OpenAI.Files;

namespace Microsoft.Teams.AI.Tests.AITests
{
Expand All @@ -11,19 +15,27 @@ public void Test_Constructor()
{
// Arrange
MessageContent content = OpenAIModelFactory.CreateMessageContent("message", "fileId");
Mock<FileClient> fileClientMock = new Mock<FileClient>();
fileClientMock.Setup(fileClient => fileClient.DownloadFileAsync("fileId", It.IsAny<CancellationToken>())).Returns(() =>
{
return Task.FromResult(ClientResult.FromValue(BinaryData.FromString("test"), new Mock<PipelineResponse>().Object));
});
fileClientMock.Setup(fileClient => fileClient.GetFileAsync("fileId", It.IsAny<CancellationToken>())).Returns(() =>
{
return Task.FromResult(ClientResult.FromValue(OpenAIModelFactory.CreateOpenAIFileInfo("fileId"), new Mock<PipelineResponse>().Object));
});

// Act
AssistantsMessage assistantMessage = new AssistantsMessage(content);
AssistantsMessage assistantMessage = new AssistantsMessage(content, fileClientMock.Object);

// Assert
Assert.Equal(assistantMessage.MessageContent, content);
Assert.Equal(content, assistantMessage.MessageContent);
Assert.Equal("message", assistantMessage.Content);
Assert.Equal(1, assistantMessage.AttachedFiles!.Count);
Assert.Equal("fileId", assistantMessage.AttachedFiles[0].FileInfo.Id);

ChatMessage chatMessage = assistantMessage;
Assert.NotNull(chatMessage);
Assert.Equal(chatMessage.Content, "message");
Assert.Equal(chatMessage.Context!.Citations[0].Url, "fileId");
Assert.Equal(chatMessage.Context.Citations[0].Title, "");
Assert.Equal(chatMessage.Context.Citations[0].Content, "");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Microsoft.Teams.AI.Tests.AITests.Models
{
internal class ChatCompletionToolCallTests
internal sealed class ChatCompletionToolCallTests
{
[Fact]
public void Test_ChatCompletionsToolCall_ToFunctionToolCall()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using OpenAI.Assistants;
using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities;
using OpenAI.Assistants;
using OpenAI.Files;
using System.ClientModel;
using System.ClientModel.Primitives;

Expand Down Expand Up @@ -69,6 +71,12 @@ public static MessageContent CreateMessageContent(string message, string fileId)
""file_citation"": {{
""file_id"": ""{fileId}""
}}
}},
{{
""type"": ""file_path"",
""file_path"": {{
""file_id"": ""{fileId}""
}}
}}
]
}}
Expand All @@ -82,6 +90,22 @@ public static MessageContent CreateMessageContent(string message, string fileId)
return threadMessage.Content[0];
}

public static OpenAIFileInfo CreateOpenAIFileInfo(string fileId)
{
var json = @$"{{
""id"": ""{fileId}"",
""object"": ""file"",
""bytes"": 120000,
""created_at"": 16761602,
""filename"": ""salesOverview.pdf"",
""purpose"": ""assistants""
}}";

var fileInfo = ModelReaderWriter.Read<OpenAIFileInfo>(BinaryData.FromString(json))!;

return fileInfo;
}

public static ThreadRun CreateThreadRun(string threadId, string runStatus, string? runId = null, IList<RequiredAction> requiredActions = null!)
{
var raJson = "{}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,19 @@ public async Task<string> SayCommandAsync([ActionTurnContext] ITurnContext turnC
entity.Citation = referencedCitations;
}

List<Attachment>? attachments = new();
if (command.Response.Attachments != null)
{
attachments = command.Response.Attachments;
}

await turnContext.SendActivityAsync(new Activity()
{
Type = ActivityTypes.Message,
Text = contentText,
ChannelData = channelData,
Entities = new List<Entity>() { entity }
Entities = new List<Entity>() { entity },
Attachments = attachments
}, cancellationToken);

return string.Empty;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using OpenAI.Assistants;
using System.ClientModel;
using Microsoft.Bot.Schema;
using OpenAI.Assistants;
using OpenAI.Files;


namespace Microsoft.Teams.AI.AI.Models
{
Expand All @@ -12,46 +16,189 @@ public class AssistantsMessage : ChatMessage
/// </summary>
public MessageContent MessageContent;

/// <summary>
/// Files attached to the assistants api message.
/// </summary>
public List<OpenAIFile>? AttachedFiles { get; }

/// <summary>
/// Creates an AssistantMessage.
/// </summary>
/// <param name="content">The Assistants API thread message.</param>
public AssistantsMessage(MessageContent content) : base(ChatRole.Assistant)
/// <param name="fileClient">The OpenAI File client.</param>
public AssistantsMessage(MessageContent content, FileClient? fileClient = null) : base(ChatRole.Assistant)
{
this.MessageContent = content;

if (content != null)
if (content == null)
{
throw new ArgumentNullException(nameof(content));
}

string textContent = content.Text ?? "";
MessageContext context = new();

List<Task<ClientResult<BinaryData>>> fileContentDownloadTasks = new();
List<Task<ClientResult<OpenAIFileInfo>>> fileInfoDownloadTasks = new();

for (int i = 0; i < content.TextAnnotations.Count; i++)
{
TextAnnotation annotation = content.TextAnnotations[i];
if (annotation?.TextToReplace != null)
{
textContent = textContent.Replace(annotation.TextToReplace, $"[{i + 1}]");
}

if (annotation?.InputFileId != null)
{
// Retrieve file info object
// Neither `content` or `title` is provided in the annotations.
context.Citations.Add(new("Content not available", $"File {i + 1}", annotation.InputFileId));
}

if (annotation?.OutputFileId != null && fileClient != null)
{
// Files generated by code interpretor tool.
fileContentDownloadTasks.Add(fileClient.DownloadFileAsync(annotation.OutputFileId));
fileInfoDownloadTasks.Add(fileClient.GetFileAsync(annotation.OutputFileId));
}
}

List<OpenAIFile> attachedFiles = new();
if (fileContentDownloadTasks.Count > 0)
{
Task.WaitAll(fileContentDownloadTasks.ToArray());
Task.WaitAll(fileInfoDownloadTasks.ToArray());

// Create attachments out of these downloaded files
// Wait for tasks to complete
ClientResult<BinaryData>[] downloadedFileContent = fileContentDownloadTasks.Select((task) => task.Result).ToArray();
ClientResult<OpenAIFileInfo>[] downloadedFileInfo = fileInfoDownloadTasks.Select((task) => task.Result).ToArray();

for (int i = 0; i < downloadedFileContent.Length; i++)
{
attachedFiles.Add(new OpenAIFile(downloadedFileInfo[i], downloadedFileContent[i]));
}
}

this.AttachedFiles = attachedFiles;
this.Attachments = _ConvertAttachedImagesToActivityAttachments(attachedFiles);

this.Content = textContent;
this.Context = context;
}

private List<Attachment> _ConvertAttachedImagesToActivityAttachments(List<OpenAIFile> attachedFiles)
{
List<Attachment> attachments = new();

foreach (OpenAIFile file in attachedFiles)
{
string? textContent = content.Text;
if (content.Text != null && content.Text != string.Empty)
string? mimetype = file.GetMimeType();
string[] imageMimeTypes = new string[] { "image/png", "image/jpg", "image/jpeg", "image/gif" };
if (mimetype == null)
{
this.Content = content.Text;
continue;
}

MessageContext context = new();
for (int i = 0; i < content.TextAnnotations.Count; i++)
if (!imageMimeTypes.Contains(mimetype))
{
TextAnnotation annotation = content.TextAnnotations[i];
if (annotation?.TextToReplace != null)
{
textContent.Replace(annotation.TextToReplace, $"[{i}]");
}

if (annotation?.InputFileId != null)
{
// Retrieve file info object
// Neither `content` or `title` is provided in the annotations
context.Citations.Add(new("", "", annotation.InputFileId));
}

if (annotation?.OutputFileId != null)
{
// TODO: Download files or provide link to end user.
// Files were generated by code interpretor tool.
}
// Skip non image file types
continue;
}

Context = context;
string imageBase64String = Convert.ToBase64String(file.FileContent.ToArray());
attachments.Add(new Attachment
{
Name = file.FileInfo.Filename,
ContentType = mimetype,
ContentUrl = $"data:image/png;base64,{imageBase64String}",
});
}

return attachments;
}
}

/// <summary>
/// Represents an OpenAI File.
/// </summary>
public class OpenAIFile
{
/// <summary>
/// Represents an OpenAI File information
/// </summary>
public OpenAIFileInfo FileInfo;

/// <summary>
/// Represents the contents of an OpenAI File
/// </summary>
public BinaryData FileContent;

private static readonly Dictionary<string, string> MimeTypes = new(StringComparer.OrdinalIgnoreCase)
{
{ "c", "text/x-c" },
{ "cs", "text/x-csharp" },
{ "cpp", "text/x-c++" },
{ "doc", "application/msword" },
{ "docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document" },
{ "html", "text/html" },
{ "java", "text/x-java" },
{ "json", "application/json" },
{ "md", "text/markdown" },
{ "pdf", "application/pdf" },
{ "php", "text/x-php" },
{ "pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation" },
{ "py", "text/x-python" },
{ "rb", "text/x-ruby" },
{ "tex", "text/x-tex" },
{ "txt", "text/plain" },
{ "css", "text/css" },
{ "js", "text/javascript" },
{ "sh", "application/x-sh" },
{ "ts", "application/typescript" },
{ "csv", "application/csv" },
{ "jpeg", "image/jpeg" },
{ "jpg", "image/jpeg" },
{ "gif", "image/gif" },
{ "png", "image/png" },
{ "tar", "application/x-tar" },
{ "xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" },
{ "xml", "application/xml" }, // or "text/xml"
{ "zip", "application/zip" }
};

/// <summary>
/// Initializes an instance of OpenAIFile
/// </summary>
/// <param name="fileInfo">The OpenAI File</param>
/// <param name="fileContent">The OpenAI File contents</param>
public OpenAIFile(OpenAIFileInfo fileInfo, BinaryData fileContent)
{
FileInfo = fileInfo;
FileContent = fileContent;
}

/// <summary>
/// Gets the file's mime type
/// </summary>
/// <returns>The file's mime type</returns>
public string? GetMimeType()
{
bool hasExtension = FileInfo.Filename.Contains(".");
if (!hasExtension)
{
return null;
}

string fileExtension = FileInfo.Filename.Split(new char[] { '.' }).Last();
if (MimeTypes.TryGetValue(fileExtension, out string mimeType))
{
return mimeType;
}
else
{
return null;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Azure.AI.OpenAI;
using Azure.AI.OpenAI.Chat;
using Microsoft.Bot.Schema;
using Microsoft.Teams.AI.Exceptions;
using Microsoft.Teams.AI.Utilities;
using OpenAI.Chat;
Expand Down Expand Up @@ -49,6 +50,10 @@ public class ChatMessage
/// </summary>
public IList<ChatCompletionsToolCall>? ToolCalls { get; set; }

/// <summary>
/// Attachments for the bot to send back.
/// </summary>
public List<Attachment>? Attachments { get; set; }

/// <summary>
/// Gets the content with the given type.
Expand Down
Loading

0 comments on commit dc986b5

Please sign in to comment.