Skip to content

Commit

Permalink
Merge pull request #123 from sashirestela/122-new-feature-stream_options
Browse files Browse the repository at this point in the history
Support for the new `stream_options` feature
  • Loading branch information
sashirestela authored May 7, 2024
2 parents b90aaa2 + 99cfd71 commit cd3f1cc
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 169 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -695,9 +695,9 @@ Examples for each OpenAI service have been created in the folder [demo](https://
* ChatAzure

* ```[debug]``` Is optional and creates the ```demo.log``` file where you can see log details for each execution.
* For example, to run the chat demo with a log file: ```./rundemo.sh chat debug```
* For example, to run the chat demo with a log file: ```./rundemo.sh Chat debug```

* Create environment variables for the Azure OpenAI demos
* Indications for Azure OpenAI demos

Azure OpenAI requires a separate deployment for each model. The Azure OpenAI demos require
two models.
Expand Down
30 changes: 19 additions & 11 deletions src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ public ChatDemo() {
public void demoCallChatStreaming() {
var futureChat = openAI.chatCompletions().createStream(chatRequest);
var chatResponse = futureChat.join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(Chat::firstContent)
.forEach(System.out::print);
System.out.println();
chatResponse.forEach(ChatDemo::processResponseChunk);
}

public void demoCallChatBlocking() {
Expand Down Expand Up @@ -110,9 +107,7 @@ public void demoCallChatWithVisionExternalImage() {
.maxTokens(500)
.build();
var chatResponse = openAI.chatCompletions().createStream(chatRequest).join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(chatResp -> chatResp.firstContent())
.forEach(System.out::print);
chatResponse.forEach(ChatDemo::processResponseChunk);
System.out.println();
}

Expand All @@ -128,13 +123,11 @@ public void demoCallChatWithVisionLocalImage() {
.maxTokens(500)
.build();
var chatResponse = openAI.chatCompletions().createStream(chatRequest).join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(chatResp -> chatResp.firstContent())
.forEach(System.out::print);
chatResponse.forEach(ChatDemo::processResponseChunk);
System.out.println();
}

private static ImageUrl loadImageAsBase64(String imagePath) {
private ImageUrl loadImageAsBase64(String imagePath) {
try {
Path path = Paths.get(imagePath);
byte[] imageBytes = Files.readAllBytes(path);
Expand All @@ -148,6 +141,21 @@ private static ImageUrl loadImageAsBase64(String imagePath) {
}
}

private static void processResponseChunk(Chat responseChunk) {
var choices = responseChunk.getChoices();
if (choices.size() > 0) {
var delta = choices.get(0).getMessage();
if (delta.getContent() != null) {
System.out.print(delta.getContent());
}
}
var usage = responseChunk.getUsage();
if (usage != null) {
System.out.println("\n");
System.out.println(usage);
}
}

public static class Weather implements Functional {

@JsonPropertyDescription("City and state, for example: León, Guanajuato")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,24 @@ public CompletionDemo() {
public void demoCallCompletionStreaming() {
var futureCompletion = openAI.completions().createStream(completionRequest);
var completionResponse = futureCompletion.join();
completionResponse.filter(complResponse -> complResponse.firstText() != null)
.map(Completion::firstText)
.forEach(System.out::print);
completionResponse.forEach(CompletionDemo::processResponseChunk);
;
System.out.println();
}

private static void processResponseChunk(Completion responseChunk) {
var choices = responseChunk.getChoices();
if (choices.size() > 0) {
var delta = choices.get(0).getText();
System.out.print(delta);
}
var usage = responseChunk.getUsage();
if (usage != null) {
System.out.println("\n");
System.out.println(usage);
}
}

public void demoCallCompletionBlocking() {
var futureCompletion = openAI.completions().create(completionRequest);
var completionResponse = futureCompletion.join();
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/io/github/sashirestela/openai/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.github.sashirestela.openai.common.DeletedObject;
import io.github.sashirestela.openai.common.Generic;
import io.github.sashirestela.openai.common.Page;
import io.github.sashirestela.openai.common.StreamOptions;
import io.github.sashirestela.openai.common.tool.ToolChoiceOption;
import io.github.sashirestela.openai.domain.audio.AudioResponseFormat;
import io.github.sashirestela.openai.domain.audio.SpeechRequest;
Expand Down Expand Up @@ -331,7 +332,7 @@ default CompletableFuture<Completion> create(@Body CompletionRequest completionR
* @return Response is delivered as a continuous flow of tokens.
*/
default CompletableFuture<Stream<Completion>> createStream(@Body CompletionRequest completionRequest) {
var request = completionRequest.withStream(Boolean.TRUE);
var request = completionRequest.withStream(Boolean.TRUE).withStreamOptions(StreamOptions.of(Boolean.TRUE));
return createStreamPrimitive(request);
}

Expand Down Expand Up @@ -736,6 +737,9 @@ static AudioResponseFormat getResponseFormat(AudioResponseFormat currValue, Audi

static ChatRequest updateRequest(ChatRequest chatRequest, Boolean useStream) {
var updatedChatRequest = chatRequest.withStream(useStream);
if (Boolean.TRUE.equals(useStream)) {
updatedChatRequest = updatedChatRequest.withStreamOptions(StreamOptions.of(useStream));
}
if (!isNullOrEmpty(chatRequest.getTools()) && chatRequest.getToolChoice() == null) {
updatedChatRequest = updatedChatRequest.withToolChoice(ToolChoiceOption.AUTO);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.github.sashirestela.openai.common;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;

@AllArgsConstructor(access = AccessLevel.PRIVATE)
@Getter
@JsonInclude(Include.NON_EMPTY)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class StreamOptions {

private Boolean includeUsage;

public static StreamOptions of(Boolean includeUsage) {
return new StreamOptions(includeUsage);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.github.sashirestela.openai.common.ResponseFormat;
import io.github.sashirestela.openai.common.StreamOptions;
import io.github.sashirestela.openai.common.tool.Tool;
import io.github.sashirestela.openai.common.tool.ToolChoice;
import io.github.sashirestela.openai.common.tool.ToolChoiceOption;
Expand Down Expand Up @@ -61,6 +62,9 @@ public class ChatRequest {
@With
private Boolean stream;

@With
private StreamOptions streamOptions;

@Range(min = 0.0, max = 2.0)
private Double temperature;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.github.sashirestela.openai.common.StreamOptions;
import io.github.sashirestela.slimvalidator.constraints.ObjectType;
import io.github.sashirestela.slimvalidator.constraints.Range;
import io.github.sashirestela.slimvalidator.constraints.Required;
Expand Down Expand Up @@ -59,6 +60,9 @@ public class CompletionRequest {
@With
private Boolean stream;

@With
private StreamOptions streamOptions;

private String suffix;

@Range(min = 0.0, max = 2.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,20 @@ static void setup() {
void testChatCompletionsCreateStream() throws IOException {
DomainTestingHelper.get().mockForStream(httpClient, "src/test/resources/chatcompletions_create_stream.txt");
var chatResponse = openAI.chatCompletions().createStream(chatTextRequest).join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(chatResp -> chatResp.firstContent())
.forEach(System.out::print);
chatResponse.forEach(responseChunk -> {
var choices = responseChunk.getChoices();
if (choices.size() > 0) {
var delta = choices.get(0).getMessage();
if (delta.getContent() != null) {
System.out.print(delta.getContent());
}
}
var usage = responseChunk.getUsage();
if (usage != null) {
System.out.println("\n");
System.out.println(usage);
}
});
assertNotNull(chatResponse);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,18 @@ static void setup() {
void testCompletionsCreateStream() throws IOException {
DomainTestingHelper.get().mockForStream(httpClient, "src/test/resources/completions_create_stream.txt");
var completionResponse = openAI.completions().createStream(completionRequest).join();
completionResponse.filter(chatResp -> chatResp.firstText() != null)
.map(chatResp -> chatResp.firstText())
.forEach(System.out::print);
completionResponse.forEach(responseChunk -> {
var choices = responseChunk.getChoices();
if (choices.size() > 0) {
var delta = choices.get(0).getText();
System.out.print(delta);
}
var usage = responseChunk.getUsage();
if (usage != null) {
System.out.println("\n");
System.out.println(usage);
}
});
assertNotNull(completionResponse);
}

Expand Down
Loading

0 comments on commit cd3f1cc

Please sign in to comment.