Skip to content

Commit

Permalink
feat: add OpenAI tools support
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Sep 16, 2024
1 parent a3ec18a commit f071566
Show file tree
Hide file tree
Showing 17 changed files with 605 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

import de.l3s.interweb.connector.openai.entity.CompletionBody;
import de.l3s.interweb.connector.openai.entity.CompletionResponse;
import de.l3s.interweb.connector.openai.entity.CompletionsBody;
import de.l3s.interweb.connector.openai.entity.CompletionsResponse;
import de.l3s.interweb.core.ConnectorException;

@Path("/openai/deployments")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@RegisterRestClient(configKey = "openai")
@ClientHeaderParam(name = "api-key", value = "${connector.openai.apikey}")
@ClientQueryParam(name = "api-version", value = "2024-02-01")
@ClientQueryParam(name = "api-version", value = "2024-06-01")
public interface OpenaiClient {

/**
Expand All @@ -28,7 +28,7 @@ public interface OpenaiClient {
*/
@POST
@Path("/{model}/chat/completions")
Uni<CompletionResponse> chatCompletions(@PathParam("model") String model, CompletionBody body);
Uni<CompletionsResponse> chatCompletions(@PathParam("model") String model, CompletionsBody body);

@ClientExceptionMapper
static RuntimeException toException(Response response) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.logging.Logger;

import de.l3s.interweb.connector.openai.entity.CompletionBody;
import de.l3s.interweb.connector.openai.entity.CompletionsBody;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.chat.ChatConnector;
import de.l3s.interweb.core.chat.CompletionsQuery;
Expand Down Expand Up @@ -56,7 +56,7 @@ public Uni<List<Model>> getModels() {

@Override
public Uni<CompletionsResults> completions(CompletionsQuery query) throws ConnectorException {
return openai.chatCompletions(query.getModel(), new CompletionBody(query)).map(response -> {
return openai.chatCompletions(query.getModel(), new CompletionsBody(query)).map(response -> {
CompletionsResults results = new CompletionsResults();
results.setModel(query.getModel());
results.setCreated(response.getCreated());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import de.l3s.interweb.core.chat.CompletionsQuery;
import de.l3s.interweb.core.chat.ResponseFormat;
import de.l3s.interweb.core.chat.Tool;

@RegisterForReflection
@JsonInclude(JsonInclude.Include.NON_NULL)
public final class CompletionBody {
public final class CompletionsBody {

private List<CompletionMessage> messages;
private List<OpenaiMessage> messages;

private Double temperature;

Expand All @@ -30,26 +31,25 @@ public final class CompletionBody {
@JsonProperty("max_tokens")
private Integer maxTokens;

/**
* How many completions to generate for each prompt. Minimum of 1 (default) and maximum of 128 allowed.
* Note: Because this parameter generates many completions, it can quickly consume your token quota.
*/
private Integer n;

/**
* If specified, our system will make the best effort to sample deterministically,
* such that repeated requests with the same seed and parameters should return the same result.
* Determinism isn't guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.
*/
private Integer seed;

@JsonProperty("response_format")
private ResponseFormat responseFormat;

private List<Tool> tools;

@JsonProperty("tool_choice")
private Object toolChoice;

@JsonProperty("parallel_tool_calls")
private Boolean parallelToolCalls;

private String[] stop;

public CompletionBody(CompletionsQuery query) {
this.messages = query.getMessages().stream().map(CompletionMessage::new).toList();
public CompletionsBody(CompletionsQuery query) {
this.messages = query.getMessages().stream().map(OpenaiMessage::new).toList();
this.temperature = query.getTemperature();
this.topP = query.getTopP();
this.frequencyPenalty = query.getPresencePenalty();
Expand All @@ -59,9 +59,12 @@ public CompletionBody(CompletionsQuery query) {
this.seed = query.getSeed();
this.responseFormat = query.getResponseFormat();
this.stop = query.getStop();
this.tools = query.getTools();
this.toolChoice = query.getToolChoice();
this.parallelToolCalls = query.getParallelToolCalls();
}

public List<CompletionMessage> getMessages() {
public List<OpenaiMessage> getMessages() {
return messages;
}

Expand Down Expand Up @@ -100,4 +103,16 @@ public ResponseFormat getResponseFormat() {
public String[] getStop() {
return stop;
}

public List<Tool> getTools() {
return tools;
}

public Object getToolChoice() {
return toolChoice;
}

public Boolean getParallelToolCalls() {
return parallelToolCalls;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import de.l3s.interweb.core.chat.Usage;

@RegisterForReflection
public class CompletionResponse {
public class CompletionsResponse {
private String id;
private String object;
private String model;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,45 +1,52 @@
package de.l3s.interweb.connector.openai.entity;

import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import de.l3s.interweb.core.chat.Message;
import de.l3s.interweb.core.chat.ToolCall;

@RegisterForReflection
public final class CompletionMessage {
@JsonInclude(JsonInclude.Include.NON_NULL)
public final class OpenaiMessage {
private String role;
@JsonIgnore
private String name;
private String content;
private String refusal;
@JsonProperty("tool_calls")
private List<ToolCall> toolCalls;

public CompletionMessage(Message message) {
public OpenaiMessage(Message message) {
this.role = message.getRole().name();
this.name = message.getName();
this.content = message.getContent();
this.refusal = message.getRefusal();
this.toolCalls = message.getToolCalls();
}

public String getRole() {
return role;
}

public void setRole(String role) {
this.role = role;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getContent() {
return content;
}

public void setContent(String content) {
this.content = content;
public String getRefusal() {
return refusal;
}

public List<ToolCall> getToolCalls() {
return toolCalls;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static org.junit.jupiter.api.Assertions.*;

import java.util.List;

import jakarta.inject.Inject;

import io.quarkus.test.junit.QuarkusTest;
Expand All @@ -10,10 +12,7 @@
import org.junit.jupiter.api.Test;

import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.chat.Choice;
import de.l3s.interweb.core.chat.CompletionsQuery;
import de.l3s.interweb.core.chat.CompletionsResults;
import de.l3s.interweb.core.chat.Role;
import de.l3s.interweb.core.chat.*;

@Disabled
@QuarkusTest
Expand Down Expand Up @@ -43,4 +42,27 @@ void completions() throws ConnectorException {
log.infov("assistant: {0}", result.getMessage().getContent());
}
}

@Test
void completionsWithTool() throws ConnectorException {
Tool weatherTool = Tool.functionBuilder()
.name("get_weather")
.description("Return the weather in a city.")
.parameters(city -> city.name("city").type("string").description("The city name.").required())
.build();

CompletionsQuery query = new CompletionsQuery();
query.setModel("gpt-4");
query.setTools(List.of(weatherTool));
query.setToolChoice(ToolChoice.required);
query.addMessage("You are Interweb Assistant, a helpful chat bot.", Role.system);
query.addMessage("What is the weather in Hannover?", Role.user);

CompletionsResults results = connector.completions(query).await().indefinitely();

assertEquals(1, results.getChoices().size());
for (Choice result : results.getChoices()) {
assertEquals("get_weather", result.getMessage().getToolCalls().getFirst().getFunction().getName());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package de.l3s.interweb.core.chat;

import java.io.Serial;
import java.io.Serializable;

import jakarta.validation.constraints.NotEmpty;

import io.quarkus.runtime.annotations.RegisterForReflection;

@RegisterForReflection
public class CallFunction implements Serializable {
@Serial
private static final long serialVersionUID = -2780720621585498099L;

/**
* The name of the function to call.
*/
@NotEmpty
private String name;

/**
* The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON,
* and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
*/
private String arguments;

public void setName(String name) {
this.name = name;
}

public String getName() {
return name;
}

public void setArguments(String arguments) {
this.arguments = arguments;
}

public String getArguments() {
return arguments;
}
}
13 changes: 13 additions & 0 deletions interweb-core/src/main/java/de/l3s/interweb/core/chat/Choice.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,22 @@
@RegisterForReflection
@JsonIgnoreProperties({"elapsed_time", "created"})
public class Choice extends ConnectorResults {
/**
* The index of the choice in the list of choices.
*/
private int index;

/**
* The reason the model stopped generating tokens.
* This will be stop if the model hit a natural stop point or a provided stop sequence, length if the maximum number of tokens specified
* in the request was reached, content_filter if content was omitted due to a flag from our content filters, tool_calls if the model called a tool.
*/
@JsonProperty("finish_reason")
private String finishReason;

/**
* A chat completion message generated by the model.
*/
private Message message;

public Choice() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.Size;

import io.quarkus.runtime.annotations.RegisterForReflection;

Expand Down Expand Up @@ -92,6 +93,35 @@ public class CompletionsQuery {
@JsonProperty("max_tokens")
private Integer maxTokens;

/**
* A list of tools the model may call. Currently, only functions are supported as a tool.
* Use this to provide a list of functions the model may generate JSON inputs for.
* <br/>
* A max of 128 functions are supported.
*/
@Size(max = 128)
@JsonProperty("tools")
private List<Tool> tools;

/**
* Controls which (if any) tool is called by the model.
* - none means the model will not call any tool and instead generates a message.
* - auto means the model can pick between generating a message or calling one or more tools.
* - required means the model must call one or more tools.
* Specifying a particular tool via {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
* <br/>
* none is the default when no tools are present. auto is the default if tools are present.
*/
@JsonProperty("tool_choice")
private Object toolChoice;

/**
* Whether to enable parallel function calling during tool use.
* https://platform.openai.com/docs/guides/function-calling/parallel-function-calling
*/
@JsonProperty("parallel_tool_calls")
private Boolean parallelToolCalls;

/**
* Whether to incrementally stream the response using server-sent events. Defaults to false.
*/
Expand Down Expand Up @@ -230,6 +260,30 @@ public Integer getN() {
return n;
}

public void setTools(List<Tool> tools) {
this.tools = tools;
}

public List<Tool> getTools() {
return tools;
}

public void setToolChoice(Object toolChoice) {
this.toolChoice = toolChoice;
}

public Object getToolChoice() {
return toolChoice;
}

public void setParallelToolCalls(Boolean parallelToolCalls) {
this.parallelToolCalls = parallelToolCalls;
}

public Boolean getParallelToolCalls() {
return parallelToolCalls;
}

public void setSeed(Integer seed) {
this.seed = seed;
}
Expand Down
Loading

0 comments on commit f071566

Please sign in to comment.