From 64ab652aad6595652b2bcda087072ce2efd0221f Mon Sep 17 00:00:00 2001 From: Oleh Astappiev Date: Mon, 2 Oct 2023 18:52:27 +0200 Subject: [PATCH] refactor: chat completion requests, separate by users, title autogeneration --- .../connector/openai/ErrorResponse.java | 12 +- .../connector/openai/OpenaiClient.java | 6 +- .../connector/openai/OpenaiConnector.java | 10 +- .../openai/entity/CompletionResponse.java | 67 +++++++++++ .../openai/entity/CompletionsBody.java | 57 ++++++++++ interweb-client/pom.xml | 4 + .../java/de/l3s/interweb/client/Interweb.java | 73 +++++++++--- .../interweb/client/InterwebException.java | 3 - .../client/InterwebCompletionTest.java | 81 +++++++++++++ .../interweb/client/InterwebDescribeTest.java | 39 +++++++ .../interweb/client/InterwebSearchTest.java | 49 ++++++++ .../interweb/client/InterwebSuggestTest.java | 47 ++++++++ .../de/l3s/interweb/client/InterwebTest.java | 107 ------------------ .../l3s/interweb/core/completion/Choice.java | 17 +-- .../core/completion/CompletionQuery.java | 52 +++++++-- .../core/completion/CompletionResults.java | 40 +++++-- .../core/completion/Conversation.java | 51 +++++++++ .../de/l3s/interweb/server/chat/Chat.java | 26 ++--- .../l3s/interweb/server/chat/ChatMessage.java | 4 + .../interweb/server/chat/ChatResource.java | 67 +++++++---- .../l3s/interweb/server/chat/ChatService.java | 14 ++- .../src/main/resources/application.properties | 1 + 22 files changed, 623 insertions(+), 204 deletions(-) create mode 100644 connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionResponse.java create mode 100644 connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionsBody.java create mode 100644 interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java create mode 100644 interweb-client/src/test/java/de/l3s/interweb/client/InterwebDescribeTest.java create mode 100644 interweb-client/src/test/java/de/l3s/interweb/client/InterwebSearchTest.java create mode 100644 interweb-client/src/test/java/de/l3s/interweb/client/InterwebSuggestTest.java delete mode 100644 interweb-client/src/test/java/de/l3s/interweb/client/InterwebTest.java create mode 100644 interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/ErrorResponse.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/ErrorResponse.java index 7a4b609d..d58fa5ff 100644 --- a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/ErrorResponse.java +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/ErrorResponse.java @@ -1,13 +1,13 @@ package de.l3s.interweb.connector.openai; public class ErrorResponse { - public Error error; + Error error; public static class Error { - public String message; - public Object type; - public String param; - public String code; - public Integer status; + String message; + Object type; + String param; + String code; + Integer status; } } diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java index 2b78f630..f2e99c37 100644 --- a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java @@ -9,9 +9,9 @@ import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import de.l3s.interweb.connector.openai.entity.CompletionResponse; +import de.l3s.interweb.connector.openai.entity.CompletionsBody; import de.l3s.interweb.core.ConnectorException; -import de.l3s.interweb.core.completion.CompletionQuery; -import de.l3s.interweb.core.completion.CompletionResults; @Path("/openai/deployments") @Consumes(MediaType.APPLICATION_JSON) @@ -22,7 +22,7 @@ public interface OpenaiClient { @POST @Path("/{model}/chat/completions") - Uni chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionQuery body); + Uni chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionsBody body); @ClientExceptionMapper static RuntimeException toException(Response response) { diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiConnector.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiConnector.java index e20a29ff..a5659b19 100644 --- a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiConnector.java +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiConnector.java @@ -7,6 +7,7 @@ import io.smallrye.mutiny.Uni; import org.eclipse.microprofile.rest.client.inject.RestClient; +import de.l3s.interweb.connector.openai.entity.CompletionsBody; import de.l3s.interweb.core.ConnectorException; import de.l3s.interweb.core.completion.CompletionConnector; import de.l3s.interweb.core.completion.CompletionQuery; @@ -49,6 +50,13 @@ public UsagePrice getPrice(String model) { @Override public Uni complete(CompletionQuery query) throws ConnectorException { - return openai.chatCompletions(query.getModel(), version, query); + return openai.chatCompletions(query.getModel(), version, new CompletionsBody(query)).map(response -> { + CompletionResults results = new CompletionResults(); + results.setModel(query.getModel()); + results.setCreated(response.getCreated()); + results.setChoices(response.getChoices()); + results.setUsage(response.getUsage()); + return results; + }); } } diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionResponse.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionResponse.java new file mode 100644 index 00000000..de98fe9d --- /dev/null +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionResponse.java @@ -0,0 +1,67 @@ +package de.l3s.interweb.connector.openai.entity; + +import java.time.Instant; +import java.util.List; + +import io.quarkus.runtime.annotations.RegisterForReflection; + +import de.l3s.interweb.core.completion.Choice; +import de.l3s.interweb.core.completion.Usage; + +@RegisterForReflection +public class CompletionResponse { + private String id; + private String object; + private String model; + private Usage usage; + private Instant created; + private List choices; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getObject() { + return object; + } + + public void setObject(String object) { + this.object = object; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getChoices() { + return choices; + } + + public void setChoices(List choices) { + this.choices = choices; + } + + public Usage getUsage() { + return usage; + } + + public void setUsage(Usage usage) { + this.usage = usage; + } + + public Instant getCreated() { + return created; + } + + public void setCreated(Instant created) { + this.created = created; + } +} diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionsBody.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionsBody.java new file mode 100644 index 00000000..57cc296b --- /dev/null +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/entity/CompletionsBody.java @@ -0,0 +1,57 @@ +package de.l3s.interweb.connector.openai.entity; + +import java.util.List; + +import io.quarkus.runtime.annotations.RegisterForReflection; + +import de.l3s.interweb.core.completion.CompletionQuery; +import de.l3s.interweb.core.completion.Message; + +@RegisterForReflection +public final class CompletionsBody { + + private List messages; + + private Double temperature; + + private Double topP; + + private Double frequencyPenalty; + + private Double presencePenalty; + + private Integer maxTokens; + + public CompletionsBody(CompletionQuery query) { + this.messages = query.getMessages(); + this.temperature = query.getTemperature(); + this.topP = query.getTopP(); + this.frequencyPenalty = query.getPresencePenalty(); + this.presencePenalty = query.getPresencePenalty(); + this.maxTokens = query.getMaxTokens(); + } + + public List getMessages() { + return messages; + } + + public Double getTemperature() { + return temperature; + } + + public Double getTopP() { + return topP; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public Integer getMaxTokens() { + return maxTokens; + } +} diff --git a/interweb-client/pom.xml b/interweb-client/pom.xml index 6c17002a..f6c342bb 100644 --- a/interweb-client/pom.xml +++ b/interweb-client/pom.xml @@ -17,6 +17,10 @@ Interweb Client + + 11 + + de.l3s.interweb diff --git a/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java b/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java index c2bacc7d..e4fdf45c 100644 --- a/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java +++ b/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java @@ -1,16 +1,16 @@ package de.l3s.interweb.client; import java.io.IOException; -import java.io.Serial; import java.io.Serializable; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.util.List; import java.util.Map; -import java.util.TreeMap; import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -18,6 +18,7 @@ import de.l3s.interweb.core.completion.CompletionQuery; import de.l3s.interweb.core.completion.CompletionResults; +import de.l3s.interweb.core.completion.Conversation; import de.l3s.interweb.core.describe.DescribeQuery; import de.l3s.interweb.core.describe.DescribeResults; import de.l3s.interweb.core.search.SearchQuery; @@ -27,7 +28,6 @@ import de.l3s.interweb.core.util.StringUtils; public class Interweb implements Serializable { - @Serial private static final long serialVersionUID = 7231324400348062196L; private final ObjectMapper mapper; @@ -72,11 +72,31 @@ public DescribeResults describe(String link) throws InterwebException { return sendRequest("/describe", params, DescribeResults.class); } - public CompletionResults chatCompletions(CompletionQuery query) throws InterwebException { + public List conversations(String user) throws InterwebException { + return sendRequest("/chat", Map.of("user", user), new TypeReference<>() {}); + } + + public CompletionResults completion(CompletionQuery query) throws InterwebException { return sendRequest("/chat/completions", query, CompletionResults.class); } - private URI createRequestUri(final String apiPath, final TreeMap params) { + public void completion(Conversation conversation) throws InterwebException { + CompletionResults results = sendRequest("/chat/completions", conversation, CompletionResults.class); + if (results.getLastMessage() != null) { + conversation.addMessage(results.getLastMessage()); + } + if (results.getChatTitle() != null) { + conversation.setTitle(results.getChatTitle()); + } + if (results.getCost() != null) { + conversation.setEstimatedCost(results.getCost().getChat()); + } + if (results.getUsage() != null) { + conversation.setUsedTokens(results.getUsage().getTotalTokens()); + } + } + + private URI createRequestUri(final String apiPath, final Map params) { StringBuilder sb = new StringBuilder(); sb.append(serverUrl); @@ -96,29 +116,48 @@ private URI createRequestUri(final String apiPath, final TreeMap return URI.create(sb.toString()); } + public T sendRequest(final String apiPath, final Map params, TypeReference valueType) throws InterwebException { + try { + final URI uri = createRequestUri(apiPath, params); + HttpRequest.Builder builder = HttpRequest.newBuilder().uri(uri).GET(); + + HttpResponse response = sendRequest(builder); + return mapper.readValue(response.body(), valueType); + } catch (IOException e) { + throw new InterwebException("An error occurred during Interweb request " + apiPath, e); + } + } + public T sendRequest(final String apiPath, final Object query, Class valueType) throws InterwebException { try { - final URI uri = createRequestUri(apiPath, null); String body = mapper.writeValueAsString(query); - HttpRequest request = HttpRequest.newBuilder().POST(HttpRequest.BodyPublishers.ofString(body)) - .uri(uri) - .header("Accept", "application/json") - .header("Content-Type", "application/json") - .header("Api-Key", apikey) - .build(); + final URI uri = createRequestUri(apiPath, null); + HttpRequest.Builder builder = HttpRequest.newBuilder().uri(uri); + builder.POST(HttpRequest.BodyPublishers.ofString(body)).header("Content-Type", "application/json"); + + HttpResponse response = sendRequest(builder); + return mapper.readValue(response.body(), valueType); + } catch (IOException e) { + throw new InterwebException("An error occurred during Interweb request " + query, e); + } + } + + public HttpResponse sendRequest(final HttpRequest.Builder builder) throws InterwebException { + try { + builder.header("Api-Key", apikey); + builder.header("Accept", "application/json"); HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - String responseBody = response.body(); + HttpResponse response = client.send(builder.build(), HttpResponse.BodyHandlers.ofString()); if (response.statusCode() != 200) { - throw new InterwebException("Interweb request failed, response: " + responseBody); + throw new InterwebException("Interweb request failed, response: " + response.body()); } - return mapper.readValue(responseBody, valueType); + return response; } catch (IOException | InterruptedException e) { - throw new InterwebException("An error occurred during Interweb request " + query, e); + throw new InterwebException("An error occurred during Interweb request", e); } } } diff --git a/interweb-client/src/main/java/de/l3s/interweb/client/InterwebException.java b/interweb-client/src/main/java/de/l3s/interweb/client/InterwebException.java index 429e97f8..17899281 100644 --- a/interweb-client/src/main/java/de/l3s/interweb/client/InterwebException.java +++ b/interweb-client/src/main/java/de/l3s/interweb/client/InterwebException.java @@ -1,9 +1,6 @@ package de.l3s.interweb.client; -import java.io.Serial; - public class InterwebException extends Exception { - @Serial private static final long serialVersionUID = 1648272342540671760L; public InterwebException(String message) { diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java new file mode 100644 index 00000000..fff35930 --- /dev/null +++ b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java @@ -0,0 +1,81 @@ +package de.l3s.interweb.client; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.List; + +import jakarta.inject.Inject; + +import io.quarkus.test.junit.QuarkusTest; +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import de.l3s.interweb.core.completion.*; + +@Disabled +@QuarkusTest +class InterwebCompletionTest { + + private final Interweb interweb; + + @Inject + public InterwebCompletionTest(@ConfigProperty(name = "interweb.server") String server, @ConfigProperty(name = "interweb.apikey") String apikey) { + this.interweb = new Interweb(server, apikey); + } + + @Test + void conversationsTest() throws InterwebException { + List response = interweb.conversations("user1"); + + assertFalse(response.isEmpty()); + assertFalse(response.get(0).getTitle().isEmpty()); + } + + @Test + void chatCompletionsTest() throws InterwebException { + CompletionQuery query = new CompletionQuery(); + query.setUser("user1"); + query.setGenerateTitle(true); + query.setModel("gpt-35-turbo"); + query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); + query.addMessage("What is your name?.", Message.Role.user); + + CompletionResults response = interweb.completion(query); + assertFalse(response.getResults().isEmpty()); + + for (Choice result : response.getResults()) { + assertNotNull(result.getMessage()); + System.out.println(result.getMessage().getContent()); + } + } + + @Test + void conversationTest() throws InterwebException { + Conversation query = new Conversation(); + query.setUser("user1"); + query.setGenerateTitle(true); + query.setModel("gpt-35-turbo"); + query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); + query.addMessage("What is your name?.", Message.Role.user); + + assertNull(query.getTitle()); + assertNull(query.getEstimatedCost()); + assertEquals(2, query.getMessages().size()); + + interweb.completion(query); + + assertNotNull(query.getTitle()); + assertNotNull(query.getEstimatedCost()); + assertEquals(3, query.getMessages().size()); + + query.addMessage("That's time now?", Message.Role.user); + interweb.completion(query); + + assertEquals(5, query.getMessages().size()); + + for (Message result : query.getMessages()) { + System.out.println(result.getContent()); + } + } +} diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebDescribeTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebDescribeTest.java new file mode 100644 index 00000000..cb856745 --- /dev/null +++ b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebDescribeTest.java @@ -0,0 +1,39 @@ +package de.l3s.interweb.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; + +import io.quarkus.test.junit.QuarkusTest; +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import de.l3s.interweb.core.describe.DescribeQuery; +import de.l3s.interweb.core.describe.DescribeResults; + +@Disabled +@QuarkusTest +class InterwebDescribeTest { + + private final Interweb interweb; + + @Inject + public InterwebDescribeTest(@ConfigProperty(name = "interweb.server") String server, @ConfigProperty(name = "interweb.apikey") String apikey) { + this.interweb = new Interweb(server, apikey); + } + + @Test + void describeTest() throws InterwebException { + DescribeQuery query = new DescribeQuery(); + // query.setLink("https://vimeo.com/524933864"); + query.setId("524933864"); + query.setServices("vimeo"); + + DescribeResults response = interweb.describe(query); + + assertEquals("524933864", response.getEntity().getId()); + assertEquals("Vimeo | Video Power", response.getEntity().getTitle()); + assertEquals("https://vimeo.com/524933864", response.getEntity().getUrl()); + } +} diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSearchTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSearchTest.java new file mode 100644 index 00000000..e267bbab --- /dev/null +++ b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSearchTest.java @@ -0,0 +1,49 @@ +package de.l3s.interweb.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.test.junit.QuarkusTest; +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import de.l3s.interweb.core.search.*; + +@Disabled +@QuarkusTest +class InterwebSearchTest { + + private final Interweb interweb; + + @Inject + public InterwebSearchTest(@ConfigProperty(name = "interweb.server") String server, @ConfigProperty(name = "interweb.apikey") String apikey) { + this.interweb = new Interweb(server, apikey); + } + + @Test + void searchTest() throws InterwebException { + SearchQuery query = new SearchQuery(); + query.setQuery("hannover"); + query.setLanguage("en"); + query.setContentTypes(ContentType.video); + query.setServices("Vimeo", "YouTube"); + query.setPerPage(32); + query.setPage(1); + query.setExtras(SearchExtra.duration, SearchExtra.tags); + + SearchResults response = interweb.search(query); + assertEquals(response.getResults().size(), 2); + + for (SearchConnectorResults result : response.getResults()) { + assertTrue(result.getTotalResults() > 0); + + for (SearchItem item : result.getItems()) { + System.out.println(item.getTitle() + " [" + item.getDuration() + "]"); + System.out.println(item.getUrl()); + } + } + } +} diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSuggestTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSuggestTest.java new file mode 100644 index 00000000..29e2c72f --- /dev/null +++ b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebSuggestTest.java @@ -0,0 +1,47 @@ +package de.l3s.interweb.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +import jakarta.inject.Inject; + +import io.quarkus.test.junit.QuarkusTest; +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import de.l3s.interweb.core.suggest.SuggestConnectorResults; +import de.l3s.interweb.core.suggest.SuggestQuery; +import de.l3s.interweb.core.suggest.SuggestResults; + +@Disabled +@QuarkusTest +class InterwebSuggestTest { + + private final Interweb interweb; + + @Inject + public InterwebSuggestTest(@ConfigProperty(name = "interweb.server") String server, @ConfigProperty(name = "interweb.apikey") String apikey) { + this.interweb = new Interweb(server, apikey); + } + + @Test + void suggestTest() throws InterwebException { + SuggestQuery query = new SuggestQuery(); + query.setQuery("hannover"); + query.setLanguage("de"); + + SuggestResults response = interweb.suggest(query); + assertEquals(response.getResults().size(), 2); + + for (SuggestConnectorResults result : response.getResults()) { + assertFalse(result.getItems().isEmpty()); + + System.out.println(result.getService() + ":"); + for (String item : result.getItems()) { + System.out.println(item); + } + System.out.println(); + } + } +} diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebTest.java deleted file mode 100644 index 67b4f661..00000000 --- a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebTest.java +++ /dev/null @@ -1,107 +0,0 @@ -package de.l3s.interweb.client; - -import static org.junit.jupiter.api.Assertions.*; - -import jakarta.inject.Inject; - -import io.quarkus.test.junit.QuarkusTest; -import org.eclipse.microprofile.config.inject.ConfigProperty; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import de.l3s.interweb.core.completion.Choice; -import de.l3s.interweb.core.completion.CompletionQuery; -import de.l3s.interweb.core.completion.CompletionResults; -import de.l3s.interweb.core.completion.Message; -import de.l3s.interweb.core.describe.DescribeQuery; -import de.l3s.interweb.core.describe.DescribeResults; -import de.l3s.interweb.core.search.*; -import de.l3s.interweb.core.suggest.SuggestConnectorResults; -import de.l3s.interweb.core.suggest.SuggestQuery; -import de.l3s.interweb.core.suggest.SuggestResults; - -@Disabled -@QuarkusTest -class InterwebTest { - - private final Interweb interweb; - - @Inject - public InterwebTest(@ConfigProperty(name = "interweb.server") String server, @ConfigProperty(name = "interweb.apikey") String apikey) { - this.interweb = new Interweb(server, apikey); - } - - @Test - void searchTest() throws InterwebException { - SearchQuery query = new SearchQuery(); - query.setQuery("hannover"); - query.setLanguage("en"); - query.setContentTypes(ContentType.video); - query.setServices("Vimeo", "YouTube"); - query.setPerPage(32); - query.setPage(1); - query.setExtras(SearchExtra.duration, SearchExtra.tags); - - SearchResults response = interweb.search(query); - assertEquals(response.getResults().size(), 2); - - for (SearchConnectorResults result : response.getResults()) { - assertTrue(result.getTotalResults() > 0); - - for (SearchItem item : result.getItems()) { - System.out.println(item.getTitle() + " [" + item.getDuration() + "]"); - System.out.println(item.getUrl()); - } - } - } - - @Test - void suggestTest() throws InterwebException { - SuggestQuery query = new SuggestQuery(); - query.setQuery("hannover"); - query.setLanguage("de"); - - SuggestResults response = interweb.suggest(query); - assertEquals(response.getResults().size(), 2); - - for (SuggestConnectorResults result : response.getResults()) { - assertFalse(result.getItems().isEmpty()); - - System.out.println(result.getService() + ":"); - for (String item : result.getItems()) { - System.out.println(item); - } - System.out.println(); - } - } - - @Test - void describeTest() throws InterwebException { - DescribeQuery query = new DescribeQuery(); - // query.setLink("https://vimeo.com/524933864"); - query.setId("524933864"); - query.setServices("vimeo"); - - DescribeResults response = interweb.describe(query); - - assertEquals("524933864", response.getEntity().getId()); - assertEquals("Vimeo | Video Power", response.getEntity().getTitle()); - assertEquals("https://vimeo.com/524933864", response.getEntity().getUrl()); - } - - @Test - void chatCompletionsTest() throws InterwebException { - CompletionQuery query = new CompletionQuery(); - query.setModel("gpt-35-turbo"); - query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); - query.addMessage("What is your name?.", Message.Role.user); - - CompletionResults response = interweb.chatCompletions(query); - assertFalse(response.getResults().isEmpty()); - - for (Choice result : response.getResults()) { - assertNotNull(result.getMessage()); - System.out.println(result.getMessage().getContent()); - } - } -} diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/Choice.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Choice.java index 4a1c64b6..66f05082 100644 --- a/interweb-core/src/main/java/de/l3s/interweb/core/completion/Choice.java +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Choice.java @@ -1,15 +1,14 @@ package de.l3s.interweb.core.completion; -import java.time.Instant; - import io.quarkus.runtime.annotations.RegisterForReflection; -import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import de.l3s.interweb.core.ConnectorResults; @RegisterForReflection +@JsonIgnoreProperties({"elapsed_time", "created"}) public class Choice extends ConnectorResults { private int index; @JsonProperty("finish_reason") @@ -39,16 +38,4 @@ public Message getMessage() { public void setMessage(Message message) { this.message = message; } - - @Override - @JsonIgnore - public long getElapsedTime() { - return super.getElapsedTime(); - } - - @Override - @JsonIgnore - public Instant getCreated() { - return super.getCreated(); - } } diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java index b8e42cda..acb5a4fb 100644 --- a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java @@ -20,19 +20,28 @@ public class CompletionQuery { * Defaults to "gpt-35-turbo". */ @NotEmpty - @JsonProperty(access = JsonProperty.Access.WRITE_ONLY) + @JsonProperty("model") private String model = "gpt-35-turbo"; /** * ID of the chat to continue. */ - @JsonProperty(access = JsonProperty.Access.WRITE_ONLY) - private UUID chatId; + @JsonProperty("id") + private UUID id; + + /** + * Participant involved in the conversation can be denoted by its ID. + * When provided, it can be used to filter the conversations. + * Applied automatically when using chatId. + */ + @JsonProperty(value = "user") + private String user; /** * A list of messages comprising the conversation so far. */ @NotEmpty + @JsonProperty("messages") private List messages = new ArrayList<>(); /** @@ -43,6 +52,7 @@ public class CompletionQuery { */ @Min(0) @Max(2) + @JsonProperty("temperature") private Double temperature = 1.0; /** @@ -82,6 +92,12 @@ public class CompletionQuery { @JsonProperty("max_tokens") private Integer maxTokens = 800; + /** + * Whether the conversation should be summarized into a title. Defaults to false. + */ + @JsonProperty(value = "generate_title") + private boolean generateTitle; + public String getModel() { return model; } @@ -90,12 +106,12 @@ public void setModel(String model) { this.model = model; } - public UUID getChatId() { - return chatId; + public UUID getId() { + return id; } - public void setChatId(UUID chatId) { - this.chatId = chatId; + public void setId(UUID id) { + this.id = id; } public List getMessages() { @@ -106,8 +122,20 @@ public void setMessages(final List messages) { this.messages = messages; } + public void addMessage(final Message message) { + this.messages.add(message); + } + public void addMessage(final String message, final Message.Role role) { - this.messages.add(new Message(role, message)); + addMessage(new Message(role, message)); + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; } public Double getTemperature() { @@ -149,4 +177,12 @@ public Integer getMaxTokens() { public void setMaxTokens(final Integer maxTokens) { this.maxTokens = maxTokens; } + + public boolean isGenerateTitle() { + return generateTitle; + } + + public void setGenerateTitle(boolean generateTitle) { + this.generateTitle = generateTitle; + } } diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionResults.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionResults.java index d45f9236..ad81202e 100644 --- a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionResults.java +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionResults.java @@ -7,14 +7,20 @@ import io.quarkus.runtime.annotations.RegisterForReflection; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; import de.l3s.interweb.core.Results; @RegisterForReflection +@JsonIgnoreProperties("results") +@JsonPropertyOrder({"id", "title", "model", "choices", "usage", "cost", "elapsed_time", "created"}) public class CompletionResults extends Results { - @JsonProperty(value = "id", access = JsonProperty.Access.READ_ONLY) + @JsonProperty(value = "id") private UUID chatId; + @JsonProperty(value = "title") + private String chatTitle; private String model; private Usage usage; private UsageCost cost; @@ -28,12 +34,12 @@ public void setChatId(UUID chatId) { this.chatId = chatId; } - public Instant getCreated() { - return created; + public String getChatTitle() { + return chatTitle; } - public void setCreated(Instant created) { - this.created = created; + public void setChatTitle(String chatTitle) { + this.chatTitle = chatTitle; } public String getModel() { @@ -44,16 +50,20 @@ public void setModel(String model) { this.model = model; } - @Override - @JsonIgnore - public List getResults() { - return super.getResults(); - } - + @JsonProperty public List getChoices() { return getResults(); } + @JsonIgnore + public Message getLastMessage() { + if (getResults().isEmpty()) { + return null; + } + + return getResults().get(0).getMessage(); + } + @JsonProperty public void setChoices(List choices) { super.add(choices); @@ -71,6 +81,14 @@ public UsageCost getCost() { return cost; } + public Instant getCreated() { + return created; + } + + public void setCreated(Instant created) { + this.created = created; + } + public void updateCosts(UsagePrice price) { double promptCost = (usage.getPromptTokens() / 1000d) * price.getPrompt(); double completionCost = (usage.getCompletionTokens() / 1000d) * price.getCompletion(); diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java new file mode 100644 index 00000000..5b9b7efa --- /dev/null +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java @@ -0,0 +1,51 @@ +package de.l3s.interweb.core.completion; + +import java.time.Instant; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class Conversation extends CompletionQuery { + @JsonProperty("title") + private String title; + + @JsonProperty("used_tokens") + private Integer usedTokens; + + @JsonProperty("estimated_cost") + private Double estimatedCost; + + @JsonProperty("created") + private Instant created; + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public Integer getUsedTokens() { + return usedTokens; + } + + public void setUsedTokens(Integer usedTokens) { + this.usedTokens = usedTokens; + } + + public Double getEstimatedCost() { + return estimatedCost; + } + + public void setEstimatedCost(Double estimatedCost) { + this.estimatedCost = estimatedCost; + } + + public Instant getCreated() { + return created; + } + + public void setCreated(Instant created) { + this.created = created; + } +} diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java index ffe47886..33819190 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java @@ -21,7 +21,9 @@ @Entity @Cacheable -@Table(name = "chat") +@Table(name = "chat", indexes = { + @Index(name = "user_index", columnList = "user"), +}) public class Chat extends PanacheEntityBase { @Id @UuidGenerator @@ -33,17 +35,23 @@ public class Chat extends PanacheEntityBase { @ManyToOne(optional = false, fetch = FetchType.LAZY) public Consumer consumer; + @Size(max = 32) + public String user; + @NotNull @Size(max = 32) public String model; + @Size(max = 512) + public String title; + @NotNull @ColumnDefault("0") - public Integer used_tokens = 0; + public Integer usedTokens = 0; @NotNull @ColumnDefault("0") - public Double estimated_cost = 0d; + public Double estimatedCost = 0d; @CreationTimestamp public Instant created; @@ -65,19 +73,11 @@ public List getMessages() { } public void addCosts(int tokens, double cost) { - this.used_tokens += tokens; - this.estimated_cost += cost; + this.usedTokens += tokens; + this.estimatedCost += cost; } public static Uni findById(UUID id) { return find("id", id).firstResult(); } - - public static Uni findByIdWithMessages(UUID id) { - return find("from Chat p left join fetch p.messages WHERE p.id = ?1", id).firstResult(); - } - - public static Uni> findByConsumer(Consumer consumer) { - return list("consumer.id", consumer.id); - } } diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java index 9d844f29..7d1e7828 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java @@ -49,6 +49,10 @@ public ChatMessage(final Message message) { this.content = message.getContent(); } + public Message toMessage() { + return new Message(role, content); + } + public static Uni> listByChat(UUID id) { return list("chat.id", id); } diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java index 657cceb2..7d1aded3 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java @@ -1,14 +1,13 @@ package de.l3s.interweb.server.chat; import java.util.List; +import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; import jakarta.inject.Inject; import jakarta.validation.Valid; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.POST; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.*; import jakarta.ws.rs.core.Context; import io.quarkus.hibernate.reactive.panache.Panache; @@ -20,6 +19,7 @@ import de.l3s.interweb.core.completion.CompletionQuery; import de.l3s.interweb.core.completion.CompletionResults; import de.l3s.interweb.core.completion.Message; +import de.l3s.interweb.core.completion.UsagePrice; import de.l3s.interweb.server.principal.Consumer; @Path("/chat") @@ -33,8 +33,13 @@ public class ChatResource { @GET @Authenticated - public Uni> chats() { - return Chat.findByConsumer(securityIdentity.getCredential(Consumer.class)); + public Uni> chats(@QueryParam("user") String user) { + Consumer consumer = securityIdentity.getCredential(Consumer.class); + if (user == null) { + return Chat.list("consumer.id = ?1 AND user = NULL ORDER BY created DESC LIMIT 20", consumer.id); + } else { + return Chat.list("consumer.id = ?1 AND user = ?2 ORDER BY created DESC LIMIT 20", consumer.id, user); + } } @GET @@ -44,33 +49,55 @@ public Uni> chat(@PathParam("uuid") UUID id) { return ChatMessage.listByChat(id); } + @GET + @Authenticated + @Path("/models") + public Map models() { + return chatService.getModels(); + } + @POST @Authenticated @Path("/completions") public Uni completions(@Valid CompletionQuery query) { return getOrCreateChat(query).flatMap(chat -> { - long start = System.currentTimeMillis(); + //noinspection CodeBlock2Expr return chatService.completions(query).call(results -> { results.setChatId(chat.id); - results.setElapsedTime(System.currentTimeMillis() - start); return persistMessages(chat, query.getMessages(), results); - }); + }).invoke(results -> { + chat.addCosts(results.getUsage().getTotalTokens(), results.getCost().getResponse()); + results.getCost().setChat(chat.estimatedCost); + }).call(results -> { + if (chat.title == null && query.isGenerateTitle()) { + return generateTitle(chat).invoke(titleCompletion -> { + chat.title = titleCompletion.getLastMessage().getContent(); + results.setChatTitle(chat.title); + + chat.addCosts(titleCompletion.getUsage().getTotalTokens(), titleCompletion.getCost().getResponse()); + results.getCost().setChat(chat.estimatedCost); + }); + } else { + return Uni.createFrom().voidItem(); + } + }).call(() -> Chat.update("title = ?1, usedTokens = ?2, estimatedCost = ?3 where id = ?4", chat.title, chat.usedTokens, chat.estimatedCost, chat.id)); }); } private Uni getOrCreateChat(CompletionQuery query) { - if (query.getChatId() != null) { - return Chat.findById(query.getChatId()).call(chat -> Mutiny.fetch(chat.getMessages())); + if (query.getId() != null) { + return Chat.findById(query.getId()).call(chat -> Mutiny.fetch(chat.getMessages())); } Chat chat = new Chat(); chat.consumer = securityIdentity.getCredential(Consumer.class); chat.model = query.getModel(); + chat.user = query.getUser(); return Panache.withTransaction(chat::persist); } - private Uni persistMessages(final Chat chat, final List history, final CompletionResults results) { - for (Message message : history) { + private Uni persistMessages(final Chat chat, final List queryMessages, final CompletionResults results) { + for (Message message : queryMessages) { if (message.getId() == null) { ChatMessage cm = new ChatMessage(message); if (!chat.getMessages().contains(cm)) { @@ -79,12 +106,14 @@ private Uni persistMessages(final Chat chat, final List history, } } - if (results.getCost() != null) { - chat.addCosts(results.getUsage().getTotalTokens(), results.getCost().getResponse()); - results.getCost().setChat(chat.estimated_cost); - } - - chat.addMessage(new ChatMessage(results.getChoices().get(0).getMessage())); + chat.addMessage(new ChatMessage(results.getLastMessage())); return Panache.withTransaction(() -> ChatMessage.persist(chat.getMessages())); } + + private Uni generateTitle(final Chat chat) { + CompletionQuery query = new CompletionQuery(); + query.setMessages(chat.getMessages().stream().map(ChatMessage::toMessage).collect(Collectors.toList())); + query.addMessage("Give a short name for this conversation; don't use any formatting; length between 80 and 120 characters", Message.Role.user); + return chatService.completions(query); + } } diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatService.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatService.java index 82e4ae87..4126316a 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatService.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatService.java @@ -15,6 +15,7 @@ import de.l3s.interweb.core.completion.CompletionConnector; import de.l3s.interweb.core.completion.CompletionQuery; import de.l3s.interweb.core.completion.CompletionResults; +import de.l3s.interweb.core.completion.UsagePrice; @ApplicationScoped public class ChatService { @@ -38,14 +39,25 @@ public Collection getConnectors() { return this.services.values(); } + public Map getModels() { + Map models = new HashMap<>(); + for (CompletionConnector connector : this.services.values()) { + for (String model : connector.getModels()) { + models.put(model, connector.getPrice(model)); + } + } + return models; + } + public Uni completions(CompletionQuery query) { return completions(query, services.get(query.getModel())); } private Uni completions(CompletionQuery query, CompletionConnector connector) { + long start = System.currentTimeMillis(); return connector.complete(query).map(results -> { - results.setModel(query.getModel()); results.updateCosts(connector.getPrice(query.getModel())); + results.setElapsedTime(System.currentTimeMillis() - start); return results; }); } diff --git a/interweb-server/src/main/resources/application.properties b/interweb-server/src/main/resources/application.properties index 369b6c98..f846c9a0 100644 --- a/interweb-server/src/main/resources/application.properties +++ b/interweb-server/src/main/resources/application.properties @@ -15,6 +15,7 @@ quarkus.datasource.db-kind=mariadb quarkus.datasource.reactive.max-size=20 quarkus.hibernate-orm.database.generation=update +quarkus.hibernate-orm.physical-naming-strategy=org.hibernate.boot.model.naming.CamelCaseToUnderscoresNamingStrategy %dev.quarkus.hibernate-orm.log.sql=true quarkus.health.openapi.included=false