Skip to content

Commit

Permalink
feat: Added ImageToText abstractions and HuggingFace implementation. (#…
Browse files Browse the repository at this point in the history
…152)

* fix: changed Titan's TextToImage to support images

* ImageToText working.  needs to get refactored and cleaned

* feat: Added ImageToText abstractions and HuggingFace implementation.  also added example to HF sample

* fix: remove postgres tests from bedrock tests

* feat: Added ImageToTextGenerationChain
  • Loading branch information
curlyfro authored Mar 2, 2024
1 parent e503560 commit 3290566
Show file tree
Hide file tree
Showing 21 changed files with 434 additions and 22 deletions.
25 changes: 22 additions & 3 deletions examples/LangChain.Samples.HuggingFace/Program.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
using LangChain.Providers.HuggingFace;
using LangChain.Providers;
using LangChain.Providers.HuggingFace;
using LangChain.Providers.HuggingFace.Predefined;

using var client = new HttpClient();
var provider = new HuggingFaceProvider(apiKey: string.Empty, client);
var gpt2Model = new Gpt2Model(provider);

var response = await gpt2Model.GenerateAsync("What would be a good company name be for name a company that makes colorful socks?");
var gp2ModelResponse = await gpt2Model.GenerateAsync("What would be a good company name be for name a company that makes colorful socks?");

Console.WriteLine(response);
Console.WriteLine("### GP2 Response");
Console.WriteLine(gp2ModelResponse);

const string imageToTextModel = "Salesforce/blip-image-captioning-base";
var model = new HuggingFaceImageToTextModel(provider, imageToTextModel);

var path = Path.Combine(Path.GetTempPath(), "solar_system.png");
var imageData = await File.ReadAllBytesAsync(path);
var binaryData = new BinaryData(imageData, "image/jpg");

var imageToTextResponse = await model.GenerateTextFromImageAsync(new ImageToTextRequest
{
Image = binaryData
});

Console.WriteLine("\n\n### ImageToText Response");
Console.WriteLine(imageToTextResponse.Text);

Console.ReadLine();
16 changes: 16 additions & 0 deletions src/Core/src/Chains/Chain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using LangChain.Chains.StackableChains.Agents.Crew;
using LangChain.Chains.StackableChains.Files;
using LangChain.Chains.StackableChains.ImageGeneration;
using LangChain.Chains.StackableChains.ImageToTextGeneration;
using LangChain.Chains.StackableChains.ReAct;
using LangChain.Indexes;
using LangChain.Memory;
Expand Down Expand Up @@ -298,4 +299,19 @@ public static ExtractCodeChain ExtractCode(
{
return new ExtractCodeChain(inputKey, outputKey);
}

/// <summary>
///
/// </summary>
/// <param name="model"></param>
/// <param name="image"></param>
/// <param name="outputKey"></param>
/// <returns></returns>
public static ImageToTextGenerationChain GenerateImageToText(
IImageToTextModel model,
BinaryData image,
string outputKey = "text")
{
return new ImageToTextGenerationChain(model, image, outputKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Providers;

namespace LangChain.Chains.StackableChains.ImageToTextGeneration;

/// <summary>
///
/// </summary>
public class ImageToTextGenerationChain : BaseStackableChain
{
private readonly IImageToTextModel _model;
private readonly BinaryData _image;

/// <summary>
///
/// </summary>
/// <param name="model"></param>
/// <param name="image"></param>
/// <param name="outputKey"></param>
public ImageToTextGenerationChain(
IImageToTextModel model,
BinaryData image,
string outputKey = "text")
{
_model = model;
_image = image;
OutputKeys = new[] { outputKey };
}

/// <inheritdoc />
protected override async Task<IChainValues> InternalCall(IChainValues values)
{
values = values ?? throw new ArgumentNullException(nameof(values));

var text = await _model.GenerateTextFromImageAsync(new ImageToTextRequest { Image = _image }).ConfigureAwait(false);
values.Value[OutputKeys[0]] = text;
return values;
}
}
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageVersion>
<PackageVersion Include="StackExchange.Redis" Version="2.7.20" />
<PackageVersion Include="System.Memory.Data" Version="8.0.0" />
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Text.Json" Version="8.0.0" />
<PackageVersion Include="System.ValueTuple" Version="4.5.0" />
Expand Down
3 changes: 3 additions & 0 deletions src/Providers/Abstractions/src/Common/Provider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ public abstract class Provider(string id) : Model(id), IProvider

/// <inheritdoc />
public TextToSpeechSettings? TextToSpeechSettings { get; init; }

/// <inheritdoc />
public ImageToTextSettings? ImageToTextSettings { get; init; }
}
19 changes: 19 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/IImageToTextModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LangChain.Providers;

/// <summary>
/// Defines a large language model that can be used for image to text generation.
/// </summary>
public interface IImageToTextModel : IModel<ImageToTextSettings>

Check warning on line 6 in src/Providers/Abstractions/src/ImageToText/IImageToTextModel.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'IImageToTextModel' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Run the LLM on the given image.
/// </summary>
/// <param name="request"></param>
/// <param name="settings"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public Task<ImageToTextResponse> GenerateTextFromImageAsync(
ImageToTextRequest request,
ImageToTextSettings? settings = null,
CancellationToken cancellationToken = default);
}
19 changes: 19 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/IImageToTextModel`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LangChain.Providers;

/// <summary>
/// Defines a large language model that can be used for image to text generation.
/// </summary>
public interface IImageToTextModel<in TRequest, TResponse, in TSettings> : IImageToTextModel

Check warning on line 6 in src/Providers/Abstractions/src/ImageToText/IImageToTextModel`2.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'IImageToTextModel<TRequest, TResponse, TSettings>' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Run the LLM on the image.
/// </summary>
/// <param name="request"></param>
/// <param name="settings"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public Task<TResponse> GenerateTextFromImageAsync(
TRequest request,
TSettings? settings = default,
CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System.Text.Json.Serialization;

