Skip to content

Commit

Permalink
com.rest.huggingface 1.0.0-preview.12 (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenHodgson authored Dec 11, 2023
1 parent 7582258 commit 2e340d3
Show file tree
Hide file tree
Showing 19 changed files with 219 additions and 172 deletions.
10 changes: 6 additions & 4 deletions Documentation~/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ The recommended installation method is though the unity package manager and [Ope

There are 4 ways to provide your API keys, in order of precedence:

1. [Pass keys directly with constructor](#pass-keys-directly-with-constructor)
2. [Unity Scriptable Object](#unity-scriptable-object)
:warning: We recommended using the environment variables to load the API key instead of having it hard coded in your source. It is not recommended use this method in production, but only for accepting user credentials, local testing and quick start scenarios.

1. [Pass keys directly with constructor](#pass-keys-directly-with-constructor) :warning:
2. [Unity Scriptable Object](#unity-scriptable-object) :warning:
3. [Load key from configuration file](#load-key-from-configuration-file)
4. [Use System Environment Variables](#use-system-environment-variables)

Expand Down Expand Up @@ -93,7 +95,7 @@ To create a configuration file, create a new text file named `.huggingface` and
You can also load the file directly with known path by calling a static method in Authentication:

```csharp
var api = new HuggingFaceClient(HuggingFaceAuthentication.Default.LoadFromDirectory("your/path/to/.huggingface"));;
var api = new HuggingFaceClient(new HuggingFaceAuthentication().LoadFromDirectory("your/path/to/.huggingface"));;
```

#### Use System Environment Variables
Expand All @@ -103,7 +105,7 @@ Use your system's environment variables specify an api key to use.
- Use `HUGGING_FACE_API_KEY` for your api key.

```csharp
var api = new HuggingFaceClient(HuggingFaceAuthentication.Default.LoadFromEnvironment());
var api = new HuggingFaceClient(new HuggingFaceAuthentication().LoadFromEnvironment());
```

### Hub
Expand Down
54 changes: 31 additions & 23 deletions Runtime/HuggingFaceAuthentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace HuggingFace
{
public sealed class HuggingFaceAuthentication : AbstractAuthentication<HuggingFaceAuthentication, HuggingFaceAuthInfo>
public sealed class HuggingFaceAuthentication : AbstractAuthentication<HuggingFaceAuthentication, HuggingFaceAuthInfo, HuggingFaceConfiguration>
{
internal const string CONFIG_FILE = ".huggingface";
private const string HUGGING_FACE_API_KEY = nameof(HUGGING_FACE_API_KEY);
Expand All @@ -20,30 +20,35 @@ public sealed class HuggingFaceAuthentication : AbstractAuthentication<HuggingFa
public static implicit operator HuggingFaceAuthentication(string apiKey) => new HuggingFaceAuthentication(apiKey);

/// <summary>
/// Instantiates a new Authentication object that will load the default config.
/// Instantiates an empty Authentication object.
/// </summary>
public HuggingFaceAuthentication()
{
if (cachedDefault != null) { return; }

cachedDefault = (LoadFromAsset<HuggingFaceConfiguration>() ??
LoadFromDirectory()) ??
LoadFromDirectory(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile)) ??
LoadFromEnvironment();
Info = cachedDefault?.Info;
}
public HuggingFaceAuthentication() { }

/// <summary>
/// Instantiates a new Authentication object with the given <paramref name="apiKey"/>, which may be <see langword="null"/>.
/// </summary>
/// <param name="apiKey">The API key, required to access the API endpoint.</param>
public HuggingFaceAuthentication(string apiKey) => Info = new HuggingFaceAuthInfo(apiKey);
public HuggingFaceAuthentication(string apiKey)
{
Info = new HuggingFaceAuthInfo(apiKey);
cachedDefault = this;
}

/// <summary>
/// Instantiates a new Authentication object with the given <paramref name="authInfo"/>, which may be <see langword="null"/>.
/// </summary>
/// <param name="authInfo"></param>
public HuggingFaceAuthentication(HuggingFaceAuthInfo authInfo) => Info = authInfo;
/// <param name="authInfo"><see cref="HuggingFaceAuthInfo"/>.</param>
public HuggingFaceAuthentication(HuggingFaceAuthInfo authInfo)
{
Info = authInfo;
cachedDefault = this;
}

/// <summary>
/// Instantiates a new Authentication object with the given <see cref="configuration"/>.
/// </summary>
/// <param name="configuration"><see cref="HuggingFaceConfiguration"/>.</param>
public HuggingFaceAuthentication(HuggingFaceConfiguration configuration) : this(configuration.ApiKey) { }

/// <inheritdoc />
public override HuggingFaceAuthInfo Info { get; }
Expand All @@ -57,18 +62,21 @@ public HuggingFaceAuthentication()
/// </summary>
public static HuggingFaceAuthentication Default
{
get => cachedDefault ?? new HuggingFaceAuthentication();
get => cachedDefault ??= new HuggingFaceAuthentication().LoadDefault();
internal set => cachedDefault = value;
}

/// <inheritdoc />
public override HuggingFaceAuthentication LoadFromAsset<T>()
=> Resources.LoadAll<T>(string.Empty)
.Where(asset => asset != null)
.Select(asset => asset is HuggingFaceConfiguration config && !string.IsNullOrWhiteSpace(config.ApiKey)
? new HuggingFaceAuthentication(config.ApiKey)
: null)
.FirstOrDefault();
public override HuggingFaceAuthentication LoadFromAsset(HuggingFaceConfiguration configuration = null)
{
if (configuration == null)
{
Debug.LogWarning($"This can be speed this up by passing a {nameof(HuggingFaceConfiguration)} to the {nameof(HuggingFaceAuthentication)}.ctr");
configuration = Resources.LoadAll<HuggingFaceConfiguration>(string.Empty).FirstOrDefault(o => o != null);
}

return configuration != null ? new HuggingFaceAuthentication(configuration) : null;
}

/// <inheritdoc />
public override HuggingFaceAuthentication LoadFromEnvironment()
Expand Down
29 changes: 21 additions & 8 deletions Runtime/HuggingFaceSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,21 @@ public class HuggingFaceSettings : ISettings<HuggingFaceSettingsInfo>
{
public HuggingFaceSettings()
{
if (cachedDefault != null) { return; }
Info = new HuggingFaceSettingsInfo();
cachedDefault = new HuggingFaceSettings(Info);
}

var config = Resources.LoadAll<HuggingFaceConfiguration>(string.Empty)
.FirstOrDefault(asset => asset != null);
public HuggingFaceSettings(HuggingFaceConfiguration configuration)
{
if (configuration == null)
{
Debug.LogWarning($"You can speed this up by passing a {nameof(HuggingFaceConfiguration)} to the {nameof(HuggingFaceSettings)}.ctr");
configuration = Resources.LoadAll<HuggingFaceConfiguration>(string.Empty).FirstOrDefault(asset => asset != null);
}

if (config != null)
if (configuration != null)
{
Info = new HuggingFaceSettingsInfo(config.ProxyDomain);
Info = new HuggingFaceSettingsInfo(configuration.ProxyDomain);
cachedDefault = new HuggingFaceSettings(Info);
}
else
Expand All @@ -28,16 +35,22 @@ public HuggingFaceSettings()
}

public HuggingFaceSettings(HuggingFaceSettingsInfo settingsInfo)
=> Info = settingsInfo;
{
Info = settingsInfo;
cachedDefault = this;
}

public HuggingFaceSettings(string domain)
=> Info = new HuggingFaceSettingsInfo(domain);
{
Info = new HuggingFaceSettingsInfo(domain);
cachedDefault = this;
}

private static HuggingFaceSettings cachedDefault;

public static HuggingFaceSettings Default
{
get => cachedDefault ?? new HuggingFaceSettings();
get => cachedDefault ?? new HuggingFaceSettings(configuration: null);
internal set => cachedDefault = value;
}

Expand Down
5 changes: 2 additions & 3 deletions Runtime/Inference/Audio/AudioToAudio/AudioToAudioResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ public override async Task DecodeAsync(CancellationToken cancellationToken = def
private static async Task DecodeAudioAsync(AudioToAudioResult result, CancellationToken cancellationToken)
{
await Rest.ValidateCacheDirectoryAsync();

Rest.TryGetDownloadCacheItem(result.Blob, out var guid);
var localFilePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}-{guid}.jpg");
var localFilePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}-{Path.GetFileName(guid)}.mp3");
var fileStream = new FileStream(localFilePath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);

try
Expand All @@ -58,7 +57,7 @@ private static async Task DecodeAudioAsync(AudioToAudioResult result, Cancellati
await fileStream.DisposeAsync();
}

result.AudioClip = await Rest.DownloadAudioClipAsync($"file://{localFilePath}", AudioType.WAV, parameters: null, cancellationToken: cancellationToken);
result.AudioClip = await Rest.DownloadAudioClipAsync($"file://{localFilePath}", AudioType.MPEG, parameters: null, cancellationToken: cancellationToken);
}
}
}
2 changes: 2 additions & 0 deletions Runtime/Inference/Audio/AudioToAudio/AudioToAudioTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@ public AudioToAudioTask(SingleSourceAudioInput input, ModelInfo model = null, In
}

public override string Id => "audio-to-audio";

public override string MimeType => "audio/mp3";
}
}
4 changes: 2 additions & 2 deletions Runtime/Inference/Audio/TextToSpeech/TextToSpeechResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public sealed class TextToSpeechResponse : BinaryInferenceTaskResponse
public override async Task DecodeAsync(Stream stream, CancellationToken cancellationToken = default)
{
await Rest.ValidateCacheDirectoryAsync();
var filePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}.wav");
var filePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}.mp3");
Debug.Log(filePath);
CachedPath = filePath;
var fileStream = new FileStream(filePath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);
Expand Down Expand Up @@ -46,7 +46,7 @@ public override async Task DecodeAsync(Stream stream, CancellationToken cancella
await fileStream.DisposeAsync();
}

AudioClip = await Rest.DownloadAudioClipAsync($"file://{filePath}", AudioType.WAV, parameters: null, cancellationToken: cancellationToken);
AudioClip = await Rest.DownloadAudioClipAsync($"file://{filePath}", AudioType.MPEG, parameters: null, cancellationToken: cancellationToken);
}
}
}
2 changes: 2 additions & 0 deletions Runtime/Inference/Audio/TextToSpeech/TextToSpeechTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ public TextToSpeechTask(string input, ModelInfo model = null, InferenceOptions o
public string Input { get; }

public override string Id => "text-to-speech";

public override string MimeType => "audio/mp3";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private static async Task DecodeImageAsync(ImageSegmentationResult result, Cance
{
await Rest.ValidateCacheDirectoryAsync();
Rest.TryGetDownloadCacheItem(result.Blob, out var guid);
var localFilePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}-{guid}.jpg");
var localFilePath = Path.Combine(Rest.DownloadCacheDirectory, $"{DateTime.UtcNow:yyyy-MM-ddTHH-mm-ssffff}-{Path.GetFileName(guid)}.jpg");
var fileStream = new FileStream(localFilePath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);

try
Expand Down
41 changes: 26 additions & 15 deletions Runtime/Inference/InferenceEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
using Newtonsoft.Json;
using System;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using UnityEngine;
using Utilities.Async;
using Utilities.WebRequestRest;

namespace HuggingFace.Inference
{
public sealed class InferenceEndpoint : HuggingFaceBaseEndpoint
{
public bool EnableLogging { get; set; } = true;

public int MaxRetryAttempts { get; set; } = 3;

public int Timeout { get; set; } = -1;
Expand All @@ -40,42 +38,50 @@ public async Task<TResponse> RunInferenceTaskAsync<TTask, TResponse>(TTask task,
}

var endpoint = GetInferenceUrl(task.Model.ModelId);

if (EnableDebug)
{
Debug.Log(endpoint);
}

Response response;
var attempt = 0;

async Task<Response> CallEndpointAsync()
{
try
{
var headers = client.DefaultRequestHeaders.ToDictionary(pair => pair.Key, pair => pair.Value);

if (!string.IsNullOrWhiteSpace(task.MimeType))
{
headers.Add("Accept", task.MimeType);
}

var jsonData = await task.ToJsonAsync(client.JsonSerializationOptions, cancellationToken).ConfigureAwait(true);

if (!string.IsNullOrWhiteSpace(jsonData))
{
if (EnableLogging)
if (EnableDebug)
{
Debug.Log(jsonData);
}

response = await Rest.PostAsync(endpoint, jsonData, parameters: new RestParameters(client.DefaultRequestHeaders, timeout: Timeout), cancellationToken);
response = await Rest.PostAsync(endpoint, jsonData, parameters: new RestParameters(headers, timeout: Timeout), cancellationToken);
}
else
{
var byteData = await task.ToByteArrayAsync(cancellationToken);
// TODO ensure proper accept headers are set here
response = await Rest.PostAsync(endpoint, byteData, parameters: new RestParameters(client.DefaultRequestHeaders, timeout: Timeout), cancellationToken);
response = await Rest.PostAsync(endpoint, byteData, parameters: new RestParameters(headers, timeout: Timeout), cancellationToken);
}

response.Validate(EnableLogging);
response.Validate(EnableDebug);
}
catch (RestException restEx)
{
if (restEx.Response.Code == 503 &&
task.Options.WaitForModel)
if (restEx.Response.Code == 503 && task.Options.WaitForModel)
{
if (++attempt == MaxRetryAttempts)
{
throw;
}
if (++attempt == MaxRetryAttempts) { throw; }

HuggingFaceError error;

Expand All @@ -89,7 +95,7 @@ async Task<Response> CallEndpointAsync()
throw restEx;
}

if (EnableLogging)
if (EnableDebug)
{
Debug.LogWarning($"Waiting for model for {error.EstimatedTime} seconds... attempt {attempt} of {MaxRetryAttempts}\n{restEx}");
}
Expand Down Expand Up @@ -126,6 +132,11 @@ async Task<Response> CallEndpointAsync()

if (binaryResponse is BinaryInferenceTaskResponse taskResponse)
{
if (response.Headers.TryGetValue("Content-Type", out var contentType))
{
Debug.Log($"{typeof(TResponse).Name} Content-Type: {contentType}");
}

await using var contentStream = new MemoryStream(response.Data);

try
Expand Down
9 changes: 8 additions & 1 deletion Runtime/Inference/InferenceTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@ public abstract class InferenceTask
{
protected InferenceTask() { }

protected InferenceTask(ModelInfo model, InferenceOptions options)
protected InferenceTask(ModelInfo model, InferenceOptions options, Action<string> streamCallback = null)
{
Model = model;
Options = options ?? new InferenceOptions();
Stream = streamCallback != null;
}

[JsonIgnore]
public abstract string Id { get; }

[JsonIgnore]
public virtual string MimeType { get; } = string.Empty;

[JsonIgnore]
public ModelInfo Model { get; internal set; }

[JsonProperty("options")]
public InferenceOptions Options { get; }

[JsonProperty("stream")]
public bool Stream { get; }

public virtual Task<string> ToJsonAsync(JsonSerializerSettings settings, CancellationToken cancellationToken)
=> Task.FromResult(string.Empty);

Expand Down
17 changes: 17 additions & 0 deletions Tests/AbstractTestFixture.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace HuggingFace.Tests
{
internal abstract class AbstractTestFixture
{
protected readonly HuggingFaceClient HuggingFaceClient;

public AbstractTestFixture()
{
var auth = new HuggingFaceAuthentication().LoadDefaultsReversed();
var settings = new HuggingFaceSettings();
HuggingFaceClient = new HuggingFaceClient(auth, settings);
//HuggingFaceClient.EnableDebug = true;
}
}
}
11 changes: 11 additions & 0 deletions Tests/AbstractTestFixture.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2e340d3

Please sign in to comment.