From 6c25f99b58b1ebccf3c27d283220ba46606d8d47 Mon Sep 17 00:00:00 2001 From: Peter James Date: Mon, 26 Feb 2024 20:42:51 -0800 Subject: [PATCH] fix: handful of bugs in OpenAI image generation --- .../src/AzureOpenAiImageGenerationModel.cs | 36 ++++++---- .../OpenAiImageGenerationModel.cs | 66 +++++++++++-------- .../OpenAiImageGenerationSettings.cs | 53 +++++++++++---- .../src/Predefined/ImageGenerationModels.cs | 4 +- 4 files changed, 104 insertions(+), 55 deletions(-) diff --git a/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs b/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs index 35a73f0e..1710f5f1 100644 --- a/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs +++ b/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs @@ -3,11 +3,23 @@ namespace LangChain.Providers.Azure; -public class AzureOpenAiImageGenerationModel( - string id, - AzureOpenAiProvider provider) - : ImageGenerationModel(id), IImageGenerationModel +public class AzureOpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel { + private readonly AzureOpenAiProvider _provider; + private readonly ImageModels _model; + + /// + /// + /// + /// + /// + public AzureOpenAiImageGenerationModel(AzureOpenAiProvider provider, string id) + : base(id) + { + _provider = provider; + _model = new(id); + } + /// /// Azure responds with a revised prompt if it changed it during generation, this property contains that prompt. Only relevant when Dall-E-3 model is used. /// @@ -26,17 +38,19 @@ public async Task GenerateImageAsync( CancellationToken cancellationToken = default) { request = request ?? throw new ArgumentNullException(nameof(request)); - + if (GenerationOptions != null && GenerationOptions.ImageCount != 1) { throw new NotSupportedException("Currently only 1 image is supported"); } - + var usedSettings = OpenAiImageGenerationSettings.Calculate( requestSettings: settings, modelSettings: Settings, - providerSettings: provider.ImageGenerationSettings); - var response = await provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions + providerSettings: _provider.ImageGenerationSettings, + defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model)); + + var response = await _provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions { DeploymentName = Id, ImageCount = 1, //currently hardcoded to 1 @@ -50,7 +64,7 @@ public async Task GenerateImageAsync( var usage = Usage.Empty with { //Todo: Usage might be off when setting different parameters in GenerationOptions - PriceInUsd = ImageModels.DallE3.GetPriceInUsd( + PriceInUsd = _model.GetPriceInUsd( resolution: usedSettings.Resolution!.Value, quality: usedSettings.Quality!.Value), }; @@ -58,11 +72,11 @@ public async Task GenerateImageAsync( var firstImage = response.Value.Data[0]; RevisedPromptResult = firstImage.RevisedPrompt; - + var bytes = Convert.FromBase64String( firstImage.Base64Data ?? throw new InvalidOperationException("B64_json is null")); - + return new ImageGenerationResponse { Bytes = bytes, diff --git a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs index c03f1f8b..14ed89ee 100644 --- a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs +++ b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs @@ -5,16 +5,24 @@ // ReSharper disable once CheckNamespace namespace LangChain.Providers.OpenAI; -/// -/// -/// -/// -/// -public class OpenAiImageGenerationModel( - OpenAiProvider provider, - string id) - : ImageGenerationModel(id), IImageGenerationModel + +public class OpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel { + private readonly OpenAiProvider _provider; + private readonly ImageModels _model; + + /// + /// + /// + /// + /// + public OpenAiImageGenerationModel(OpenAiProvider provider, string id) + : base(id) + { + _provider = provider; + _model = new(id); + } + /// public async Task GenerateImageAsync( ImageGenerationRequest request, @@ -22,14 +30,16 @@ public async Task GenerateImageAsync( CancellationToken cancellationToken = default) { request = request ?? throw new ArgumentNullException(nameof(request)); - + OnPromptSent(request.Prompt); var usedSettings = OpenAiImageGenerationSettings.Calculate( requestSettings: settings, modelSettings: Settings, - providerSettings: provider.ImageGenerationSettings); - var response = await provider.Api.ImagesEndPoint.GenerateImageAsync( + providerSettings: _provider.ImageGenerationSettings, + defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model)); + + var response = await _provider.Api.ImagesEndPoint.GenerateImageAsync( request: new global::OpenAI.Images.ImageGenerationRequest( prompt: request.Prompt, model: new global::OpenAI.Models.Model(Id, "openai"), @@ -42,32 +52,32 @@ public async Task GenerateImageAsync( var usage = Usage.Empty with { - PriceInUsd = ImageModels.DallE3.GetPriceInUsd( + PriceInUsd = _model.GetPriceInUsd( resolution: usedSettings.Resolution!.Value, quality: usedSettings.Quality), }; AddUsage(usage); - provider.AddUsage(usage); + _provider.AddUsage(usage); switch (usedSettings.ResponseFormat) { case ResponseFormat.Url: - { - using var client = new HttpClient(); + { + using var client = new HttpClient(); #if NET6_0_OR_GREATER - var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false); + var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false); #else - var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false); + var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false); #endif - - return new ImageGenerationResponse - { - Bytes = bytes, - Usage = usage, - UsedSettings = usedSettings, - }; - } - + + return new ImageGenerationResponse + { + Bytes = bytes, + Usage = usage, + UsedSettings = usedSettings, + }; + } + case ResponseFormat.B64_Json: return new ImageGenerationResponse { @@ -77,7 +87,7 @@ public async Task GenerateImageAsync( Usage = usage, UsedSettings = usedSettings, }; - + default: throw new NotImplementedException("ResponseFormat not implemented."); } diff --git a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs index bee3a2d1..cdcbbd99 100644 --- a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs +++ b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs @@ -12,14 +12,36 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings /// /// /// - public new static OpenAiImageGenerationSettings Default { get; } = new() + public static OpenAiImageGenerationSettings GetDefault(string id) { - NumberOfResults = 1, - Quality = ImageQualities.Standard, - ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json, - Resolution = ImageResolutions._256x256, - User = string.Empty, - }; + if (id == ImageModels.DallE2) + { + return new() + { + NumberOfResults = 1, + Quality = ImageQualities.Standard, + ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json, + Resolution = ImageResolutions._256x256, + User = string.Empty, + }; + } + else if (id == ImageModels.DallE3) + { + return new() + { + NumberOfResults = 1, + Quality = ImageQualities.Standard, + ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json, + Resolution = ImageResolutions._1024x1024, + User = string.Empty, + }; + } + else + { + throw new NotSupportedException($"OpenAI model {id} is not supported"); + } + } + /// /// @@ -43,7 +65,7 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings /// [CLSCompliant(false)] public ImageResolutions? Resolution { get; init; } - + /// /// /// @@ -55,16 +77,19 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings /// /// /// + /// /// /// public static OpenAiImageGenerationSettings Calculate( ImageGenerationSettings? requestSettings, ImageGenerationSettings? modelSettings, - ImageGenerationSettings? providerSettings) + ImageGenerationSettings? providerSettings, + ImageGenerationSettings? defaultSettings) { var requestSettingsCasted = requestSettings as OpenAiImageGenerationSettings; var modelSettingsCasted = modelSettings as OpenAiImageGenerationSettings; var providerSettingsCasted = providerSettings as OpenAiImageGenerationSettings; + var defaultSettingsCasted = defaultSettings as OpenAiImageGenerationSettings; return new OpenAiImageGenerationSettings { @@ -72,31 +97,31 @@ public static OpenAiImageGenerationSettings Calculate( requestSettingsCasted?.NumberOfResults ?? modelSettingsCasted?.NumberOfResults ?? providerSettingsCasted?.NumberOfResults ?? - Default.NumberOfResults ?? + defaultSettingsCasted?.NumberOfResults ?? throw new InvalidOperationException("Default NumberOfResults is not set."), Quality = requestSettingsCasted?.Quality ?? modelSettingsCasted?.Quality ?? providerSettingsCasted?.Quality ?? - Default.Quality ?? + defaultSettingsCasted?.Quality ?? throw new InvalidOperationException("Default Quality is not set."), ResponseFormat = requestSettingsCasted?.ResponseFormat ?? modelSettingsCasted?.ResponseFormat ?? providerSettingsCasted?.ResponseFormat ?? - Default.ResponseFormat ?? + defaultSettingsCasted?.ResponseFormat ?? throw new InvalidOperationException("Default ResponseFormat is not set."), Resolution = requestSettingsCasted?.Resolution ?? modelSettingsCasted?.Resolution ?? providerSettingsCasted?.Resolution ?? - Default.Resolution ?? + defaultSettingsCasted?.Resolution ?? throw new InvalidOperationException("Default Resolution is not set."), User = requestSettingsCasted?.User ?? modelSettingsCasted?.User ?? providerSettingsCasted?.User ?? - Default.User ?? + defaultSettingsCasted?.User ?? throw new InvalidOperationException("Default User is not set."), }; } diff --git a/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs b/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs index d42cccd6..db00b809 100644 --- a/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs +++ b/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs @@ -4,8 +4,8 @@ namespace LangChain.Providers.OpenAI.Predefined; /// public class DallE2Model(OpenAiProvider provider) - : OpenAiTextToSpeechModel(provider, id: ImageModels.DallE2); + : OpenAiImageGenerationModel(provider, id: ImageModels.DallE2); /// public class DallE3Model(OpenAiProvider provider) - : OpenAiTextToSpeechModel(provider, id: ImageModels.DallE3); \ No newline at end of file + : OpenAiImageGenerationModel(provider, id: ImageModels.DallE3); \ No newline at end of file