Skip to content

Commit

Permalink
Merge branch 'main' of github.com:curlyfro/LangChain
Browse files Browse the repository at this point in the history
  • Loading branch information
curlyfro committed Feb 28, 2024
2 parents 80d558a + 0e96b06 commit bb4dd2c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 55 deletions.
36 changes: 25 additions & 11 deletions src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@

namespace LangChain.Providers.Azure;

public class AzureOpenAiImageGenerationModel(
string id,
AzureOpenAiProvider provider)
: ImageGenerationModel(id), IImageGenerationModel
public class AzureOpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel
{
private readonly AzureOpenAiProvider _provider;
private readonly ImageModels _model;

/// <summary>
///
/// </summary>
/// <param name="provider"></param>
/// <param name="id"></param>
public AzureOpenAiImageGenerationModel(AzureOpenAiProvider provider, string id)
: base(id)
{
_provider = provider;
_model = new(id);
}

/// <summary>
/// Azure responds with a revised prompt if it changed it during generation, this property contains that prompt. Only relevant when Dall-E-3 model is used.
/// </summary>
Expand All @@ -26,17 +38,19 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

if (GenerationOptions != null && GenerationOptions.ImageCount != 1)
{
throw new NotSupportedException("Currently only 1 image is supported");
}

var usedSettings = OpenAiImageGenerationSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ImageGenerationSettings);
var response = await provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions
providerSettings: _provider.ImageGenerationSettings,
defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model));

