Skip to content

Commit aef2409

Browse files
Bart Veenstrabartveenstra
Bart Veenstra
authored andcommitted
feat: enhance AzureOpenAiResponseFormat to support JSON schema and builder pattern
Signed-off-by: Bart Veenstra <[email protected]>
1 parent 82b46d2 commit aef2409

File tree

3 files changed

+259
-20
lines changed

3 files changed

+259
-20
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
20+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
1921
import java.util.ArrayList;
2022
import java.util.Base64;
2123
import java.util.Collections;
@@ -58,6 +60,8 @@
5860
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
5961
import org.slf4j.Logger;
6062
import org.slf4j.LoggerFactory;
63+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
64+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
6165
import reactor.core.publisher.Flux;
6266

6367
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -115,6 +119,7 @@
115119
* @author Jihoon Kim
116120
* @author Ilayaperumal Gopinathan
117121
* @author Alexandros Pappas
122+
* @author Bart Veenstra
118123
* @see ChatModel
119124
* @see com.azure.ai.openai.OpenAIClient
120125
* @since 1.0.0
@@ -278,7 +283,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
278283
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
279284
.prompt(prompt)
280285
.provider(AiProvider.AZURE_OPENAI.value())
281-
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
282286
.build();
283287

284288
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
@@ -334,7 +338,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
334338
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
335339
.prompt(prompt)
336340
.provider(AiProvider.AZURE_OPENAI.value())
337-
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
338341
.build();
339342

