From 2f75e608f16e9768e23d2858838ab48e11046750 Mon Sep 17 00:00:00 2001 From: Sashir Estela Date: Sat, 14 Sep 2024 00:19:38 +0000 Subject: [PATCH] Fix for Azure OpenAI issues --- .../openai/demo/ChatAnyscaleDemo.java | 100 ++---------- .../openai/demo/ChatAzureDemo.java | 143 +++--------------- .../sashirestela/openai/demo/ChatDemo.java | 31 ++-- .../openai/SimpleOpenAIAzure.java | 12 +- .../openai/SimpleOpenAIAzureTest.java | 4 +- 5 files changed, 60 insertions(+), 230 deletions(-) diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java index 7211a081..4c8acf12 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java @@ -1,104 +1,26 @@ package io.github.sashirestela.openai.demo; +import io.github.sashirestela.openai.BaseSimpleOpenAI; import io.github.sashirestela.openai.SimpleOpenAIAnyscale; -import io.github.sashirestela.openai.common.function.FunctionDef; -import io.github.sashirestela.openai.common.function.FunctionExecutor; -import io.github.sashirestela.openai.demo.ChatDemo.Product; -import io.github.sashirestela.openai.demo.ChatDemo.RunAlarm; -import io.github.sashirestela.openai.demo.ChatDemo.Weather; -import io.github.sashirestela.openai.domain.chat.Chat; -import io.github.sashirestela.openai.domain.chat.ChatMessage; -import io.github.sashirestela.openai.domain.chat.ChatMessage.SystemMessage; -import io.github.sashirestela.openai.domain.chat.ChatMessage.ToolMessage; -import io.github.sashirestela.openai.domain.chat.ChatMessage.UserMessage; -import io.github.sashirestela.openai.domain.chat.ChatRequest; -import java.util.ArrayList; +public class ChatAnyscaleDemo extends ChatDemo { -public class ChatAnyscaleDemo extends AbstractDemo { - - public static final String MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"; - - private ChatRequest chatRequest; - - public ChatAnyscaleDemo(String apiKey, String model) { - super(SimpleOpenAIAnyscale.builder().apiKey(apiKey).build()); - chatRequest = ChatRequest.builder() - .model(model) - .message(SystemMessage.of("You are an expert in AI.")) - .message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words.")) - .temperature(0.0) - .maxTokens(300) - .build(); - } - - public void demoCallChatStreaming() { - var futureChat = openAI.chatCompletions().createStream(chatRequest); - var chatResponse = futureChat.join(); - chatResponse.filter(chatResp -> chatResp.firstContent() != null) - .map(Chat::firstContent) - .forEach(System.out::print); - System.out.println(); - } - - public void demoCallChatBlocking() { - var futureChat = openAI.chatCompletions().create(chatRequest); - var chatResponse = futureChat.join(); - System.out.println(chatResponse.firstContent()); - } - - public void demoCallChatWithFunctions() { - var functionExecutor = new FunctionExecutor(); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("get_weather") - .description("Get the current weather of a location") - .functionalClass(Weather.class) - .build()); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("product") - .description("Get the product of two numbers") - .functionalClass(Product.class) - .build()); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("run_alarm") - .description("Run an alarm") - .functionalClass(RunAlarm.class) - .build()); - var messages = new ArrayList(); - messages.add(UserMessage.of("What is the product of 123 and 456?")); - var chatRequest = ChatRequest.builder() - .model(MODEL) - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); - var futureChat = openAI.chatCompletions().create(chatRequest); - var chatResponse = futureChat.join(); - var chatMessage = chatResponse.firstMessage(); - var chatToolCall = chatMessage.getToolCalls().get(0); - var result = functionExecutor.execute(chatToolCall.getFunction()); - messages.add(chatMessage); - messages.add(ToolMessage.of(result.toString(), chatToolCall.getId())); - chatRequest = ChatRequest.builder() - .model(MODEL) - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); - futureChat = openAI.chatCompletions().create(chatRequest); - chatResponse = futureChat.join(); - System.out.println(chatResponse.firstContent()); + public ChatAnyscaleDemo(BaseSimpleOpenAI openAI, String model) { + super(openAI, model); } public static void main(String[] args) { - var apiKey = System.getenv("ANYSCALE_API_KEY"); - - var demo = new ChatAnyscaleDemo(apiKey, MODEL); + var openAI = SimpleOpenAIAnyscale.builder() + .apiKey(System.getenv("ANYSCALE_API_KEY")) + .build(); + var demo = new ChatAnyscaleDemo(openAI, "mistralai/Mixtral-8x7B-Instruct-v0.1"); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); + //demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); + //demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); + //demo.addTitleAction("Call Chat with Structured Outputs", demo::demoCallChatWithStructuredOutputs); demo.run(); } diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java index b56b1886..be50f0a8 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java @@ -1,99 +1,25 @@ package io.github.sashirestela.openai.demo; +import io.github.sashirestela.openai.BaseSimpleOpenAI; import io.github.sashirestela.openai.SimpleOpenAIAzure; import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl; import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl.ImageUrl; import io.github.sashirestela.openai.common.content.ContentPart.ContentPartText; -import io.github.sashirestela.openai.common.function.FunctionDef; -import io.github.sashirestela.openai.common.function.FunctionExecutor; -import io.github.sashirestela.openai.demo.ChatDemo.Product; -import io.github.sashirestela.openai.demo.ChatDemo.RunAlarm; -import io.github.sashirestela.openai.demo.ChatDemo.Weather; -import io.github.sashirestela.openai.domain.chat.ChatMessage; -import io.github.sashirestela.openai.domain.chat.ChatMessage.SystemMessage; -import io.github.sashirestela.openai.domain.chat.ChatMessage.ToolMessage; import io.github.sashirestela.openai.domain.chat.ChatMessage.UserMessage; import io.github.sashirestela.openai.domain.chat.ChatRequest; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Base64; import java.util.List; -public class ChatAzureDemo extends AbstractDemo { +public class ChatAzureDemo extends ChatDemo { - private ChatRequest chatRequest; - - public ChatAzureDemo(String baseUrl, String apiKey, String apiVersion) { - super(SimpleOpenAIAzure.builder() - .apiKey(apiKey) - .baseUrl(baseUrl) - .apiVersion(apiVersion) - .build()); - chatRequest = ChatRequest.builder() - .model("N/A") - .message(SystemMessage.of("You are an expert in AI.")) - .message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words.")) - .temperature(0.0) - .maxTokens(300) - .build(); - } - - public void demoCallChatBlocking() { - var futureChat = openAI.chatCompletions().create(chatRequest); - var chatResponse = futureChat.join(); - System.out.println(chatResponse.firstContent()); - } - - public void demoCallChatWithFunctions() { - var functionExecutor = new FunctionExecutor(); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("get_weather") - .description("Get the current weather of a location") - .functionalClass(Weather.class) - .build()); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("product") - .description("Get the product of two numbers") - .functionalClass(Product.class) - .build()); - functionExecutor.enrollFunction( - FunctionDef.builder() - .name("run_alarm") - .description("Run an alarm") - .functionalClass(RunAlarm.class) - .build()); - var messages = new ArrayList(); - messages.add(UserMessage.of("What is the product of 123 and 456?")); - chatRequest = ChatRequest.builder() - .model("N/A") - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); - var futureChat = openAI.chatCompletions().create(chatRequest); - var chatResponse = futureChat.join(); - var chatMessage = chatResponse.firstMessage(); - var chatToolCall = chatMessage.getToolCalls().get(0); - var result = functionExecutor.execute(chatToolCall.getFunction()); - messages.add(chatMessage); - messages.add(ToolMessage.of(result.toString(), chatToolCall.getId())); - chatRequest = ChatRequest.builder() - .model("N/A") - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); - futureChat = openAI.chatCompletions().create(chatRequest); - chatResponse = futureChat.join(); - System.out.println(chatResponse.firstContent()); + public ChatAzureDemo(BaseSimpleOpenAI openAI, String model) { + super(openAI, model); } + @Override public void demoCallChatWithVisionExternalImage() { var chatRequest = ChatRequest.builder() - .model("N/A") + .model(model) .messages(List.of( UserMessage.of(List.of( ContentPartText.of( @@ -103,16 +29,14 @@ public void demoCallChatWithVisionExternalImage() { .temperature(0.0) .maxTokens(500) .build(); - var chatResponse = openAI.chatCompletions().create(chatRequest).join(); System.out.println(chatResponse.firstContent()); - System.out.println(); - } + @Override public void demoCallChatWithVisionLocalImage() { var chatRequest = ChatRequest.builder() - .model("N/A") + .model(model) .messages(List.of( UserMessage.of(List.of( ContentPartText.of( @@ -125,45 +49,22 @@ public void demoCallChatWithVisionLocalImage() { System.out.println(chatResponse.firstContent()); } - private static ImageUrl loadImageAsBase64(String imagePath) { - try { - Path path = Paths.get(imagePath); - byte[] imageBytes = Files.readAllBytes(path); - String base64String = Base64.getEncoder().encodeToString(imageBytes); - var extension = imagePath.substring(imagePath.lastIndexOf('.') + 1); - var prefix = "data:image/" + extension + ";base64,"; - return ImageUrl.of(prefix + base64String); - } catch (Exception e) { - e.printStackTrace(); - return null; - } - } - - private static void chatWithFunctionsDemo(String apiVersion) { - var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL"); - var apiKey = System.getenv("AZURE_OPENAI_API_KEY"); - var chatDemo = new ChatAzureDemo(baseUrl, apiKey, apiVersion); - chatDemo.addTitleAction("Call Chat (Blocking Approach)", chatDemo::demoCallChatBlocking); - chatDemo.addTitleAction("Call Chat with Functions", chatDemo::demoCallChatWithFunctions); - - chatDemo.run(); - } - - private static void chatWithVisionDemo(String apiVersion) { - var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL_VISION"); - var apiKey = System.getenv("AZURE_OPENAI_API_KEY_VISION"); - var visionDemo = new ChatAzureDemo(baseUrl, apiKey, apiVersion); - visionDemo.addTitleAction("Call Chat with Vision (External image)", - visionDemo::demoCallChatWithVisionExternalImage); - visionDemo.addTitleAction("Call Chat with Vision (Local image)", visionDemo::demoCallChatWithVisionLocalImage); - visionDemo.run(); - } - public static void main(String[] args) { - var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION"); + var openAI = SimpleOpenAIAzure.builder() + .apiKey(System.getenv("AZURE_OPENAI_API_KEY")) + .apiVersion(System.getenv("AZURE_OPENAI_API_VERSION")) + .baseUrl(System.getenv("AZURE_OPENAI_BASE_URL")) + .build(); + var demo = new ChatAzureDemo(openAI, "N/A"); + + demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); + demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); + demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); + demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); + demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); + demo.addTitleAction("Call Chat with Structured Outputs", demo::demoCallChatWithStructuredOutputs); - chatWithFunctionsDemo(apiVersion); - chatWithVisionDemo(apiVersion); + demo.run(); } } diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java index 13c10cbf..354d2e43 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java @@ -2,6 +2,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import io.github.sashirestela.openai.BaseSimpleOpenAI; +import io.github.sashirestela.openai.SimpleOpenAI; import io.github.sashirestela.openai.common.ResponseFormat; import io.github.sashirestela.openai.common.ResponseFormat.JsonSchema; import io.github.sashirestela.openai.common.content.ContentPart.ContentPartImageUrl; @@ -26,13 +28,14 @@ public class ChatDemo extends AbstractDemo { - private ChatRequest chatRequest; - private String modelIdToUse; + protected ChatRequest chatRequest; + protected String model; - public ChatDemo() { - modelIdToUse = "gpt-4o-mini"; + public ChatDemo(BaseSimpleOpenAI openAI, String model) { + super(openAI); + this.model = model; chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .message(SystemMessage.of("You are an expert in AI.")) .message(UserMessage.of("Write a technical article about ChatGPT, no more than 100 words.")) .temperature(0.0) @@ -78,7 +81,7 @@ public void demoCallChatWithFunctions() { var messages = new ArrayList(); messages.add(UserMessage.of("What is the product of 123 and 456?")); chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .messages(messages) .tools(functionExecutor.getToolFunctions()) .build(); @@ -90,7 +93,7 @@ public void demoCallChatWithFunctions() { messages.add(chatMessage); messages.add(ToolMessage.of(result.toString(), chatToolCall.getId())); chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .messages(messages) .tools(functionExecutor.getToolFunctions()) .build(); @@ -101,7 +104,7 @@ public void demoCallChatWithFunctions() { public void demoCallChatWithVisionExternalImage() { var chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .messages(List.of( UserMessage.of(List.of( ContentPartText.of( @@ -118,7 +121,7 @@ public void demoCallChatWithVisionExternalImage() { public void demoCallChatWithVisionLocalImage() { var chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .messages(List.of( UserMessage.of(List.of( ContentPartText.of( @@ -134,7 +137,7 @@ public void demoCallChatWithVisionLocalImage() { public void demoCallChatWithStructuredOutputs() { var chatRequest = ChatRequest.builder() - .model(modelIdToUse) + .model(model) .message(SystemMessage .of("You are a helpful math tutor. Guide the user through the solution step by step.")) .message(UserMessage.of("How can I solve 8x + 7 = -23")) @@ -148,7 +151,7 @@ public void demoCallChatWithStructuredOutputs() { System.out.println(); } - private ImageUrl loadImageAsBase64(String imagePath) { + protected ImageUrl loadImageAsBase64(String imagePath) { try { Path path = Paths.get(imagePath); byte[] imageBytes = Files.readAllBytes(path); @@ -235,7 +238,11 @@ public static class Step { } public static void main(String[] args) { - var demo = new ChatDemo(); + var openAI = SimpleOpenAI.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) + .build(); + var demo = new ChatDemo(openAI, "gpt-4o-mini"); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java index 9640befb..6bc72b37 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java @@ -63,18 +63,18 @@ private static String getNewUrl(String url, String apiVersion) { } private static Object getBodyForJson(String url, String body, String deployment) { - final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?"; - final String EMPTY_REGEX = "\"\""; - final String QUOTED_COMMA = "\",\""; + final String MODEL_ENTRY_REGEX = "\"model\"\\s*:\\s*\"[^\"]+\"\\s*,?\\s*"; + final String TRAILING_COMMA_REGEX = ",\\s*}"; + final String CLOSING_BRACE = "}"; final String MODEL_LITERAL = "model"; final String ASSISTANTS_LITERAL = "/assistants"; var model = ""; if (url.contains(ASSISTANTS_LITERAL)) { - model = "\"" + MODEL_LITERAL + "\":\"" + deployment + "\""; + model = "\"" + MODEL_LITERAL + "\":\"" + deployment + "\","; } - body = body.replaceFirst(MODEL_REGEX, model); - body = body.replaceFirst(EMPTY_REGEX, QUOTED_COMMA); + body = body.replaceFirst(MODEL_ENTRY_REGEX, model); + body = body.replaceFirst(TRAILING_COMMA_REGEX, CLOSING_BRACE); return body; } diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java index 4a5da944..01038cb9 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java @@ -34,13 +34,13 @@ void shouldInterceptUrlCorrectlyWhenBodyIsJson() { .url(baseUrl + "/chat/completions") .contentType(ContentType.APPLICATION_JSON) .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) - .body("{\"model\":\"model1\"}") + .body("{\"messages\":[],\"model\":\"model1\",\"stream\":false}") .build(); var expectedRequest = HttpRequestData.builder() .url(baseUrl + "/chat/completions?" + Constant.AZURE_API_VERSION + "=12-34-5678") .contentType(ContentType.APPLICATION_JSON) .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) - .body("{}") + .body("{\"messages\":[],\"stream\":false}") .build(); var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( "the-api-key",