var response = await _provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions
{
DeploymentName = Id,
ImageCount = 1, //currently hardcoded to 1
Expand All @@ -50,19 +64,19 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(
var usage = Usage.Empty with
{
//Todo: Usage might be off when setting different parameters in GenerationOptions
PriceInUsd = ImageModels.DallE3.GetPriceInUsd(
PriceInUsd = _model.GetPriceInUsd(
resolution: usedSettings.Resolution!.Value,
quality: usedSettings.Quality!.Value),
};
AddUsage(usage);

var firstImage = response.Value.Data[0];
RevisedPromptResult = firstImage.RevisedPrompt;

var bytes = Convert.FromBase64String(
firstImage.Base64Data ??
throw new InvalidOperationException("B64_json is null"));

return new ImageGenerationResponse
{
Bytes = bytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,41 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.OpenAI;

/// <summary>
///
/// </summary>
/// <param name="provider"></param>
/// <param name="id"></param>
public class OpenAiImageGenerationModel(
OpenAiProvider provider,
string id)
: ImageGenerationModel(id), IImageGenerationModel

public class OpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel
{
private readonly OpenAiProvider _provider;
private readonly ImageModels _model;

/// <summary>
///
/// </summary>
/// <param name="provider"></param>
/// <param name="id"></param>
public OpenAiImageGenerationModel(OpenAiProvider provider, string id)
: base(id)
{
_provider = provider;
_model = new(id);
}

/// <inheritdoc/>
public async Task<ImageGenerationResponse> GenerateImageAsync(
ImageGenerationRequest request,
ImageGenerationSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

OnPromptSent(request.Prompt);

var usedSettings = OpenAiImageGenerationSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ImageGenerationSettings);
var response = await provider.Api.ImagesEndPoint.GenerateImageAsync(
providerSettings: _provider.ImageGenerationSettings,
defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model));

var response = await _provider.Api.ImagesEndPoint.GenerateImageAsync(
request: new global::OpenAI.Images.ImageGenerationRequest(
prompt: request.Prompt,
model: new global::OpenAI.Models.Model(Id, "openai"),
Expand All @@ -42,32 +52,32 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(

var usage = Usage.Empty with
{
PriceInUsd = ImageModels.DallE3.GetPriceInUsd(
PriceInUsd = _model.GetPriceInUsd(
resolution: usedSettings.Resolution!.Value,
quality: usedSettings.Quality),
};
AddUsage(usage);
provider.AddUsage(usage);
_provider.AddUsage(usage);

switch (usedSettings.ResponseFormat)
{
case ResponseFormat.Url:
{
using var client = new HttpClient();
{
using var client = new HttpClient();
#if NET6_0_OR_GREATER
var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false);
var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false);
#else
var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false);
var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false);
#endif
return new ImageGenerationResponse
{
Bytes = bytes,
Usage = usage,
UsedSettings = usedSettings,
};
}

return new ImageGenerationResponse
{
Bytes = bytes,
Usage = usage,
UsedSettings = usedSettings,
};
}

case ResponseFormat.B64_Json:
return new ImageGenerationResponse
{
Expand All @@ -77,7 +87,7 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(
Usage = usage,
UsedSettings = usedSettings,
};

default:
throw new NotImplementedException("ResponseFormat not implemented.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,36 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
/// <summary>
///
/// </summary>
public new static OpenAiImageGenerationSettings Default { get; } = new()
public static OpenAiImageGenerationSettings GetDefault(string id)

Check warning on line 15 in src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The property name 'Default' is confusing given the existence of method 'GetDefault'. Rename or remove one of these members. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1721)

Check warning on line 15 in src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The property name 'Default' is confusing given the existence of method 'GetDefault'. Rename or remove one of these members. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1721)

Check warning on line 15 in src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The property name 'Default' is confusing given the existence of method 'GetDefault'. Rename or remove one of these members. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1721)
{
NumberOfResults = 1,
Quality = ImageQualities.Standard,
ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
Resolution = ImageResolutions._256x256,
User = string.Empty,
};
if (id == ImageModels.DallE2)
{
return new()
{
NumberOfResults = 1,
Quality = ImageQualities.Standard,
ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
Resolution = ImageResolutions._256x256,
User = string.Empty,
};
}
else if (id == ImageModels.DallE3)
{
return new()
{
NumberOfResults = 1,
Quality = ImageQualities.Standard,
ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
Resolution = ImageResolutions._1024x1024,
User = string.Empty,
};
}
else
{
throw new NotSupportedException($"OpenAI model {id} is not supported");
}
}


/// <summary>
///
Expand All @@ -43,7 +65,7 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
/// </summary>
[CLSCompliant(false)]
public ImageResolutions? Resolution { get; init; }

/// <summary>
///
/// </summary>
Expand All @@ -55,48 +77,51 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <param name="defaultSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static OpenAiImageGenerationSettings Calculate(
ImageGenerationSettings? requestSettings,
ImageGenerationSettings? modelSettings,
ImageGenerationSettings? providerSettings)
ImageGenerationSettings? providerSettings,
ImageGenerationSettings? defaultSettings)
{
var requestSettingsCasted = requestSettings as OpenAiImageGenerationSettings;
var modelSettingsCasted = modelSettings as OpenAiImageGenerationSettings;
var providerSettingsCasted = providerSettings as OpenAiImageGenerationSettings;
var defaultSettingsCasted = defaultSettings as OpenAiImageGenerationSettings;

return new OpenAiImageGenerationSettings
{
NumberOfResults =
requestSettingsCasted?.NumberOfResults ??
modelSettingsCasted?.NumberOfResults ??
providerSettingsCasted?.NumberOfResults ??
Default.NumberOfResults ??
defaultSettingsCasted?.NumberOfResults ??
throw new InvalidOperationException("Default NumberOfResults is not set."),
Quality =
requestSettingsCasted?.Quality ??
modelSettingsCasted?.Quality ??
providerSettingsCasted?.Quality ??
Default.Quality ??
defaultSettingsCasted?.Quality ??
throw new InvalidOperationException("Default Quality is not set."),
ResponseFormat =
requestSettingsCasted?.ResponseFormat ??
modelSettingsCasted?.ResponseFormat ??
providerSettingsCasted?.ResponseFormat ??
Default.ResponseFormat ??
defaultSettingsCasted?.ResponseFormat ??
throw new InvalidOperationException("Default ResponseFormat is not set."),
Resolution =
requestSettingsCasted?.Resolution ??
modelSettingsCasted?.Resolution ??
providerSettingsCasted?.Resolution ??
Default.Resolution ??
defaultSettingsCasted?.Resolution ??
throw new InvalidOperationException("Default Resolution is not set."),
User =
requestSettingsCasted?.User ??
modelSettingsCasted?.User ??
providerSettingsCasted?.User ??
Default.User ??
defaultSettingsCasted?.User ??
throw new InvalidOperationException("Default User is not set."),
};
}
Expand Down
4 changes: 2 additions & 2 deletions src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ namespace LangChain.Providers.OpenAI.Predefined;

/// <inheritdoc cref="ImageModels.DallE2" />
public class DallE2Model(OpenAiProvider provider)
: OpenAiTextToSpeechModel(provider, id: ImageModels.DallE2);
: OpenAiImageGenerationModel(provider, id: ImageModels.DallE2);

/// <inheritdoc cref="ImageModels.DallE3" />
public class DallE3Model(OpenAiProvider provider)
: OpenAiTextToSpeechModel(provider, id: ImageModels.DallE3);
: OpenAiImageGenerationModel(provider, id: ImageModels.DallE3);

0 comments on commit bb4dd2c

Please sign in to comment.