Skip to content

Commit

Permalink
feat: deserialize CallFunction from string
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Sep 17, 2024
1 parent fd2b04b commit 1c84878
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class CompletionsResponse {
@JsonProperty("system_fingerprint")
private String systemFingerprint;
private Instant created;
private List<OpenaiChoice> choices;
private List<Choice> choices;

public String getId() {
return id;
Expand Down Expand Up @@ -68,38 +68,19 @@ public void setCreated(Instant created) {
this.created = created;
}

public List<OpenaiChoice> getChoices() {
public List<Choice> getChoices() {
return choices;
}

public void setChoices(List<OpenaiChoice> choices) {
public void setChoices(List<Choice> choices) {
this.choices = choices;
}

public CompletionsResults toCompletionResults() {
CompletionsResults results = new CompletionsResults();
results.setModel(model);
results.setCreated(created);
results.setChoices(choices.stream().map(openaiChoice -> {
Message message = new Message(Role.assistant);
message.setContent(openaiChoice.getMessage().getContent());
message.setRefusal(openaiChoice.getMessage().getRefusal());
if (openaiChoice.getMessage().getToolCalls() != null) {
message.setToolCalls(openaiChoice.getMessage().getToolCalls().stream().map(openaiCallTool -> {
CallTool callTool = new CallTool();
callTool.setId(openaiCallTool.getId());
callTool.setType(openaiCallTool.getType());
callTool.setFunction(openaiCallTool.getFunction().toCallFunction());
return callTool;
}).toList());
}

Choice choice = new Choice();
choice.setIndex(openaiChoice.getIndex());
choice.setFinishReason(openaiChoice.getFinishReason());
choice.setMessage(message);
return choice;
}).toList());
results.setChoices(choices);
results.setUsage(usage);
results.setObject(object);
results.setSystemFingerprint(systemFingerprint);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package de.l3s.interweb.connector.openai.entity;

import java.io.Serial;
import java.io.Serializable;
import java.util.HashMap;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.chat.CallFunction;
import io.quarkus.runtime.annotations.RegisterForReflection;

import java.io.Serial;
import java.io.Serializable;

@RegisterForReflection
public class OpenaiCallFunction implements Serializable {
Expand Down Expand Up @@ -48,20 +44,4 @@ public void setArguments(String arguments) {
public String getArguments() {
return arguments;
}

public CallFunction toCallFunction() {
try {
CallFunction function = new CallFunction();
function.setName(name);

if (arguments != null && !arguments.isBlank()) {
TypeReference<HashMap<String, String>> typeRef = new TypeReference<>() {};
function.setArguments(new ObjectMapper().readValue(arguments, typeRef));
}

return function;
} catch (JsonProcessingException e) {
throw new ConnectorException("Failed to parse function arguments", e);
}
}
}

This file was deleted.

8 changes: 0 additions & 8 deletions interweb-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@
<groupId>de.l3s.interweb</groupId>
<artifactId>interweb-core</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
Expand Down
8 changes: 8 additions & 0 deletions interweb-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package de.l3s.interweb.core.chat;

import java.io.IOException;
import java.io.Serial;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import com.fasterxml.jackson.databind.*;

import jakarta.validation.constraints.NotEmpty;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;

@RegisterForReflection
public class CallFunction implements Serializable {
@Serial
Expand All @@ -23,6 +32,7 @@ public class CallFunction implements Serializable {
* The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON,
* and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
*/
@JsonDeserialize(using = ArgumentsDeserializer.class)
private Map<String, String> arguments;

public void setName(String name) {
Expand All @@ -40,4 +50,31 @@ public void setArguments(Map<String, String> arguments) {
public Map<String, String> getArguments() {
return arguments;
}

public static class ArgumentsDeserializer extends StdDeserializer<Map<String, String>> {

public ArgumentsDeserializer() {
this(null);
}

public ArgumentsDeserializer(Class<?> vc) {
super(vc);
}

@Override
public Map<String, String> deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException, JsonProcessingException {
JavaType typeRef = ctxt.getTypeFactory().constructMapType(HashMap.class, String.class, String.class);

JsonNode node = jp.getCodec().readTree(jp);
if (node.isTextual()) {
try {
return new ObjectMapper().readValue(node.asText(), typeRef);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

return ctxt.readTreeAsValue(node, typeRef);
}
}
}

0 comments on commit 1c84878

Please sign in to comment.