Skip to content

Commit

Permalink
Merge pull request #65 from sashirestela/64-streamming-in-the-assista…
Browse files Browse the repository at this point in the history
…nts-api

Streaming in the Assistants API
  • Loading branch information
sashirestela authored Mar 24, 2024
2 parents 09ab2b2 + 92a1c0e commit 6ff3d59
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .sdkmanrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Enable auto-env through the sdkman_auto_env config
# Add key=value pairs of SDKs to use below
java=11.0.15-tem
java=11.0.22-tem
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Full support for all of the OpenAI services:
* Embeddings
* Fine tuning
* Assistants API (Beta)
* Assistant Stream Events (Beta) 📣

![Services](media/supported_services.png)

Expand Down
Binary file modified media/assistants_api.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>io.github.sashirestela</groupId>
<artifactId>simple-openai</artifactId>
<version>2.0.1</version>
<version>2.1.0</version>
<packaging>jar</packaging>

<name>simple-openai</name>
Expand Down Expand Up @@ -52,7 +52,7 @@
<maven.compiler.release>11</maven.compiler.release>
<!-- Dependencies Versions -->
<slf4j.version>2.0.12</slf4j.version>
<cleverclient.version>1.1.0</cleverclient.version>
<cleverclient.version>1.3.3</cleverclient.version>
<lombok.version>1.18.30</lombok.version>
<jackson.version>2.16.1</jackson.version>
<json.schema.version>4.33.1</json.schema.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import io.github.sashirestela.openai.domain.assistant.AssistantRequest;
import io.github.sashirestela.openai.domain.assistant.AssistantTool;
import io.github.sashirestela.openai.domain.assistant.Events;
import io.github.sashirestela.openai.domain.assistant.ImageFileContent;
import io.github.sashirestela.openai.domain.assistant.TextContent;
import io.github.sashirestela.openai.domain.assistant.ThreadMessage;
import io.github.sashirestela.openai.domain.assistant.ThreadMessageDelta;
import io.github.sashirestela.openai.domain.assistant.ThreadMessageRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRequest;
import io.github.sashirestela.openai.domain.assistant.ThreadRun;
import io.github.sashirestela.openai.domain.assistant.ThreadRunRequest;
import io.github.sashirestela.openai.domain.file.FileRequest;
import io.github.sashirestela.openai.domain.file.PurposeType;

Expand Down Expand Up @@ -94,6 +98,15 @@ public void demoRunThreadAndWaitUntilComplete() {
System.out.println(messages);
}

public void demoRunThreadAndStream() {
var request = ThreadRunRequest.builder().assistantId(assistantId).build();
var response = openAI.threads().createRunStream(threadId, request).join();
response.filter(e -> e.getName().equals(Events.THREAD_MESSAGE_DELTA))
.map(e -> ((TextContent) ((ThreadMessageDelta) e.getData()).getDelta().getContent().get(0)).getValue())
.forEach(System.out::print);
System.out.println();
}

