diff --git a/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs b/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs
index 35a73f0e..1710f5f1 100644
--- a/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs
+++ b/src/Providers/Azure/src/AzureOpenAiImageGenerationModel.cs
@@ -3,11 +3,23 @@
namespace LangChain.Providers.Azure;
-public class AzureOpenAiImageGenerationModel(
- string id,
- AzureOpenAiProvider provider)
- : ImageGenerationModel(id), IImageGenerationModel
+public class AzureOpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel
{
+ private readonly AzureOpenAiProvider _provider;
+ private readonly ImageModels _model;
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ public AzureOpenAiImageGenerationModel(AzureOpenAiProvider provider, string id)
+ : base(id)
+ {
+ _provider = provider;
+ _model = new(id);
+ }
+
///
/// Azure responds with a revised prompt if it changed it during generation, this property contains that prompt. Only relevant when Dall-E-3 model is used.
///
@@ -26,17 +38,19 @@ public async Task GenerateImageAsync(
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
-
+
if (GenerationOptions != null && GenerationOptions.ImageCount != 1)
{
throw new NotSupportedException("Currently only 1 image is supported");
}
-
+
var usedSettings = OpenAiImageGenerationSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
- providerSettings: provider.ImageGenerationSettings);
- var response = await provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions
+ providerSettings: _provider.ImageGenerationSettings,
+ defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model));
+
+ var response = await _provider.Client.GetImageGenerationsAsync(GenerationOptions ?? new ImageGenerationOptions
{
DeploymentName = Id,
ImageCount = 1, //currently hardcoded to 1
@@ -50,7 +64,7 @@ public async Task GenerateImageAsync(
var usage = Usage.Empty with
{
//Todo: Usage might be off when setting different parameters in GenerationOptions
- PriceInUsd = ImageModels.DallE3.GetPriceInUsd(
+ PriceInUsd = _model.GetPriceInUsd(
resolution: usedSettings.Resolution!.Value,
quality: usedSettings.Quality!.Value),
};
@@ -58,11 +72,11 @@ public async Task GenerateImageAsync(
var firstImage = response.Value.Data[0];
RevisedPromptResult = firstImage.RevisedPrompt;
-
+
var bytes = Convert.FromBase64String(
firstImage.Base64Data ??
throw new InvalidOperationException("B64_json is null"));
-
+
return new ImageGenerationResponse
{
Bytes = bytes,
diff --git a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs
index c03f1f8b..14ed89ee 100644
--- a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs
+++ b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationModel.cs
@@ -5,16 +5,24 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.OpenAI;
-///
-///
-///
-///
-///
-public class OpenAiImageGenerationModel(
- OpenAiProvider provider,
- string id)
- : ImageGenerationModel(id), IImageGenerationModel
+
+public class OpenAiImageGenerationModel : ImageGenerationModel, IImageGenerationModel
{
+ private readonly OpenAiProvider _provider;
+ private readonly ImageModels _model;
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ public OpenAiImageGenerationModel(OpenAiProvider provider, string id)
+ : base(id)
+ {
+ _provider = provider;
+ _model = new(id);
+ }
+
///
public async Task GenerateImageAsync(
ImageGenerationRequest request,
@@ -22,14 +30,16 @@ public async Task GenerateImageAsync(
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
-
+
OnPromptSent(request.Prompt);
var usedSettings = OpenAiImageGenerationSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
- providerSettings: provider.ImageGenerationSettings);
- var response = await provider.Api.ImagesEndPoint.GenerateImageAsync(
+ providerSettings: _provider.ImageGenerationSettings,
+ defaultSettings: OpenAiImageGenerationSettings.GetDefault(_model));
+
+ var response = await _provider.Api.ImagesEndPoint.GenerateImageAsync(
request: new global::OpenAI.Images.ImageGenerationRequest(
prompt: request.Prompt,
model: new global::OpenAI.Models.Model(Id, "openai"),
@@ -42,32 +52,32 @@ public async Task GenerateImageAsync(
var usage = Usage.Empty with
{
- PriceInUsd = ImageModels.DallE3.GetPriceInUsd(
+ PriceInUsd = _model.GetPriceInUsd(
resolution: usedSettings.Resolution!.Value,
quality: usedSettings.Quality),
};
AddUsage(usage);
- provider.AddUsage(usage);
+ _provider.AddUsage(usage);
switch (usedSettings.ResponseFormat)
{
case ResponseFormat.Url:
- {
- using var client = new HttpClient();
+ {
+ using var client = new HttpClient();
#if NET6_0_OR_GREATER
- var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false);
+ var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url), cancellationToken).ConfigureAwait(false);
#else
- var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false);
+ var bytes = await client.GetByteArrayAsync(new Uri(response[0].Url)).ConfigureAwait(false);
#endif
-
- return new ImageGenerationResponse
- {
- Bytes = bytes,
- Usage = usage,
- UsedSettings = usedSettings,
- };
- }
-
+
+ return new ImageGenerationResponse
+ {
+ Bytes = bytes,
+ Usage = usage,
+ UsedSettings = usedSettings,
+ };
+ }
+
case ResponseFormat.B64_Json:
return new ImageGenerationResponse
{
@@ -77,7 +87,7 @@ public async Task GenerateImageAsync(
Usage = usage,
UsedSettings = usedSettings,
};
-
+
default:
throw new NotImplementedException("ResponseFormat not implemented.");
}
diff --git a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs
index bee3a2d1..cdcbbd99 100644
--- a/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs
+++ b/src/Providers/OpenAI/src/ImageGeneration/OpenAiImageGenerationSettings.cs
@@ -12,14 +12,36 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
///
///
///
- public new static OpenAiImageGenerationSettings Default { get; } = new()
+ public static OpenAiImageGenerationSettings GetDefault(string id)
{
- NumberOfResults = 1,
- Quality = ImageQualities.Standard,
- ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
- Resolution = ImageResolutions._256x256,
- User = string.Empty,
- };
+ if (id == ImageModels.DallE2)
+ {
+ return new()
+ {
+ NumberOfResults = 1,
+ Quality = ImageQualities.Standard,
+ ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
+ Resolution = ImageResolutions._256x256,
+ User = string.Empty,
+ };
+ }
+ else if (id == ImageModels.DallE3)
+ {
+ return new()
+ {
+ NumberOfResults = 1,
+ Quality = ImageQualities.Standard,
+ ResponseFormat = global::OpenAI.Images.ResponseFormat.B64_Json,
+ Resolution = ImageResolutions._1024x1024,
+ User = string.Empty,
+ };
+ }
+ else
+ {
+ throw new NotSupportedException($"OpenAI model {id} is not supported");
+ }
+ }
+
///
///
@@ -43,7 +65,7 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
///
[CLSCompliant(false)]
public ImageResolutions? Resolution { get; init; }
-
+
///
///
///
@@ -55,16 +77,19 @@ public class OpenAiImageGenerationSettings : ImageGenerationSettings
///
///
///
+ ///
///
///
public static OpenAiImageGenerationSettings Calculate(
ImageGenerationSettings? requestSettings,
ImageGenerationSettings? modelSettings,
- ImageGenerationSettings? providerSettings)
+ ImageGenerationSettings? providerSettings,
+ ImageGenerationSettings? defaultSettings)
{
var requestSettingsCasted = requestSettings as OpenAiImageGenerationSettings;
var modelSettingsCasted = modelSettings as OpenAiImageGenerationSettings;
var providerSettingsCasted = providerSettings as OpenAiImageGenerationSettings;
+ var defaultSettingsCasted = defaultSettings as OpenAiImageGenerationSettings;
return new OpenAiImageGenerationSettings
{
@@ -72,31 +97,31 @@ public static OpenAiImageGenerationSettings Calculate(
requestSettingsCasted?.NumberOfResults ??
modelSettingsCasted?.NumberOfResults ??
providerSettingsCasted?.NumberOfResults ??
- Default.NumberOfResults ??
+ defaultSettingsCasted?.NumberOfResults ??
throw new InvalidOperationException("Default NumberOfResults is not set."),
Quality =
requestSettingsCasted?.Quality ??
modelSettingsCasted?.Quality ??
providerSettingsCasted?.Quality ??
- Default.Quality ??
+ defaultSettingsCasted?.Quality ??
throw new InvalidOperationException("Default Quality is not set."),
ResponseFormat =
requestSettingsCasted?.ResponseFormat ??
modelSettingsCasted?.ResponseFormat ??
providerSettingsCasted?.ResponseFormat ??
- Default.ResponseFormat ??
+ defaultSettingsCasted?.ResponseFormat ??
throw new InvalidOperationException("Default ResponseFormat is not set."),
Resolution =
requestSettingsCasted?.Resolution ??
modelSettingsCasted?.Resolution ??
providerSettingsCasted?.Resolution ??
- Default.Resolution ??
+ defaultSettingsCasted?.Resolution ??
throw new InvalidOperationException("Default Resolution is not set."),
User =
requestSettingsCasted?.User ??
modelSettingsCasted?.User ??
providerSettingsCasted?.User ??
- Default.User ??
+ defaultSettingsCasted?.User ??
throw new InvalidOperationException("Default User is not set."),
};
}
diff --git a/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs b/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs
index d42cccd6..db00b809 100644
--- a/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs
+++ b/src/Providers/OpenAI/src/Predefined/ImageGenerationModels.cs
@@ -4,8 +4,8 @@ namespace LangChain.Providers.OpenAI.Predefined;
///
public class DallE2Model(OpenAiProvider provider)
- : OpenAiTextToSpeechModel(provider, id: ImageModels.DallE2);
+ : OpenAiImageGenerationModel(provider, id: ImageModels.DallE2);
///
public class DallE3Model(OpenAiProvider provider)
- : OpenAiTextToSpeechModel(provider, id: ImageModels.DallE3);
\ No newline at end of file
+ : OpenAiImageGenerationModel(provider, id: ImageModels.DallE3);
\ No newline at end of file