diff --git a/.github/workflows/test-mito-ai.yml b/.github/workflows/test-mito-ai.yml index 14004d9c6..a82786314 100644 --- a/.github/workflows/test-mito-ai.yml +++ b/.github/workflows/test-mito-ai.yml @@ -15,6 +15,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.10', '3.11'] + use-mito-ai-server: [true, false] fail-fast: false steps: @@ -54,7 +55,7 @@ jobs: jupyter lab --config jupyter_server_test_config.py & npm run test:mitoai env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENAI_API_KEY: ${{ matrix.use-mito-ai-server && '' || secrets.OPENAI_API_KEY }} - name: Upload test-results uses: actions/upload-artifact@v3 if: failure() diff --git a/mito-ai/mito-ai/OpenAICompletionHandler.py b/mito-ai/mito-ai/OpenAICompletionHandler.py index 7dce283af..a8229c2c7 100644 --- a/mito-ai/mito-ai/OpenAICompletionHandler.py +++ b/mito-ai/mito-ai/OpenAICompletionHandler.py @@ -1,9 +1,7 @@ -import os -import openai from jupyter_server.base.handlers import APIHandler from tornado import web import json -from openai import OpenAI +from .utils.open_ai_utils import get_open_ai_completion # This handler is responsible for the mito_ai/completion endpoint. @@ -17,33 +15,10 @@ def post(self): data = self.get_json_body() messages = data.get('messages', '') - # Get the OpenAI API key from environment variables - openai_api_key = os.getenv('OPENAI_API_KEY') - if not openai_api_key: - # If the API key is not set, return a 401 unauthorized error - self.set_status(401) - self.finish(json.dumps({"response": "OPENAI_API_KEY not set"})) - return - - # Set up the OpenAI client - openai.api_key = openai_api_key - client = OpenAI() - try: # Query OpenAI API - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=messages - ) - - response_dict = response.to_dict() - - # Send the response back to the frontend - # TODO: In the future, instead of returning the raw response, - # return a cleaned up version of the response so we can support - # multiple models - - self.finish(json.dumps(response_dict)) + response = get_open_ai_completion(messages) + self.finish(json.dumps(response)) except Exception as e: self.set_status(500) self.finish(json.dumps({"response": f"Error: {str(e)}"})) \ No newline at end of file diff --git a/mito-ai/mito-ai/utils/__init__.py b/mito-ai/mito-ai/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mito-ai/mito-ai/utils/open_ai_utils.py b/mito-ai/mito-ai/utils/open_ai_utils.py new file mode 100644 index 000000000..616177738 --- /dev/null +++ b/mito-ai/mito-ai/utils/open_ai_utils.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# coding: utf-8 + +# Copyright (c) Saga Inc. + +import os +import requests +from typing import Any, Dict, List + +OPEN_AI_URL = 'https://api.openai.com/v1/chat/completions' +MITO_AI_URL = 'https://ogtzairktg.execute-api.us-east-1.amazonaws.com/Prod/completions/' + +OPEN_SOURCE_AI_COMPLETIONS_LIMIT = 100 + +def _get_ai_completion_data(messages: List[Dict[str, Any]]) -> Dict[str, Any]: + return { + "model": "gpt-4o-mini", + "messages": messages, + "temperature": 0, + } + +__user_email = None +__user_id = None +__num_usages = None + +def _get_ai_completion_from_mito_server(ai_completion_data: Dict[str, Any]) -> Dict[str, Any]: + + data = { + 'email': __user_email, + 'user_id': __user_id, + 'data': ai_completion_data + } + + headers = { + 'Content-Type': 'application/json', + } + + try: + res = requests.post(MITO_AI_URL, headers=headers, json=data) + + # If the response status code is in the 200s, this does nothing + # If the response status code indicates an error (4xx or 5xx), + # raise an HTTPError exception with details about what went wrong + res.raise_for_status() + + # The lambda function returns a dictionary with a completion entry in it, + # so we just return that. + return res.json() + + except Exception as e: + print('Error using mito server', e) + raise e + + +def get_open_ai_completion(messages: List[Dict[str, Any]]) -> Dict[str, Any]: + + OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') + ai_completion_data = _get_ai_completion_data(messages) + + if OPENAI_API_KEY is None: + # If they don't have an Open AI key, + # use the mito server to get a completion + completion = _get_ai_completion_from_mito_server(ai_completion_data) + return completion + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {OPENAI_API_KEY}' + } + + try: + res = requests.post(OPEN_AI_URL, headers=headers, json=ai_completion_data) + + # If the response status code is in the 200s, this does nothing + # If the response status code indicates an error (4xx or 5xx), + # raise an HTTPError exception with details about what went wrong + res.raise_for_status() + + completion = res.json()['choices'][0]['message']['content'] + return {'completion': completion} + except Exception as e: + raise e + \ No newline at end of file diff --git a/mito-ai/src/Extensions/AiChat/ChatTaskpane.tsx b/mito-ai/src/Extensions/AiChat/ChatTaskpane.tsx index b038586d8..85a85e219 100644 --- a/mito-ai/src/Extensions/AiChat/ChatTaskpane.tsx +++ b/mito-ai/src/Extensions/AiChat/ChatTaskpane.tsx @@ -165,12 +165,11 @@ const ChatTaskpane: React.FC = ({ }); if (apiResponse.type === 'success') { - - const response = apiResponse.response; - const aiMessage = response.choices[0].message; + const aiMessage = apiResponse.response; newChatHistoryManager.addAIMessageFromResponse(aiMessage); setChatHistoryManager(newChatHistoryManager); + aiRespone = aiMessage } else { newChatHistoryManager.addAIMessageFromMessageContent(apiResponse.errorMessage, true) diff --git a/mito-ai/src/utils/handler.ts b/mito-ai/src/utils/handler.ts index d6543633a..a10ee8076 100644 --- a/mito-ai/src/utils/handler.ts +++ b/mito-ai/src/utils/handler.ts @@ -1,11 +1,11 @@ import { URLExt } from '@jupyterlab/coreutils'; import { ServerConnection } from '@jupyterlab/services'; -import OpenAI from 'openai'; +import OpenAI from "openai"; export type SuccessfulAPIResponse = { 'type': 'success', - response: OpenAI.Chat.ChatCompletion + response: OpenAI.Chat.Completions.ChatCompletionMessage } export type FailedAPIResponse = { type: 'error', @@ -41,8 +41,8 @@ export async function requestAPI( // Merge default headers with any provided headers init.headers = { - ...defaultHeaders, - ...init.headers, + ...defaultHeaders, + ...init.headers, }; // Make the request @@ -83,9 +83,18 @@ export async function requestAPI( try { data = JSON.parse(data); + + // TODO: Update the lambda funciton to return the entire message instead of + // just the content so we don't have to recreate the message here. + const aiMessage: OpenAI.Chat.Completions.ChatCompletionMessage = { + role: 'assistant', + content: data['completion'], + refusal: null + } + return { type: 'success', - response: data + response: aiMessage } } catch (error) { console.error('Not a JSON response body.', response);