Skip to content

Commit

Permalink
Fix for Azure OpenAI issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sashirestela committed Sep 14, 2024
1 parent 77f5100 commit 2f75e60
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 230 deletions.
100 changes: 11 additions & 89 deletions src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java
Original file line number Diff line number Diff line change
@@ -1,104 +1,26 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.openai.BaseSimpleOpenAI;
import io.github.sashirestela.openai.SimpleOpenAIAnyscale;
import io.github.sashirestela.openai.common.function.FunctionDef;
import io.github.sashirestela.openai.common.function.FunctionExecutor;
import io.github.sashirestela.openai.demo.ChatDemo.Product;
import io.github.sashirestela.openai.demo.ChatDemo.RunAlarm;
import io.github.sashirestela.openai.demo.ChatDemo.Weather;
import io.github.sashirestela.openai.domain.chat.Chat;
import io.github.sashirestela.openai.domain.chat.ChatMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.SystemMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.ToolMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.UserMessage;
import io.github.sashirestela.openai.domain.chat.ChatRequest;

import java.util.ArrayList;
public class ChatAnyscaleDemo extends ChatDemo {

public class ChatAnyscaleDemo extends AbstractDemo {

public static final String MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1";

private ChatRequest chatRequest;

public ChatAnyscaleDemo(String apiKey, String model) {
super(SimpleOpenAIAnyscale.builder().apiKey(apiKey).build());
chatRequest = ChatRequest.builder()
.model(model)
.message(SystemMessage.of("You are an expert in AI."))
.message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words."))
.temperature(0.0)
.maxTokens(300)
.build();
}

public void demoCallChatStreaming() {
var futureChat = openAI.chatCompletions().createStream(chatRequest);
var chatResponse = futureChat.join();
chatResponse.filter(chatResp -> chatResp.firstContent() != null)
.map(Chat::firstContent)
.forEach(System.out::print);
System.out.println();
}

public void demoCallChatBlocking() {
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
}

public void demoCallChatWithFunctions() {
var functionExecutor = new FunctionExecutor();
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("get_weather")
.description("Get the current weather of a location")
.functionalClass(Weather.class)
.build());
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("product")
.description("Get the product of two numbers")
.functionalClass(Product.class)
.build());
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("run_alarm")
.description("Run an alarm")
.functionalClass(RunAlarm.class)
.build());
var messages = new ArrayList<ChatMessage>();
messages.add(UserMessage.of("What is the product of 123 and 456?"));
var chatRequest = ChatRequest.builder()
.model(MODEL)
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
var chatMessage = chatResponse.firstMessage();
var chatToolCall = chatMessage.getToolCalls().get(0);
var result = functionExecutor.execute(chatToolCall.getFunction());
messages.add(chatMessage);
messages.add(ToolMessage.of(result.toString(), chatToolCall.getId()));
chatRequest = ChatRequest.builder()
.model(MODEL)
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
futureChat = openAI.chatCompletions().create(chatRequest);
chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
public ChatAnyscaleDemo(BaseSimpleOpenAI openAI, String model) {
super(openAI, model);
}

public static void main(String[] args) {
var apiKey = System.getenv("ANYSCALE_API_KEY");

var demo = new ChatAnyscaleDemo(apiKey, MODEL);
var openAI = SimpleOpenAIAnyscale.builder()
.apiKey(System.getenv("ANYSCALE_API_KEY"))
.build();
var demo = new ChatAnyscaleDemo(openAI, "mistralai/Mixtral-8x7B-Instruct-v0.1");

demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming);
demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking);
demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions);
//demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage);
//demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage);
//demo.addTitleAction("Call Chat with Structured Outputs", demo::demoCallChatWithStructuredOutputs);

