Skip to content

Commit a83b590

Browse files
committed
feat(zhipuai): Add usage field to ChatCompletionChunk and update tests
- Add `usage` field to `ChatCompletionChunk` - Update integration and unit tests to use free model (`glm-4-flash`) and verify `usage` field Fixed #4609 Signed-off-by: YunKui Lu <[email protected]>
1 parent 3fc1ed6 commit a83b590

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
284284
})
285285
.concatMapIterable(window -> {
286286
Mono<ChatCompletionChunk> monoChunk = window
287-
.reduce(new ChatCompletionChunk(null, null, null, null, null, null), this.chunkMerger::merge);
287+
.reduce(new ChatCompletionChunk(null, null, null, null, null, null, null), this.chunkMerger::merge);
288288
return List.of(monoChunk);
289289
})
290290
.flatMap(mono -> mono);
@@ -1110,7 +1110,8 @@ public record ChatCompletionChunk(// @formatter:off
11101110
@JsonProperty("created") Long created,
11111111
@JsonProperty("model") String model,
11121112
@JsonProperty("system_fingerprint") String systemFingerprint,
1113-
@JsonProperty("object") String object) { // @formatter:on
1113+
@JsonProperty("object") String object,
1114+
@JsonProperty("usage") Usage usage) { // @formatter:on
11141115

11151116
/**
11161117
* Chat completion choice.

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,14 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
5858
String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint()
5959
: previous.systemFingerprint());
6060
String object = (current.object() != null ? current.object() : previous.object());
61+
ZhiPuAiApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage());
6162

6263
ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
6364
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));
6465

6566
ChunkChoice choice = merge(previousChoice0, currentChoice0);
6667
List<ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
67-
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object);
68+
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object, usage);
6869
}
6970

7071
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,24 @@ void chatCompletionEntity() {
5757
void chatCompletionEntityWithMoreParams() {
5858
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
5959
ResponseEntity<ChatCompletion> response = this.zhiPuAiApi
60-
.chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 1024, null,
60+
.chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-flash", 1024, null,
6161
false, 0.95, 0.7, null, null, null, "test_request_id", false, null, null));
6262

6363
assertThat(response).isNotNull();
6464
assertThat(response.getBody()).isNotNull();
65+
assertThat(response.getBody().usage()).isNotNull();
6566
}
6667

6768
@Test
6869
void chatCompletionStream() {
6970
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
7071
Flux<ChatCompletionChunk> response = this.zhiPuAiApi
71-
.chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, true));
72+
.chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-flash", 0.7, true));
7273

7374
assertThat(response).isNotNull();
74-
assertThat(response.collectList().block()).isNotNull();
75+
List<ChatCompletionChunk> chunks = response.collectList().block();
76+
assertThat(chunks).isNotNull();
77+
assertThat(chunks.get(chunks.size() - 1).usage()).isNotNull();
7578
}
7679

7780
@Test

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ public void zhiPuAiChatStreamTransientError() {
133133
var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
134134
new ChatCompletionMessage("Response", Role.ASSISTANT), null);
135135
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666L, "model", null,
136-
null);
136+
null, null);
137137

138138
given(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
139139
.willThrow(new TransientAiException("Transient Error 1"))

0 commit comments

Comments
 (0)