public void demoGetAssistantMessages() {
List<ThreadMessage> messages = openAI.threads().getMessageList(threadId).join();
ThreadMessage assistant = messages.get(0);
Expand Down Expand Up @@ -143,6 +156,7 @@ public static void main(String[] args) {
demo.addTitleAction("Demo Call Assistant File Upload", demo::demoUploadAssistantFile);
demo.addTitleAction("Demo Call Assistant Thread Create", demo::demoCreateThread);
demo.addTitleAction("Demo Call Assistant Thread Run", demo::demoRunThreadAndWaitUntilComplete);
demo.addTitleAction("Demo Call Assistant Thread Run Stream", demo::demoRunThreadAndStream);
demo.addTitleAction("Demo Call Assistant Messages Get", demo::demoGetAssistantMessages);
demo.addTitleAction("Demo Call Assistant Delete", demo::demoDeleteAssistant);

Expand Down
84 changes: 78 additions & 6 deletions src/main/java/io/github/sashirestela/openai/OpenAI.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.sashirestela.openai;

import io.github.sashirestela.cleverclient.Event;
import io.github.sashirestela.cleverclient.annotation.Body;
import io.github.sashirestela.cleverclient.annotation.DELETE;
import io.github.sashirestela.cleverclient.annotation.GET;
Expand All @@ -16,6 +17,7 @@
import io.github.sashirestela.openai.domain.assistant.Assistant;
import io.github.sashirestela.openai.domain.assistant.AssistantFile;
import io.github.sashirestela.openai.domain.assistant.AssistantRequest;
import io.github.sashirestela.openai.domain.assistant.AssistantStreamEvents;
import io.github.sashirestela.openai.domain.assistant.FilePath;
import io.github.sashirestela.openai.domain.assistant.Thread;
import io.github.sashirestela.openai.domain.assistant.ThreadCreateAndRunRequest;
Expand Down Expand Up @@ -827,8 +829,9 @@ CompletableFuture<Page<ThreadMessageFile>> getMessageFileList(@Path("threadId")
* @return the queued run object
*/
default CompletableFuture<ThreadRun> createRun(String threadId, String assistantId) {
return createRun(threadId, ThreadRunRequest.builder()
return __createRun(threadId, ThreadRunRequest.builder()
.assistantId(assistantId)
.stream(Boolean.FALSE)
.build());
}

Expand All @@ -839,8 +842,32 @@ default CompletableFuture<ThreadRun> createRun(String threadId, String assistant
* @param request The requested run.
* @return the queued run object
*/
default CompletableFuture<ThreadRun> createRun(@Path("threadId") String threadId,
@Body ThreadRunRequest request) {
var newRequest = request.withStream(Boolean.FALSE);
return __createRun(threadId, newRequest);
}

@POST("/{threadId}/runs")
CompletableFuture<ThreadRun> createRun(@Path("threadId") String threadId, @Body ThreadRunRequest request);
CompletableFuture<ThreadRun> __createRun(@Path("threadId") String threadId, @Body ThreadRunRequest request);

/**
* Create a run and stream the response.
*
* @param threadId The ID of the thread to run.
* @param request The requested run.
* @return A stream of events.
*/
default CompletableFuture<Stream<Event>> createRunStream(@Path("threadId") String threadId,
@Body ThreadRunRequest request) {
var newRequest = request.withStream(Boolean.TRUE);
return __createRunStream(threadId, newRequest);
}

@POST("/{threadId}/runs")
@AssistantStreamEvents
CompletableFuture<Stream<Event>> __createRunStream(@Path("threadId") String threadId,
@Body ThreadRunRequest request);

/**
* Retrieves a run.
Expand Down Expand Up @@ -884,7 +911,7 @@ default CompletableFuture<Page<ThreadRun>> getRunList(String threadId) {
CompletableFuture<Page<ThreadRun>> getRunList(@Path("threadId") String threadId, @Query PageRequest page);

/**
* Submit tool outputs to run
* Submit tool outputs to run.
*
* @param threadId The ID of the thread to which this run belongs.
* @param runId The ID of the run that requires the tool output submission.
Expand All @@ -899,17 +926,42 @@ default CompletableFuture<ThreadRun> submitToolOutputs(String threadId, String r
}

/**
* Submit tool outputs to run
* Submit tool outputs to run.
*
* @param threadId The ID of the thread to which this run belongs.
* @param runId The ID of the run that requires the tool output submission.
* @param toolOutputs The tool output submission.
* @return The modified run object matching the specified ID.
*/
default CompletableFuture<ThreadRun> submitToolOutputs(@Path("threadId") String threadId,
@Path("runId") String runId, @Body ToolOutputSubmission toolOutputs) {
var newToolOutputs = toolOutputs.withStream(Boolean.FALSE);
return __submitToolOutputs(threadId, runId, newToolOutputs);
}

@POST("/{threadId}/runs/{runId}/submit_tool_outputs")
CompletableFuture<ThreadRun> submitToolOutputs(@Path("threadId") String threadId, @Path("runId") String runId,
CompletableFuture<ThreadRun> __submitToolOutputs(@Path("threadId") String threadId, @Path("runId") String runId,
@Body ToolOutputSubmission toolOutputs);

/**
* Submit tool outputs to run and stream the response.
*
* @param threadId The ID of the thread to which this run belongs.
* @param runId The ID of the run that requires the tool output submission.
* @param toolOutputs The tool output submission.
* @return A stream of events.
*/
default CompletableFuture<Stream<Event>> submitToolOutputsStream(@Path("threadId") String threadId,
@Path("runId") String runId, @Body ToolOutputSubmission toolOutputs) {
var newToolOutputs = toolOutputs.withStream(Boolean.TRUE);
return __submitToolOutputsStream(threadId, runId, newToolOutputs);
}

@POST("/{threadId}/runs/{runId}/submit_tool_outputs")
@AssistantStreamEvents
CompletableFuture<Stream<Event>> __submitToolOutputsStream(@Path("threadId") String threadId,
@Path("runId") String runId, @Body ToolOutputSubmission toolOutputs);

/**
* Cancels a run that is {@code in_progress}.
*
Expand All @@ -926,8 +978,28 @@ CompletableFuture<ThreadRun> submitToolOutputs(@Path("threadId") String threadId
* @param request The thread request create and to run.
* @return A created run object.
*/
default CompletableFuture<ThreadRun> createThreadAndRun(@Body ThreadCreateAndRunRequest request) {
var newRequest = request.withStream(Boolean.FALSE);
return __createThreadAndRun(newRequest);
}

@POST("/runs")
CompletableFuture<ThreadRun> __createThreadAndRun(@Body ThreadCreateAndRunRequest request);

/**
* Create a thread and run it in one request and stream the response.
*
* @param request The thread request create and to run.
* @return A stream of events.
*/
default CompletableFuture<Stream<Event>> createThreadAndRunStream(@Body ThreadCreateAndRunRequest request) {
var newRequest = request.withStream(Boolean.TRUE);
return __createThreadAndRunStream(newRequest);
}

@POST("/runs")
CompletableFuture<ThreadRun> createThreadAndRun(@Body ThreadCreateAndRunRequest request);
@AssistantStreamEvents
CompletableFuture<Stream<Event>> __createThreadAndRunStream(@Body ThreadCreateAndRunRequest request);

/**
* Retrieves a run step.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.github.sashirestela.openai.domain.assistant;

import io.github.sashirestela.cleverclient.annotation.StreamType;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import static io.github.sashirestela.openai.domain.assistant.Events.ERROR;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_CREATED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_MESSAGE_COMPLETED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_MESSAGE_CREATED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_MESSAGE_DELTA;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_MESSAGE_INCOMPLETE;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_MESSAGE_IN_PROGRESS;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_CANCELLED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_CANCELLING;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_COMPLETED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_CREATED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_EXPIRED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_FAILED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_IN_PROGRESS;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_QUEUED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_REQUIRES_ACTION;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_CANCELLED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_COMPLETED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_CREATED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_DELTA;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_EXPIRED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_FAILED;
import static io.github.sashirestela.openai.domain.assistant.Events.THREAD_RUN_STEP_IN_PROGRESS;

@StreamType(type = Thread.class, events = { THREAD_CREATED })
@StreamType(type = ThreadRun.class, events = { THREAD_RUN_CREATED, THREAD_RUN_QUEUED, THREAD_RUN_IN_PROGRESS,
THREAD_RUN_REQUIRES_ACTION, THREAD_RUN_COMPLETED, THREAD_RUN_FAILED, THREAD_RUN_CANCELLING,
THREAD_RUN_CANCELLED, THREAD_RUN_EXPIRED })
@StreamType(type = ThreadRunStep.class, events = { THREAD_RUN_STEP_CREATED, THREAD_RUN_STEP_IN_PROGRESS,
THREAD_RUN_STEP_COMPLETED, THREAD_RUN_STEP_FAILED, THREAD_RUN_STEP_CANCELLED, THREAD_RUN_STEP_EXPIRED })
@StreamType(type = ThreadRunStepDelta.class, events = { THREAD_RUN_STEP_DELTA })
@StreamType(type = ThreadMessage.class, events = { THREAD_MESSAGE_CREATED, THREAD_MESSAGE_IN_PROGRESS,
THREAD_MESSAGE_COMPLETED, THREAD_MESSAGE_INCOMPLETE })
@StreamType(type = ThreadMessageDelta.class, events = { THREAD_MESSAGE_DELTA })
@StreamType(type = String.class, events = { ERROR })
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AssistantStreamEvents {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.github.sashirestela.openai.domain.assistant;

public interface Events {

static final String THREAD_CREATED = "thread.created";

static final String THREAD_RUN_CREATED = "thread.run.created";
static final String THREAD_RUN_QUEUED = "thread.run.queued";
static final String THREAD_RUN_IN_PROGRESS = "thread.run.in_progress";
static final String THREAD_RUN_REQUIRES_ACTION = "thread.run.requires_action";
static final String THREAD_RUN_COMPLETED = "thread.run.completed";
static final String THREAD_RUN_FAILED = "thread.run.failed";
static final String THREAD_RUN_CANCELLING = "thread.run.cancelling";
static final String THREAD_RUN_CANCELLED = "thread.run.cancelled";
static final String THREAD_RUN_EXPIRED = "thread.run.expired";

static final String THREAD_RUN_STEP_CREATED = "thread.run.step.created";
static final String THREAD_RUN_STEP_IN_PROGRESS = "thread.run.step.in_progress";
static final String THREAD_RUN_STEP_COMPLETED = "thread.run.step.completed";
static final String THREAD_RUN_STEP_FAILED = "thread.run.step.failed";
static final String THREAD_RUN_STEP_CANCELLED = "thread.run.step.cancelled";
static final String THREAD_RUN_STEP_EXPIRED = "thread.run.step.expired";

static final String THREAD_RUN_STEP_DELTA = "thread.run.step.delta";

static final String THREAD_MESSAGE_CREATED = "thread.message.created";
static final String THREAD_MESSAGE_IN_PROGRESS = "thread.message.in_progress";
static final String THREAD_MESSAGE_COMPLETED = "thread.message.completed";
static final String THREAD_MESSAGE_INCOMPLETE = "thread.message.incomplete";

static final String THREAD_MESSAGE_DELTA = "thread.message.delta";

static final String ERROR = "error";

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class FileCitationAnnotation implements TextContentAnnotation {

private Integer index;
private String text;
private FileCitation fileCitation;
private int startIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class FilePathAnnotation implements TextContentAnnotation {

private Integer index;
private String text;
private FilePath filePath;
private int startIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class ImageFileContent implements ThreadMessageContent {

private Integer index;
private FilePath imageFile;

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
@ToString
public class TextContent implements ThreadMessageContent {

private Integer index;
private Text text;

public String getValue() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Singular;
import lombok.With;

import java.util.List;
import java.util.Map;
Expand All @@ -24,5 +25,7 @@ public class ThreadCreateAndRunRequest {
@Singular
private List<AssistantTool> tools;
private Map<String, String> metadata;
@With
private boolean stream;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.github.sashirestela.openai.domain.assistant;

import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;

import java.util.List;

@NoArgsConstructor
@Getter
@ToString
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class ThreadMessageDelta {

private String id;
private String object;
private ThreadMessageDeltaDetail delta;

@NoArgsConstructor
@Getter
@ToString
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class ThreadMessageDeltaDetail {

private String role;
private List<ThreadMessageContent> content;
private List<String> fileIds;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Singular;
import lombok.With;

import java.util.List;
import java.util.Map;
Expand All @@ -24,5 +25,7 @@ public class ThreadRunRequest {
@Singular
private List<AssistantTool> tools;
private Map<String, String> metadata;
@With
private boolean stream;

}
Loading

0 comments on commit 6ff3d59

Please sign in to comment.