Skip to content

Commit

Permalink
refactor: chat completion requests, separate by users, title autogene…
Browse files Browse the repository at this point in the history
…ration
  • Loading branch information
astappiev committed Oct 2, 2023
1 parent eb68609 commit 64ab652
Show file tree
Hide file tree
Showing 22 changed files with 623 additions and 204 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package de.l3s.interweb.connector.openai;

public class ErrorResponse {
public Error error;
Error error;

public static class Error {
public String message;
public Object type;
public String param;
public String code;
public Integer status;
String message;
Object type;
String param;
String code;
Integer status;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

import de.l3s.interweb.connector.openai.entity.CompletionResponse;
import de.l3s.interweb.connector.openai.entity.CompletionsBody;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.completion.CompletionQuery;
import de.l3s.interweb.core.completion.CompletionResults;

@Path("/openai/deployments")
@Consumes(MediaType.APPLICATION_JSON)
Expand All @@ -22,7 +22,7 @@ public interface OpenaiClient {

@POST
@Path("/{model}/chat/completions")
Uni<CompletionResults> chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionQuery body);
Uni<CompletionResponse> chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionsBody body);

@ClientExceptionMapper
static RuntimeException toException(Response response) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.smallrye.mutiny.Uni;
import org.eclipse.microprofile.rest.client.inject.RestClient;

import de.l3s.interweb.connector.openai.entity.CompletionsBody;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.completion.CompletionConnector;
import de.l3s.interweb.core.completion.CompletionQuery;
Expand Down Expand Up @@ -49,6 +50,13 @@ public UsagePrice getPrice(String model) {

@Override
public Uni<CompletionResults> complete(CompletionQuery query) throws ConnectorException {
return openai.chatCompletions(query.getModel(), version, query);
return openai.chatCompletions(query.getModel(), version, new CompletionsBody(query)).map(response -> {
CompletionResults results = new CompletionResults();
results.setModel(query.getModel());
results.setCreated(response.getCreated());
results.setChoices(response.getChoices());
results.setUsage(response.getUsage());
return results;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package de.l3s.interweb.connector.openai.entity;

import java.time.Instant;
import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import de.l3s.interweb.core.completion.Choice;
import de.l3s.interweb.core.completion.Usage;

@RegisterForReflection
public class CompletionResponse {
private String id;
private String object;
private String model;
private Usage usage;
private Instant created;
private List<Choice> choices;

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public String getObject() {
return object;
}

public void setObject(String object) {
this.object = object;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public List<Choice> getChoices() {
return choices;
}

public void setChoices(List<Choice> choices) {
this.choices = choices;
}

public Usage getUsage() {
return usage;
}

public void setUsage(Usage usage) {
this.usage = usage;
}

public Instant getCreated() {
return created;
}

public void setCreated(Instant created) {
this.created = created;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package de.l3s.interweb.connector.openai.entity;

import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import de.l3s.interweb.core.completion.CompletionQuery;
import de.l3s.interweb.core.completion.Message;

@RegisterForReflection
public final class CompletionsBody {

private List<Message> messages;

private Double temperature;

private Double topP;

private Double frequencyPenalty;

private Double presencePenalty;

private Integer maxTokens;

public CompletionsBody(CompletionQuery query) {
this.messages = query.getMessages();
this.temperature = query.getTemperature();
this.topP = query.getTopP();
this.frequencyPenalty = query.getPresencePenalty();
this.presencePenalty = query.getPresencePenalty();
this.maxTokens = query.getMaxTokens();
}

public List<Message> getMessages() {
return messages;
}

public Double getTemperature() {
return temperature;
}

public Double getTopP() {
return topP;
}

public Double getFrequencyPenalty() {
return frequencyPenalty;
}

public Double getPresencePenalty() {
return presencePenalty;
}

public Integer getMaxTokens() {
return maxTokens;
}
}
4 changes: 4 additions & 0 deletions interweb-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

<name>Interweb Client</name>

<properties>
<maven.compiler.release>11</maven.compiler.release>
</properties>

<dependencies>
<dependency>
<groupId>de.l3s.interweb</groupId>
Expand Down
73 changes: 56 additions & 17 deletions interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package de.l3s.interweb.client;

import java.io.IOException;
import java.io.Serial;
import java.io.Serializable;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;

import de.l3s.interweb.core.completion.CompletionQuery;
import de.l3s.interweb.core.completion.CompletionResults;
import de.l3s.interweb.core.completion.Conversation;
import de.l3s.interweb.core.describe.DescribeQuery;
import de.l3s.interweb.core.describe.DescribeResults;
import de.l3s.interweb.core.search.SearchQuery;
Expand All @@ -27,7 +28,6 @@
import de.l3s.interweb.core.util.StringUtils;

public class Interweb implements Serializable {
@Serial
private static final long serialVersionUID = 7231324400348062196L;

private final ObjectMapper mapper;
Expand Down Expand Up @@ -72,11 +72,31 @@ public DescribeResults describe(String link) throws InterwebException {
return sendRequest("/describe", params, DescribeResults.class);
}

public CompletionResults chatCompletions(CompletionQuery query) throws InterwebException {
public List<Conversation> conversations(String user) throws InterwebException {
return sendRequest("/chat", Map.of("user", user), new TypeReference<>() {});
}

public CompletionResults completion(CompletionQuery query) throws InterwebException {
return sendRequest("/chat/completions", query, CompletionResults.class);
}

private URI createRequestUri(final String apiPath, final TreeMap<String, String> params) {
public void completion(Conversation conversation) throws InterwebException {
CompletionResults results = sendRequest("/chat/completions", conversation, CompletionResults.class);
if (results.getLastMessage() != null) {
conversation.addMessage(results.getLastMessage());
}
if (results.getChatTitle() != null) {
conversation.setTitle(results.getChatTitle());
}
if (results.getCost() != null) {
conversation.setEstimatedCost(results.getCost().getChat());
}
if (results.getUsage() != null) {
conversation.setUsedTokens(results.getUsage().getTotalTokens());
}
}

private URI createRequestUri(final String apiPath, final Map<String, String> params) {
StringBuilder sb = new StringBuilder();

sb.append(serverUrl);
Expand All @@ -96,29 +116,48 @@ private URI createRequestUri(final String apiPath, final TreeMap<String, String>
return URI.create(sb.toString());
}

public <T> T sendRequest(final String apiPath, final Map<String, String> params, TypeReference<T> valueType) throws InterwebException {
try {
final URI uri = createRequestUri(apiPath, params);
HttpRequest.Builder builder = HttpRequest.newBuilder().uri(uri).GET();

HttpResponse<String> response = sendRequest(builder);
return mapper.readValue(response.body(), valueType);
} catch (IOException e) {
throw new InterwebException("An error occurred during Interweb request " + apiPath, e);
}
}

public <T> T sendRequest(final String apiPath, final Object query, Class<T> valueType) throws InterwebException {
try {
final URI uri = createRequestUri(apiPath, null);
String body = mapper.writeValueAsString(query);

HttpRequest request = HttpRequest.newBuilder().POST(HttpRequest.BodyPublishers.ofString(body))
.uri(uri)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Api-Key", apikey)
.build();
final URI uri = createRequestUri(apiPath, null);
HttpRequest.Builder builder = HttpRequest.newBuilder().uri(uri);
builder.POST(HttpRequest.BodyPublishers.ofString(body)).header("Content-Type", "application/json");

HttpResponse<String> response = sendRequest(builder);
return mapper.readValue(response.body(), valueType);
} catch (IOException e) {
throw new InterwebException("An error occurred during Interweb request " + query, e);
}
}

public HttpResponse<String> sendRequest(final HttpRequest.Builder builder) throws InterwebException {
try {
builder.header("Api-Key", apikey);
builder.header("Accept", "application/json");

HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
String responseBody = response.body();
HttpResponse<String> response = client.send(builder.build(), HttpResponse.BodyHandlers.ofString());

if (response.statusCode() != 200) {
throw new InterwebException("Interweb request failed, response: " + responseBody);
throw new InterwebException("Interweb request failed, response: " + response.body());
}

return mapper.readValue(responseBody, valueType);
return response;
} catch (IOException | InterruptedException e) {
throw new InterwebException("An error occurred during Interweb request " + query, e);
throw new InterwebException("An error occurred during Interweb request", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package de.l3s.interweb.client;

import java.io.Serial;

public class InterwebException extends Exception {
@Serial
private static final long serialVersionUID = 1648272342540671760L;

public InterwebException(String message) {
Expand Down
Loading

0 comments on commit 64ab652

Please sign in to comment.