Skip to content

Commit

Permalink
[C#] feat: Switch tokenizer to use the Microsoft.ML.Tokenizers library (
Browse files Browse the repository at this point in the history
#1466)

fixes #1467

## Details

Switch to use Microsoft.ML.tokenizers library for tokenizer support.

#### Change details

> Describe your changes, with screenshots and code snippets as
appropriate

**code snippets**:

**screenshots**:

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes
  • Loading branch information
tarekgh authored Apr 1, 2024
1 parent 3af475e commit 91c532d
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Microsoft.Teams.AI.AI.Tokenizers;

namespace Microsoft.Teams.AI.Tests.UtilitiesTests
{
public class TokenizerTests
{
public static IEnumerable<object[]> TokenizersObjects()
{
yield return new object[] { new GPTTokenizer() };
yield return new object[] { new GPTTokenizer("gpt-4") };
}


[Theory]
[MemberData(nameof(TokenizersObjects))]
public void ValidateResults(ITokenizer tokenizer)
{
string text = "Hello, World";
Assert.NotNull(tokenizer);
IReadOnlyList<int> tokens = tokenizer.Encode(text);

Assert.Equal(new int[] { 9906, 11, 4435 }, tokens);
Assert.Equal(text, tokenizer.Decode(tokens));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class TextDataSource : IDataSource
public string Name { get; }

private readonly string _text;
private List<int> _tokens = new();
private IReadOnlyList<int> _tokens = new List<int>();

/// <summary>
/// Creates instance of `TextDataSource`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class ActionAugmentationSection : PromptSection
public readonly Dictionary<string, ChatCompletionAction> Actions;

private readonly string _text;
private List<int>? _tokens;
private IReadOnlyList<int>? _tokens;

private class ActionMap
{
Expand Down Expand Up @@ -64,7 +64,7 @@ public override async Task<RenderedPromptSection<List<ChatMessage>>> RenderAsMes
this._tokens = tokenizer.Encode(this._text);
}

List<int> tokens = this._tokens;
IReadOnlyList<int> tokens = this._tokens;
bool tooLong = false;

if (this._tokens.Count > maxTokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public override async Task<RenderedPromptSection<string>> RenderAsTextAsync(ITur
// truncate
if (this.Tokens > 1 && length > this.Tokens)
{
List<int> encoded = tokenizer.Encode(text);
IReadOnlyList<int> encoded = tokenizer.Encode(text);
text = tokenizer.Decode(encoded.Take(this.Tokens).ToList());
length = this.Tokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public virtual async Task<RenderedPromptSection<string>> RenderAsTextAsync(ITurn
// truncate
if (this.Tokens > 1 && length > this.Tokens)
{
List<int> encoded = tokenizer.Encode(text);
IReadOnlyList<int> encoded = tokenizer.Encode(text);
text = tokenizer.Decode(encoded.Take(this.Tokens).ToList());
length = this.Tokens;
}
Expand All @@ -148,7 +148,7 @@ protected RenderedPromptSection<List<ChatMessage>> TruncateMessages(List<ChatMes
foreach (ChatMessage message in messages)
{
string text = this.GetMessageText(message);
List<int> encoded = tokenizer.Encode(text);
IReadOnlyList<int> encoded = tokenizer.Encode(text);

if (len + encoded.Count > budget)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SharpToken;
using Microsoft.ML.Tokenizers;

namespace Microsoft.Teams.AI.AI.Tokenizers
{
Expand All @@ -7,52 +7,37 @@ namespace Microsoft.Teams.AI.AI.Tokenizers
/// </summary>
public class GPTTokenizer : ITokenizer
{
private readonly GptEncoding _encoding;
private readonly Tokenizer _encoding;

/// <summary>
/// Creates an instance of `GPTTokenizer` using the `cl100k_base` encoding by default
/// Creates an instance of `GPTTokenizer` using "gpt-4" model name by default which is using the `cl100k_base` encoding
/// </summary>
public GPTTokenizer()
{
this._encoding = GptEncoding.GetEncoding("cl100k_base");
}
public GPTTokenizer() => _encoding = Tokenizer.CreateTiktokenForModel("gpt-4");

/// <summary>
/// Creates an instance of `GPTTokenizer`
/// </summary>
/// <param name="encoding">encoding to use</param>
public GPTTokenizer(GptEncoding encoding)
{
this._encoding = encoding;
}
public GPTTokenizer(Tokenizer encoding) => this._encoding = encoding;

/// <summary>
/// Creates an instance of `GPTTokenizer`
/// </summary>
/// <param name="model">model to encode/decode for</param>
public GPTTokenizer(string model)
{
this._encoding = GptEncoding.GetEncodingForModel(model);
}
public GPTTokenizer(string model) => this._encoding = Tokenizer.CreateTiktokenForModel(model);

/// <summary>
/// Encode
/// </summary>
/// <param name="text">text to encode</param>
/// <returns>encoded tokens</returns>
public List<int> Encode(string text)
{
return this._encoding.Encode(text);
}
public IReadOnlyList<int> Encode(string text) => this._encoding.EncodeToIds(text);

/// <summary>
/// Decode
/// </summary>
/// <param name="tokens">tokens to decode</param>
/// <returns>decoded text</returns>
public string Decode(List<int> tokens)
{
return this._encoding.Decode(tokens);
}
public string Decode(IEnumerable<int> tokens) => this._encoding.Decode(tokens)!;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ public interface ITokenizer
/// </summary>
/// <param name="text">text to encode</param>
/// <returns>encoded bytes</returns>
public List<int> Encode(string text);
public IReadOnlyList<int> Encode(string text);

/// <summary>
/// Decode
/// </summary>
/// <param name="tokens">tokens to decode</param>
/// <returns>decoded string</returns>
public string Decode(List<int> tokens);
public string Decode(IEnumerable<int> tokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using Microsoft.Teams.AI.Exceptions;
using Microsoft.Teams.AI.State;
using System.Runtime.CompilerServices;
using System.Threading;

[assembly: InternalsVisibleTo("Microsoft.Teams.AI.Tests")]
namespace Microsoft.Teams.AI
Expand Down Expand Up @@ -133,6 +132,9 @@ public async Task SignOutUserAsync(ITurnContext context, TState state, Cancellat
await UserTokenClientWrapper.SignoutUserAsync(context, _settings.ConnectionName, cancellationToken);
}

/// <summary>
/// Get user token
/// </summary>
protected virtual async Task<TokenResponse> GetUserToken(ITurnContext context, string connectionName, CancellationToken cancellationToken = default)
{
return await UserTokenClientWrapper.GetUserTokenAsync(context, connectionName, "", cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
<PackageReference Include="Microsoft.Bot.Builder.Integration.AspNet.Core" Version="4.21.1" />
<PackageReference Include="Microsoft.Identity.Client" Version="4.59.0" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SharpToken" Version="1.2.17" />
<PackageReference Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24179.1" />
<PackageReference Include="System.Text.Json" Version="7.0.4" />
</ItemGroup>

Expand Down
7 changes: 7 additions & 0 deletions dotnet/packages/Microsoft.TeamsAI/nuget.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<add key="dotnet-local-feed" value=" https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-libraries/nuget/v3/index.json" />
<add key="nuget" value="https://api.nuget.org/v3/index.json" />
</packageSources>
</configuration>

0 comments on commit 91c532d

Please sign in to comment.