340343
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
@@ -940,9 +943,16 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
940943
* @return Azure response format
941944
*/
942945
private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseFormat responseFormat) {
943-
if (responseFormat == AzureOpenAiResponseFormat.JSON) {
946+
if (responseFormat.getType() == Type.JSON_OBJECT) {
944947
return new ChatCompletionsJsonResponseFormat();
945948
}
949+
if (responseFormat.getType() == Type.JSON_SCHEMA) {
950+
JsonSchema jsonSchema = responseFormat.getJsonSchema();
951+
var responseFormatJsonSchema = new ChatCompletionsJsonSchemaResponseFormatJsonSchema(jsonSchema.getName());
952+
String jsonString = ModelOptionsUtils.toJsonString(jsonSchema.getSchema());
953+
responseFormatJsonSchema.setSchema(BinaryData.fromString(jsonString));
954+
return new ChatCompletionsJsonSchemaResponseFormat(responseFormatJsonSchema);
955+
}
946956
return new ChatCompletionsTextResponseFormat();
947957
}
948958

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java

+243-15
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,255 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import java.util.Map;
23+
import java.util.Objects;
24+
import org.springframework.ai.model.ModelOptionsUtils;
25+
import org.springframework.util.StringUtils;
26+
1927
/**
2028
* Utility enumeration for representing the response format that may be requested from the
2129
* Azure OpenAI model. Please check <a href=
2230
* "https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format">OpenAI
2331
* API documentation</a> for more details.
2432
*/
25-
public enum AzureOpenAiResponseFormat {
26-
27-
// default value used by OpenAI
28-
TEXT,
29-
/*
30-
* From the OpenAI API documentation: Compatability: Compatible with GPT-4 Turbo and
31-
* all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Caveats: This enables JSON
32-
* mode, which guarantees the message the model generates is valid JSON. Important:
33-
* when using JSON mode, you must also instruct the model to produce JSON yourself via
34-
* a system or user message. Without this, the model may generate an unending stream
35-
* of whitespace until the generation reaches the token limit, resulting in a
36-
* long-running and seemingly "stuck" request. Also note that the message content may
37-
* be partially cut off if finish_reason="length", which indicates the generation
38-
* exceeded max_tokens or the conversation exceeded the max context length.
33+
@JsonInclude(Include.NON_NULL)
34+
public class AzureOpenAiResponseFormat {
35+
36+
/**
37+
* Type Must be one of 'text', 'json_object' or 'json_schema'.
38+
*/
39+
@JsonProperty("type")
40+
private Type type;
41+
42+
/**
43+
* JSON schema object that describes the format of the JSON object. Only applicable
44+
* when type is 'json_schema'.
3945
*/
40-
JSON
46+
@JsonProperty("json_schema")
47+
private JsonSchema jsonSchema = null;
48+
49+
private String schema;
50+
51+
public AzureOpenAiResponseFormat() {
52+
53+
}
54+
55+
public Type getType() {
56+
return this.type;
57+
}
58+
59+
public void setType(Type type) {
60+
this.type = type;
61+
}
62+
63+
public JsonSchema getJsonSchema() {
64+
return this.jsonSchema;
65+
}
66+
67+
public void setJsonSchema(JsonSchema jsonSchema) {
68+
this.jsonSchema = jsonSchema;
69+
}
70+
71+
public String getSchema() {
72+
return this.schema;
73+
}
74+
75+
public void setSchema(String schema) {
76+
this.schema = schema;
77+
if (schema != null) {
78+
this.jsonSchema = JsonSchema.builder().schema(schema).strict(true).build();
79+
}
80+
}
81+
82+
private AzureOpenAiResponseFormat(Type type, JsonSchema jsonSchema) {
83+
this.type = type;
84+
this.jsonSchema = jsonSchema;
85+
}
86+
87+
public AzureOpenAiResponseFormat(Type type, String schema) {
88+
this(type, StringUtils.hasText(schema) ? JsonSchema.builder().schema(schema).strict(true).build() : null);
89+
}
90+
91+
public static Builder builder() {
92+
return new Builder();
93+
}
94+
95+
@Override
96+
public boolean equals(Object o) {
97+
if (this == o) {
98+
return true;
99+
}
100+
if (o == null || getClass() != o.getClass()) {
101+
return false;
102+
}
103+
AzureOpenAiResponseFormat that = (AzureOpenAiResponseFormat) o;
104+
return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema);
105+
}
106+
107+
@Override
108+
public int hashCode() {
109+
return Objects.hash(this.type, this.jsonSchema);
110+
}
111+
112+
@Override
113+
public String toString() {
114+
return "ResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}';
115+
}
116+
117+
public static final class Builder {
118+
119+
private Type type;
120+
121+
private JsonSchema jsonSchema;
122+
123+
private Builder() {
124+
}
125+
126+
public Builder type(Type type) {
127+
this.type = type;
128+
return this;
129+
}
130+
131+
public Builder jsonSchema(JsonSchema jsonSchema) {
132+
this.jsonSchema = jsonSchema;
133+
return this;
134+
}
135+
136+
public Builder jsonSchema(String jsonSchema) {
137+
this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build();
138+
return this;
139+
}
140+
141+
public AzureOpenAiResponseFormat build() {
142+
return new AzureOpenAiResponseFormat(this.type, this.jsonSchema);
143+
}
144+
145+
}
146+
147+
public enum Type {
148+
149+
/**
150+
* Generates a text response. (default)
151+
*/
152+
@JsonProperty("text")
153+
TEXT,
154+
155+
/**
156+
* Enables JSON mode, which guarantees the message the model generates is valid
157+
* JSON.
158+
*/
159+
@JsonProperty("json_object")
160+
JSON_OBJECT,
161+
162+
/**
163+
* Enables Structured Outputs which guarantees the model will match your supplied
164+
* JSON schema.
165+
*/
166+
@JsonProperty("json_schema")
167+
JSON_SCHEMA
168+
169+
}
170+
171+
/**
172+
* JSON schema object that describes the format of the JSON object. Applicable for the
173+
* 'json_schema' type only.
174+
*/
175+
@JsonInclude(Include.NON_NULL)
176+
public static class JsonSchema {
177+
178+
@JsonProperty("name")
179+
private String name;
180+
181+
@JsonProperty("schema")
182+
private Map<String, Object> schema;
183+
184+
@JsonProperty("strict")
185+
private Boolean strict;
186+
187+
public JsonSchema() {
188+
189+
}
190+
191+
public String getName() {
192+
return this.name;
193+
}
194+
195+
public Map<String, Object> getSchema() {
196+
return this.schema;
197+
}
198+
199+
public Boolean getStrict() {
200+
return this.strict;
201+
}
202+
203+
private JsonSchema(String name, Map<String, Object> schema, Boolean strict) {
204+
this.name = name;
205+
this.schema = schema;
206+
this.strict = strict;
207+
}
208+
209+
public static Builder builder() {
210+
return new Builder();
211+
}
212+
213+
@Override
214+
public int hashCode() {
215+
return Objects.hash(this.name, this.schema, this.strict);
216+
}
217+
218+
@Override
219+
public boolean equals(Object o) {
220+
if (this == o) {
221+
return true;
222+
}
223+
if (o == null || getClass() != o.getClass()) {
224+
return false;
225+
}
226+
JsonSchema that = (JsonSchema) o;
227+
return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema)
228+
&& Objects.equals(this.strict, that.strict);
229+
}
230+
231+
public static final class Builder {
232+
233+
private String name = "custom_schema";
234+
235+
private Map<String, Object> schema;
236+
237+
private Boolean strict = true;
238+
239+
private Builder() {
240+
}
241+
242+
public Builder name(String name) {
243+
this.name = name;
244+
return this;
245+
}
246+
247+
public Builder schema(Map<String, Object> schema) {
248+
this.schema = schema;
249+
return this;
250+
}
251+
252+
public Builder schema(String schema) {
253+
this.schema = ModelOptionsUtils.jsonToMap(schema);
254+
return this;
255+
}
256+
257+
public Builder strict(Boolean strict) {
258+
this.strict = strict;
259+
return this;
260+
}
261+
262+
public JsonSchema build() {
263+
return new JsonSchema(this.name, this.schema, this.strict);
264+
}
265+
266+
}
267+
268+
}
41269

42270
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.junit.jupiter.params.provider.MethodSource;
3131
import org.mockito.Mockito;
3232

33+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
3334
import org.springframework.ai.chat.prompt.Prompt;
3435

3536
import static org.assertj.core.api.Assertions.assertThat;
@@ -68,7 +69,7 @@ public void createRequestWithChatOptions() {
6869
.logprobs(true)
6970
.topLogprobs(5)
7071
.enhancements(mockAzureChatEnhancementConfiguration)
71-
.responseFormat(AzureOpenAiResponseFormat.TEXT)
72+
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.TEXT).build())
7273
.build();
7374

7475
var client = AzureOpenAiChatModel.builder()
@@ -114,7 +115,7 @@ public void createRequestWithChatOptions() {
114115
.logprobs(true)
115116
.topLogprobs(4)
116117
.enhancements(anotherMockAzureChatEnhancementConfiguration)
117-
.responseFormat(AzureOpenAiResponseFormat.JSON)
118+
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.JSON_OBJECT).build())
118119
.build();
119120

120121
requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", runtimeOptions));

0 commit comments

Comments
 (0)