-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
217 changed files
with
18,637 additions
and
3,254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import time | ||
from openai import OpenAI | ||
from dotenv import load_dotenv | ||
from openai.types.beta.assistant_stream_event import ThreadMessageDelta | ||
from streaming_assistants import patch | ||
from openai.lib.streaming import AssistantEventHandler | ||
from typing_extensions import override | ||
|
||
|
||
load_dotenv("./.env") | ||
load_dotenv("../../../.env") | ||
|
||
def run_with_assistant(assistant, client): | ||
print(f"created assistant: {assistant.name}") | ||
print("Uploading file:") | ||
# Upload the file | ||
file = client.files.create( | ||
file=open( | ||
"./examples/python/language_models_are_unsupervised_multitask_learners.pdf", | ||
"rb", | ||
), | ||
purpose="assistants", | ||
) | ||
print("adding file id to assistant") | ||
# Update Assistant | ||
assistant = client.beta.assistants.update( | ||
assistant.id, | ||
tools=[{"type": "retrieval"}], | ||
file_ids=[file.id], | ||
) | ||
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences." | ||
print("creating persistent thread and message") | ||
thread = client.beta.threads.create() | ||
client.beta.threads.messages.create( | ||
thread_id=thread.id, role="user", content=user_message | ||
) | ||
print(f"> {user_message}") | ||
|
||
class EventHandler(AssistantEventHandler): | ||
@override | ||
def on_text_delta(self, delta, snapshot): | ||
# Increment the counter each time the method is called | ||
print(delta.value, end="", flush=True) | ||
|
||
print(f"creating run") | ||
with client.beta.threads.runs.create_and_stream( | ||
thread_id=thread.id, | ||
assistant_id=assistant.id, | ||
event_handler=EventHandler(), | ||
) as stream: | ||
for part in stream: | ||
if not isinstance(part, ThreadMessageDelta): | ||
print(f'received event: {part}\n') | ||
|
||
print("\n") | ||
|
||
|
||
client = patch(OpenAI()) | ||
|
||
instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond and share the exact snippets from the file at the end of your response." | ||
|
||
model = "gpt-3.5-turbo" | ||
name = f"{model} Math Tutor" | ||
|
||
gpt3_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(gpt3_assistant, client) | ||
|
||
model = "cohere/command" | ||
name = f"{model} Math Tutor" | ||
|
||
cohere_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(cohere_assistant, client) | ||
|
||
model = "perplexity/mixtral-8x7b-instruct" | ||
name = f"{model} Math Tutor" | ||
|
||
perplexity_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(perplexity_assistant, client) | ||
|
||
model = "anthropic.claude-v2" | ||
name = f"{model} Math Tutor" | ||
|
||
claude_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(claude_assistant, client) | ||
|
||
model = "gemini/gemini-pro" | ||
name = f"{model} Math Tutor" | ||
|
||
gemini_assistant = client.beta.assistants.create( | ||
name=name, | ||
instructions=instructions, | ||
model=model, | ||
) | ||
run_with_assistant(gemini_assistant, client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import asyncio | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
background_task_set = set() | ||
|
||
event_loop = asyncio.new_event_loop() | ||
|
||
async def add_background_task(function, run_id, thread_id, astradb): | ||
logger.debug("Creating background task") | ||
task = asyncio.create_task( | ||
function, name=run_id | ||
) | ||
background_task_set.add(task) | ||
task.add_done_callback(lambda t: on_task_completion(t, astradb=astradb, run_id=run_id, thread_id=thread_id)) | ||
|
||
|
||
def on_task_completion(task, astradb, run_id, thread_id): | ||
background_task_set.remove(task) | ||
logger.debug(f"Task stopped for run_id: {run_id} and thread_id: {thread_id}") | ||
|
||
if task.cancelled(): | ||
logger.warning(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); | ||
return | ||
try: | ||
exception = task.exception() | ||
if exception is not None: | ||
logger.warning(f"Task raised an exception, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
logger.error(exception) | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); | ||
raise exception | ||
else: | ||
logger.debug(f"Task completed successfully for run_id: {run_id} and thread_id: {thread_id}") | ||
except asyncio.CancelledError: | ||
logger.warning(f"why wasn't this caught in task.cancelled()") | ||
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}") | ||
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Optional, Annotated, List | ||
|
||
from pydantic import Field | ||
|
||
from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner | ||
from openapi_server.models.assistant_object import AssistantObject as AssistantObjectGenerated | ||
|
||
|
||
class AssistantObject(AssistantObjectGenerated): | ||
tools: Annotated[List[AssistantObjectToolsInner], Field(max_length=20)] = Field(description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# coding: utf-8 | ||
|
||
from __future__ import annotations | ||
from datetime import date, datetime # noqa: F401 | ||
|
||
import re # noqa: F401 | ||
from typing import Any, Dict, List, Optional # noqa: F401 | ||
|
||
from pydantic import AnyUrl, BaseModel, EmailStr, Field, validator # noqa: F401 | ||
from openapi_server.models.assistant_tools_code import AssistantToolsCode | ||
from openapi_server.models.assistant_tools_function import AssistantToolsFunction | ||
from openapi_server.models.assistant_tools_retrieval import AssistantToolsRetrieval | ||
from openapi_server.models.function_object import FunctionObject | ||
|
||
|
||
class AssistantObjectToolsInner(BaseModel): | ||
"""NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). | ||
Do not edit the class manually. | ||
AssistantObjectToolsInner - a model defined in OpenAPI | ||
# type: The type of this AssistantObjectToolsInner. | ||
# function: The function of this AssistantObjectToolsInner. | ||
""" | ||
|
||
type: Optional[str] = Field(alias="type") | ||
function: Optional[FunctionObject] = Field(default=None, alias="function") | ||
|
||
AssistantObjectToolsInner.update_forward_refs() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# coding: utf-8 | ||
|
||
from __future__ import annotations | ||
from datetime import date, datetime # noqa: F401 | ||
|
||
import re # noqa: F401 | ||
from typing import Any, Dict, List, Optional # noqa: F401 | ||
|
||
from pydantic import AnyUrl, BaseModel, EmailStr, Field, validator # noqa: F401 | ||
from openapi_server.models.chat_completion_message_tool_call import ChatCompletionMessageToolCall | ||
from openapi_server.models.chat_completion_request_assistant_message import ChatCompletionRequestAssistantMessage | ||
from openapi_server.models.chat_completion_request_assistant_message_function_call import ChatCompletionRequestAssistantMessageFunctionCall | ||
from openapi_server.models.chat_completion_request_function_message import ChatCompletionRequestFunctionMessage | ||
from openapi_server.models.chat_completion_request_system_message import ChatCompletionRequestSystemMessage | ||
from openapi_server.models.chat_completion_request_tool_message import ChatCompletionRequestToolMessage | ||
from openapi_server.models.chat_completion_request_user_message import ChatCompletionRequestUserMessage | ||
|
||
|
||
class ChatCompletionRequestMessage(BaseModel): | ||
"""NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). | ||
Do not edit the class manually. | ||
ChatCompletionRequestMessage - a model defined in OpenAPI | ||
content: The content of this ChatCompletionRequestMessage. | ||
role: The role of this ChatCompletionRequestMessage. | ||
name: The name of this ChatCompletionRequestMessage. | ||
tool_calls: The tool_calls of this ChatCompletionRequestMessage [Optional]. | ||
function_call: The function_call of this ChatCompletionRequestMessage [Optional]. | ||
tool_call_id: The tool_call_id of this ChatCompletionRequestMessage. | ||
""" | ||
|
||
content: str = Field(alias="content") | ||
role: str = Field(alias="role") | ||
name: Optional[str] = Field(alias="name", default=None) | ||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(alias="tool_calls", default=None) | ||
function_call: Optional[ChatCompletionRequestAssistantMessageFunctionCall] = Field(alias="function_call", default=None) | ||
tool_call_id: Optional[str] = Field(alias="tool_call_id", default=None) | ||
|
||
ChatCompletionRequestMessage.update_forward_refs() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import Literal | ||
|
||
from openai.types.beta.threads import Run as ClientRun | ||
|
||
|
||
class Run(ClientRun): | ||
status: Literal[ | ||
"queued", "in_progress", "requires_action", "cancelling", "cancelled", "failed", "completed", "expired", "generating" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Optional, Annotated, List | ||
|
||
from pydantic import Field | ||
|
||
from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner | ||
from openapi_server.models.create_assistant_request import CreateAssistantRequest as CreateAssistantRequestGenerated | ||
|
||
|
||
class CreateAssistantRequest(CreateAssistantRequestGenerated): | ||
tools: Optional[Annotated[List[AssistantObjectToolsInner], Field(max_length=128)]] = Field(default=None, description="assistant_tools_param_description") |
Oops, something went wrong.