namespace LangChain.Providers;

public class ImageToTextGenerationResponse : List<ImageToTextGenerationResponse.GeneratedTextItem>

Check warning on line 5 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'ImageToTextGenerationResponse' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)

Check warning on line 5 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'implicit constructor for 'ImageToTextGenerationResponse'' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
public sealed class GeneratedTextItem

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Do not nest type GeneratedTextItem. Alternatively, change its accessibility so that it is not externally visible. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1034)

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'GeneratedTextItem' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextGenerationResponse.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'implicit constructor for 'GeneratedTextItem'' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// The continuated string
/// </summary>
[JsonPropertyName("generated_text")]
public string? GeneratedText { get; set; }
}
}
10 changes: 10 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

public abstract class ImageToTextModel(string id) : Model<ImageToTextSettings>(id), IImageToTextModel<ImageToTextRequest, ImageToTextResponse, ImageToTextSettings>

Check warning on line 4 in src/Providers/Abstractions/src/ImageToText/ImageToTextModel.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

{
public abstract Task<ImageToTextResponse> GenerateTextFromImageAsync(
ImageToTextRequest request,
ImageToTextSettings? settings = default,
CancellationToken cancellationToken = default);
}
13 changes: 13 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

/// <summary>
/// Base class for image to text requests.
/// </summary>
public class ImageToTextRequest

Check warning on line 7 in src/Providers/Abstractions/src/ImageToText/ImageToTextRequest.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Symbol 'ImageToTextRequest' is not part of the declared public API (https://github.com/dotnet/roslyn-analyzers/blob/main/src/PublicApiAnalyzers/PublicApiAnalyzers.Help.md)
{
/// <summary>
/// Image to upload.
/// </summary>
public required BinaryData Image { get; init; }
}
27 changes: 27 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// ReSharper disable once CheckNamespace
// ReSharper disable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract
namespace LangChain.Providers;

#pragma warning disable CA2225

/// <summary>
///
/// </summary>
public class ImageToTextResponse
{
/// <summary>
///
/// </summary>
public required ImageToTextSettings UsedSettings { get; init; }

/// <summary>
///
/// </summary>
public Usage Usage { get; init; } = Usage.Empty;


/// <summary>
/// Generated text
/// </summary>
public string? Text { get; set; }
}
55 changes: 55 additions & 0 deletions src/Providers/Abstractions/src/ImageToText/ImageToTextSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers;

/// <summary>
/// Base class for image to text request settings.
/// </summary>
public class ImageToTextSettings
{
public static ImageToTextSettings Default { get; } = new()
{
User = string.Empty,
Endpoint = "https://api-inference.huggingface.co/models/"
};

/// <summary>
/// Unique user identifier.
/// </summary>
public string? User { get; init; }

/// <summary>
/// Endpoint url for api.
/// </summary>
public string Endpoint { get; set; }

Check warning on line 23 in src/Providers/Abstractions/src/ImageToText/ImageToTextSettings.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.


/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static ImageToTextSettings Calculate(
ImageToTextSettings? requestSettings,
ImageToTextSettings? modelSettings,
ImageToTextSettings? providerSettings)
{
return new ImageToTextSettings
{
User =
requestSettings?.User ??
modelSettings?.User ??
providerSettings?.User ??
Default.User ??
throw new InvalidOperationException("Default User is not set."),
Endpoint =
requestSettings?.Endpoint ??
modelSettings?.Endpoint ??
providerSettings?.Endpoint ??
Default.Endpoint ??
throw new InvalidOperationException("Default Endpoint is not set."),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="System.Memory.Data" />
</ItemGroup>

</Project>
Loading

0 comments on commit 3290566

Please sign in to comment.