Skip to content

Commit

Permalink
feat(ollama-connector): add stream chat request
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Jul 17, 2024
1 parent c9b1f9c commit 4b0002e
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package de.l3s.interweb.connector.ollama;

import de.l3s.interweb.connector.ollama.entity.ChatStreamBody;

import io.smallrye.mutiny.Multi;

import jakarta.ws.rs.*;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
Expand All @@ -13,6 +17,8 @@
import de.l3s.interweb.connector.ollama.entity.TagsResponse;
import de.l3s.interweb.core.ConnectorException;

import org.jboss.resteasy.reactive.common.util.RestMediaType;

@Path("")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -25,7 +31,12 @@ public interface OllamaClient {
*/
@POST
@Path("/api/chat")
Uni<ChatResponse> chatCompletions(ChatBody body);
Uni<ChatResponse> chat(ChatBody body);

@POST
@Path("/api/chat")
@Produces(RestMediaType.APPLICATION_NDJSON)
Multi<ChatResponse> chatStream(ChatStreamBody body);

@GET
@Path("/api/tags")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package de.l3s.interweb.connector.ollama;

import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import de.l3s.interweb.connector.ollama.entity.ChatResponse;

import de.l3s.interweb.connector.ollama.entity.ChatStreamBody;

import io.smallrye.mutiny.Multi;

import jakarta.enterprise.context.Dependent;

import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -66,30 +70,14 @@ public UsagePrice getPrice(String model) {

@Override
public Uni<CompletionResults> complete(CompletionQuery query) throws ConnectorException {
return ollama.chatCompletions(new ChatBody(query)).map(response -> {
Usage usage = new Usage(
response.getPromptEvalCount(),
response.getEvalCount()
);

List<Choice> choices = List.of(
new Choice(
0,
response.getDoneReason(),
new Message(
Message.Role.assistant,
response.getMessage().getContent()
)
)
);

CompletionResults results = new CompletionResults();
results.setModel(response.getModel());
results.setUsage(usage);
results.setChoices(choices);
results.setCreated(Instant.now());
return results;
});
final ChatBody body = new ChatBody(query);
return ollama.chat(body).map(ChatResponse::toCompletionResults);
}

@Override
public Multi<CompletionResults> completeStream(CompletionQuery query) throws ConnectorException {
final ChatStreamBody body = new ChatStreamBody(query);
return ollama.chatStream(body).map(ChatResponse::toCompletionResults);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@RegisterForReflection
@JsonInclude(JsonInclude.Include.NON_NULL)
public final class ChatBody {
public class ChatBody {

private String model;
private List<OllamaMessage> messages;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package de.l3s.interweb.connector.ollama.entity;

import java.time.Instant;
import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonProperty;

import de.l3s.interweb.core.completion.Choice;
import de.l3s.interweb.core.completion.CompletionResults;
import de.l3s.interweb.core.completion.Message;
import de.l3s.interweb.core.completion.Usage;

@RegisterForReflection
public class ChatResponse {

Expand Down Expand Up @@ -114,4 +122,25 @@ public String getCreatedAt() {
public void setCreatedAt(String createdAt) {
this.createdAt = createdAt;
}

public CompletionResults toCompletionResults() {
CompletionResults results = new CompletionResults();
results.setModel(model);

if (promptEvalCount != null && evalCount != null) {
Usage usage = new Usage(promptEvalCount, evalCount);
results.setUsage(usage);
}

if (message != null) {
Choice choice = new Choice(0, doneReason, new Message(Message.Role.assistant, message.getContent()));
results.setChoices(List.of(choice));
}

if (createdAt != null) {
results.setCreated(Instant.parse(createdAt));
}

return results;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.l3s.interweb.connector.ollama.entity;

import de.l3s.interweb.core.completion.CompletionQuery;

public final class ChatStreamBody extends ChatBody {

public ChatStreamBody(CompletionQuery query) {
super(query);
}

public Boolean getStream() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static org.junit.jupiter.api.Assertions.*;

import java.util.List;

import jakarta.inject.Inject;

import io.quarkus.test.junit.QuarkusTest;
Expand Down Expand Up @@ -31,22 +33,45 @@ void validate() throws ConnectorException {

@Test
void complete() throws ConnectorException {
connector.validate();
CompletionQuery query = new CompletionQuery();
query.setModel("llama3");
query.addMessage("You are Interweb Assistant, a helpful chat bot. Your name is not Claude it is Interweb Assistant.", Message.Role.system);
query.addMessage("What is your name?.", Message.Role.user);
query.addMessage("What is your name?", Message.Role.user);
query.setMaxTokens(100);
query.setTemperature(20.0);
query.setTopP(1.0);

query.setModel("llama3");

long start = System.currentTimeMillis();
CompletionResults results = connector.complete(query).await().indefinitely();
log.infov("duration: {0} ms", System.currentTimeMillis() - start);

assertEquals(1, results.getChoices().size());
log.infov("user: {0}", query.getMessages().getLast().getContent());
for (Choice result : results.getChoices()) {
log.infov("assistant: {0}", result.getMessage().getContent());
}
}

@Test
void completeStream() throws ConnectorException {
CompletionQuery query = new CompletionQuery();
query.setModel("llama3");
query.addMessage("You are Interweb Assistant, a helpful chat bot. Your name is not Claude it is Interweb Assistant.", Message.Role.system);
query.addMessage("What is your name?", Message.Role.user);

long start = System.currentTimeMillis();
List<CompletionResults> list = connector.completeStream(query).onItem().invoke(() -> {
log.infov("message after: {0} ms", System.currentTimeMillis() - start);
}).collect().asList().await().indefinitely();
log.infov("duration: {0} ms", System.currentTimeMillis() - start);
assertTrue(list.size() > 10);

StringBuilder sb = new StringBuilder();
for (CompletionResults results : list) {
for (Choice result : results.getChoices()) {
sb.append(result.getMessage().getContent());
}
}
log.info(sb.toString());
}
}

0 comments on commit 4b0002e

Please sign in to comment.