demo.run();
}
Expand Down
143 changes: 22 additions & 121 deletions src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java
Original file line number Diff line number Diff line change
@@ -1,99 +1,25 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.openai.BaseSimpleOpenAI;
import io.github.sashirestela.openai.SimpleOpenAIAzure;
import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl;
import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl.ImageUrl;
import io.github.sashirestela.openai.common.content.ContentPart.ContentPartText;
import io.github.sashirestela.openai.common.function.FunctionDef;
import io.github.sashirestela.openai.common.function.FunctionExecutor;
import io.github.sashirestela.openai.demo.ChatDemo.Product;
import io.github.sashirestela.openai.demo.ChatDemo.RunAlarm;
import io.github.sashirestela.openai.demo.ChatDemo.Weather;
import io.github.sashirestela.openai.domain.chat.ChatMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.SystemMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.ToolMessage;
import io.github.sashirestela.openai.domain.chat.ChatMessage.UserMessage;
import io.github.sashirestela.openai.domain.chat.ChatRequest;

import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;

public class ChatAzureDemo extends AbstractDemo {
public class ChatAzureDemo extends ChatDemo {

private ChatRequest chatRequest;

public ChatAzureDemo(String baseUrl, String apiKey, String apiVersion) {
super(SimpleOpenAIAzure.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.apiVersion(apiVersion)
.build());
chatRequest = ChatRequest.builder()
.model("N/A")
.message(SystemMessage.of("You are an expert in AI."))
.message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words."))
.temperature(0.0)
.maxTokens(300)
.build();
}

public void demoCallChatBlocking() {
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
}

public void demoCallChatWithFunctions() {
var functionExecutor = new FunctionExecutor();
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("get_weather")
.description("Get the current weather of a location")
.functionalClass(Weather.class)
.build());
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("product")
.description("Get the product of two numbers")
.functionalClass(Product.class)
.build());
functionExecutor.enrollFunction(
FunctionDef.builder()
.name("run_alarm")
.description("Run an alarm")
.functionalClass(RunAlarm.class)
.build());
var messages = new ArrayList<ChatMessage>();
messages.add(UserMessage.of("What is the product of 123 and 456?"));
chatRequest = ChatRequest.builder()
.model("N/A")
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
var chatMessage = chatResponse.firstMessage();
var chatToolCall = chatMessage.getToolCalls().get(0);
var result = functionExecutor.execute(chatToolCall.getFunction());
messages.add(chatMessage);
messages.add(ToolMessage.of(result.toString(), chatToolCall.getId()));
chatRequest = ChatRequest.builder()
.model("N/A")
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
futureChat = openAI.chatCompletions().create(chatRequest);
chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
public ChatAzureDemo(BaseSimpleOpenAI openAI, String model) {
super(openAI, model);
}

@Override
public void demoCallChatWithVisionExternalImage() {
var chatRequest = ChatRequest.builder()
.model("N/A")
.model(model)
.messages(List.of(
UserMessage.of(List.of(
ContentPartText.of(
Expand All @@ -103,16 +29,14 @@ public void demoCallChatWithVisionExternalImage() {
.temperature(0.0)
.maxTokens(500)
.build();

var chatResponse = openAI.chatCompletions().create(chatRequest).join();
System.out.println(chatResponse.firstContent());
System.out.println();

}

@Override
public void demoCallChatWithVisionLocalImage() {
var chatRequest = ChatRequest.builder()
.model("N/A")
.model(model)
.messages(List.of(
UserMessage.of(List.of(
ContentPartText.of(
Expand All @@ -125,45 +49,22 @@ public void demoCallChatWithVisionLocalImage() {
System.out.println(chatResponse.firstContent());
}

private static ImageUrl loadImageAsBase64(String imagePath) {
try {
Path path = Paths.get(imagePath);
byte[] imageBytes = Files.readAllBytes(path);
String base64String = Base64.getEncoder().encodeToString(imageBytes);
var extension = imagePath.substring(imagePath.lastIndexOf('.') + 1);
var prefix = "data:image/" + extension + ";base64,";
return ImageUrl.of(prefix + base64String);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}

private static void chatWithFunctionsDemo(String apiVersion) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY");
var chatDemo = new ChatAzureDemo(baseUrl, apiKey, apiVersion);
chatDemo.addTitleAction("Call Chat (Blocking Approach)", chatDemo::demoCallChatBlocking);
chatDemo.addTitleAction("Call Chat with Functions", chatDemo::demoCallChatWithFunctions);

chatDemo.run();
}

private static void chatWithVisionDemo(String apiVersion) {
var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL_VISION");
var apiKey = System.getenv("AZURE_OPENAI_API_KEY_VISION");
var visionDemo = new ChatAzureDemo(baseUrl, apiKey, apiVersion);
visionDemo.addTitleAction("Call Chat with Vision (External image)",
visionDemo::demoCallChatWithVisionExternalImage);
visionDemo.addTitleAction("Call Chat with Vision (Local image)", visionDemo::demoCallChatWithVisionLocalImage);
visionDemo.run();
}

public static void main(String[] args) {
var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION");
var openAI = SimpleOpenAIAzure.builder()
.apiKey(System.getenv("AZURE_OPENAI_API_KEY"))
.apiVersion(System.getenv("AZURE_OPENAI_API_VERSION"))
.baseUrl(System.getenv("AZURE_OPENAI_BASE_URL"))
.build();
var demo = new ChatAzureDemo(openAI, "N/A");

demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming);
demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking);
demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions);
demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage);
demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage);
demo.addTitleAction("Call Chat with Structured Outputs", demo::demoCallChatWithStructuredOutputs);

chatWithFunctionsDemo(apiVersion);
chatWithVisionDemo(apiVersion);
demo.run();
}

}
31 changes: 19 additions & 12 deletions src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import io.github.sashirestela.openai.BaseSimpleOpenAI;
import io.github.sashirestela.openai.SimpleOpenAI;
import io.github.sashirestela.openai.common.ResponseFormat;
import io.github.sashirestela.openai.common.ResponseFormat.JsonSchema;
import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl;
Expand All @@ -26,13 +28,14 @@

public class ChatDemo extends AbstractDemo {

private ChatRequest chatRequest;
private String modelIdToUse;
protected ChatRequest chatRequest;
protected String model;

public ChatDemo() {
modelIdToUse = "gpt-4o-mini";
public ChatDemo(BaseSimpleOpenAI openAI, String model) {
super(openAI);
this.model = model;
chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.message(SystemMessage.of("You are an expert in AI."))
.message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words."))
.temperature(0.0)
Expand Down Expand Up @@ -78,7 +81,7 @@ public void demoCallChatWithFunctions() {
var messages = new ArrayList<ChatMessage>();
messages.add(UserMessage.of("What is the product of 123 and 456?"));
chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
Expand All @@ -90,7 +93,7 @@ public void demoCallChatWithFunctions() {
messages.add(chatMessage);
messages.add(ToolMessage.of(result.toString(), chatToolCall.getId()));
chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.messages(messages)
.tools(functionExecutor.getToolFunctions())
.build();
Expand All @@ -101,7 +104,7 @@ public void demoCallChatWithFunctions() {

public void demoCallChatWithVisionExternalImage() {
var chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.messages(List.of(
UserMessage.of(List.of(
ContentPartText.of(
Expand All @@ -118,7 +121,7 @@ public void demoCallChatWithVisionExternalImage() {

public void demoCallChatWithVisionLocalImage() {
var chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.messages(List.of(
UserMessage.of(List.of(
ContentPartText.of(
Expand All @@ -134,7 +137,7 @@ public void demoCallChatWithVisionLocalImage() {

public void demoCallChatWithStructuredOutputs() {
var chatRequest = ChatRequest.builder()
.model(modelIdToUse)
.model(model)
.message(SystemMessage
.of("You are a helpful math tutor. Guide the user through the solution step by step."))
.message(UserMessage.of("How can I solve 8x + 7 = -23"))
Expand All @@ -148,7 +151,7 @@ public void demoCallChatWithStructuredOutputs() {
System.out.println();
}

private ImageUrl loadImageAsBase64(String imagePath) {
protected ImageUrl loadImageAsBase64(String imagePath) {
try {
Path path = Paths.get(imagePath);
byte[] imageBytes = Files.readAllBytes(path);
Expand Down Expand Up @@ -235,7 +238,11 @@ public static class Step {
}

public static void main(String[] args) {
var demo = new ChatDemo();
var openAI = SimpleOpenAI.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.build();
var demo = new ChatDemo(openAI, "gpt-4o-mini");

demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming);
demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking);
Expand Down
Loading

0 comments on commit 2f75e60

Please sign in to comment.