Skip to content

Commit

Permalink
streaming runs
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Mar 15, 2024
1 parent a9f5b1b commit d82a728
Show file tree
Hide file tree
Showing 217 changed files with 18,637 additions and 3,254 deletions.
110 changes: 110 additions & 0 deletions examples/python/streaming_runs/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import time
from openai import OpenAI
from dotenv import load_dotenv
from openai.types.beta.assistant_stream_event import ThreadMessageDelta
from streaming_assistants import patch
from openai.lib.streaming import AssistantEventHandler
from typing_extensions import override


load_dotenv("./.env")
load_dotenv("../../../.env")

def run_with_assistant(assistant, client):
print(f"created assistant: {assistant.name}")
print("Uploading file:")
# Upload the file
file = client.files.create(
file=open(
"./examples/python/language_models_are_unsupervised_multitask_learners.pdf",
"rb",
),
purpose="assistants",
)
print("adding file id to assistant")
# Update Assistant
assistant = client.beta.assistants.update(
assistant.id,
tools=[{"type": "retrieval"}],
file_ids=[file.id],
)
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences."
print("creating persistent thread and message")
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=user_message
)
print(f"> {user_message}")

class EventHandler(AssistantEventHandler):
@override
def on_text_delta(self, delta, snapshot):
# Increment the counter each time the method is called
print(delta.value, end="", flush=True)

print(f"creating run")
with client.beta.threads.runs.create_and_stream(
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=EventHandler(),
) as stream:
for part in stream:
if not isinstance(part, ThreadMessageDelta):
print(f'received event: {part}\n')

print("\n")


client = patch(OpenAI())

instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond and share the exact snippets from the file at the end of your response."

model = "gpt-3.5-turbo"
name = f"{model} Math Tutor"

gpt3_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(gpt3_assistant, client)

model = "cohere/command"
name = f"{model} Math Tutor"

cohere_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(cohere_assistant, client)

model = "perplexity/mixtral-8x7b-instruct"
name = f"{model} Math Tutor"

perplexity_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(perplexity_assistant, client)

model = "anthropic.claude-v2"
name = f"{model} Math Tutor"

claude_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(claude_assistant, client)

model = "gemini/gemini-pro"
name = f"{model} Math Tutor"

gemini_assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
model=model,
)
run_with_assistant(gemini_assistant, client)
23 changes: 12 additions & 11 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
)
from pydantic import BaseModel, Field

from impl.model.assistant_object import AssistantObject
from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from impl.model.message_object import MessageObject
from impl.model.open_ai_file import OpenAIFile
from impl.model.run_object import RunObject
from impl.models import (
DocumentChunk,
DocumentChunkMetadata,
Expand All @@ -34,16 +39,11 @@
QueryWithEmbedding,
)
from impl.services.inference_utils import get_embeddings
from openapi_server.models.assistant_object import AssistantObject
from openapi_server.models.message_content_text_object import MessageContentTextObject
from openapi_server.models.message_content_text_object_text import MessageContentTextObjectText
from openapi_server.models.run_object_required_action import RunObjectRequiredAction
from openapi_server.models.message_object import MessageObject
from openapi_server.models.message_object_content_inner import MessageObjectContentInner
from openapi_server.models.open_ai_file import OpenAIFile
from openapi_server.models.run_object import RunObject
from openapi_server.models.thread_object import ThreadObject
from openapi_server.models.assistant_object_tools_inner import AssistantObjectToolsInner



# Create a logger for this module.
Expand Down Expand Up @@ -858,6 +858,7 @@ def upsert_run(
tools=tools,
file_ids=file_ids,
metadata=metadata,
usage=None,
)

def upsert_message(
Expand Down Expand Up @@ -944,8 +945,8 @@ def get_message(self, thread_id, message_id):
if file_ids is None:
file_ids = []

created_at = row["created_at"].timestamp() * 1000
return MessageObject(
created_at = int(row["created_at"].timestamp() * 1000)
message_object = MessageObject(
id=row['id'],
object=row['object'],
created_at=created_at,
Expand All @@ -957,12 +958,13 @@ def get_message(self, thread_id, message_id):
file_ids=file_ids,
metadata=metadata
)
return message_object

def upsert_content_only_file(
self, id, created_at, object, purpose, filename, format, bytes, content, **litellm_kwargs,
):
self.upsert_chunks_content_only(id, content, created_at)
status = "success"
status = "uploaded"
query_string = f"""insert into {CASSANDRA_KEYSPACE}.files (
id,
object,
Expand Down Expand Up @@ -998,7 +1000,7 @@ def upsert_file(
self, id, created_at, object, purpose, filename, format, bytes, chunks, model, **litellm_kwargs,
):
self.upsert_chunks(chunks, model, **litellm_kwargs)
status = "success"
status = "processed"

query_string = f"""insert into {CASSANDRA_KEYSPACE}.files (
id,
Expand Down Expand Up @@ -1157,7 +1159,6 @@ def upsert_assistant(

def __del__(self):
# close the connection when the client is destroyed
logger.info("shutdown")
self.session.shutdown()

# TODO: make these async
Expand Down
38 changes: 38 additions & 0 deletions impl/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
import logging

logger = logging.getLogger(__name__)
background_task_set = set()

event_loop = asyncio.new_event_loop()

async def add_background_task(function, run_id, thread_id, astradb):
logger.debug("Creating background task")
task = asyncio.create_task(
function, name=run_id
)
background_task_set.add(task)
task.add_done_callback(lambda t: on_task_completion(t, astradb=astradb, run_id=run_id, thread_id=thread_id))


def on_task_completion(task, astradb, run_id, thread_id):
background_task_set.remove(task)
logger.debug(f"Task stopped for run_id: {run_id} and thread_id: {thread_id}")

if task.cancelled():
logger.warning(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
return
try:
exception = task.exception()
if exception is not None:
logger.warning(f"Task raised an exception, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
logger.error(exception)
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
raise exception
else:
logger.debug(f"Task completed successfully for run_id: {run_id} and thread_id: {thread_id}")
except asyncio.CancelledError:
logger.warning(f"why wasn't this caught in task.cancelled()")
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
24 changes: 18 additions & 6 deletions impl/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Callable, Sequence, Union

Expand All @@ -10,11 +11,12 @@
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import Info

from impl.background import background_task_set
from impl.routes import assistants, files, health, stateless, threads

# Configure logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)',
format='%(asctime)s - %(levelname)s - %(message)s (%(module)s:%(filename)s:%(lineno)d)',
datefmt='%Y-%m-%d %H:%M:%S')

logger = logging.getLogger('cassandra')
Expand All @@ -24,12 +26,22 @@
logger = logging.getLogger(__name__)

app = FastAPI(
# TODO: Change these?
title="OpenAI API",
description="The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details.",
title="Astra Assistants API",
description="Drop in replacement for OpenAI Assistants API. .",
version="2.0.0",
)

@app.on_event("shutdown")
async def shutdown_event():
logger.info("shutting down server")
for task in background_task_set:
task.cancel()
try:
await task # Give the task a chance to finish
except asyncio.CancelledError:
pass # Handle cancellation if needed


app.include_router(assistants.router, prefix="/v1")
app.include_router(files.router, prefix="/v1")
app.include_router(health.router, prefix="/v1")
Expand Down Expand Up @@ -163,5 +175,5 @@ async def unimplemented(request: Request, full_path: str):
)


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
#if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
10 changes: 10 additions & 0 deletions impl/model/assistant_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, Annotated, List

from pydantic import Field

from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from openapi_server.models.assistant_object import AssistantObject as AssistantObjectGenerated


class AssistantObject(AssistantObjectGenerated):
tools: Annotated[List[AssistantObjectToolsInner], Field(max_length=20)] = Field(description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.")
30 changes: 30 additions & 0 deletions impl/model/assistant_object_tools_inner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# coding: utf-8

from __future__ import annotations
from datetime import date, datetime # noqa: F401

import re # noqa: F401
from typing import Any, Dict, List, Optional # noqa: F401

from pydantic import AnyUrl, BaseModel, EmailStr, Field, validator # noqa: F401
from openapi_server.models.assistant_tools_code import AssistantToolsCode
from openapi_server.models.assistant_tools_function import AssistantToolsFunction
from openapi_server.models.assistant_tools_retrieval import AssistantToolsRetrieval
from openapi_server.models.function_object import FunctionObject


class AssistantObjectToolsInner(BaseModel):
"""NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
Do not edit the class manually.
AssistantObjectToolsInner - a model defined in OpenAPI
# type: The type of this AssistantObjectToolsInner.
# function: The function of this AssistantObjectToolsInner.
"""

type: Optional[str] = Field(alias="type")
function: Optional[FunctionObject] = Field(default=None, alias="function")

AssistantObjectToolsInner.update_forward_refs()
41 changes: 41 additions & 0 deletions impl/model/chat_completion_request_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# coding: utf-8

from __future__ import annotations
from datetime import date, datetime # noqa: F401

import re # noqa: F401
from typing import Any, Dict, List, Optional # noqa: F401

from pydantic import AnyUrl, BaseModel, EmailStr, Field, validator # noqa: F401
from openapi_server.models.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from openapi_server.models.chat_completion_request_assistant_message import ChatCompletionRequestAssistantMessage
from openapi_server.models.chat_completion_request_assistant_message_function_call import ChatCompletionRequestAssistantMessageFunctionCall
from openapi_server.models.chat_completion_request_function_message import ChatCompletionRequestFunctionMessage
from openapi_server.models.chat_completion_request_system_message import ChatCompletionRequestSystemMessage
from openapi_server.models.chat_completion_request_tool_message import ChatCompletionRequestToolMessage
from openapi_server.models.chat_completion_request_user_message import ChatCompletionRequestUserMessage


class ChatCompletionRequestMessage(BaseModel):
"""NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
Do not edit the class manually.
ChatCompletionRequestMessage - a model defined in OpenAPI
content: The content of this ChatCompletionRequestMessage.
role: The role of this ChatCompletionRequestMessage.
name: The name of this ChatCompletionRequestMessage.
tool_calls: The tool_calls of this ChatCompletionRequestMessage [Optional].
function_call: The function_call of this ChatCompletionRequestMessage [Optional].
tool_call_id: The tool_call_id of this ChatCompletionRequestMessage.
"""

content: str = Field(alias="content")
role: str = Field(alias="role")
name: Optional[str] = Field(alias="name", default=None)
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(alias="tool_calls", default=None)
function_call: Optional[ChatCompletionRequestAssistantMessageFunctionCall] = Field(alias="function_call", default=None)
tool_call_id: Optional[str] = Field(alias="tool_call_id", default=None)

ChatCompletionRequestMessage.update_forward_refs()
9 changes: 9 additions & 0 deletions impl/model/client_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Literal

from openai.types.beta.threads import Run as ClientRun


class Run(ClientRun):
status: Literal[
"queued", "in_progress", "requires_action", "cancelling", "cancelled", "failed", "completed", "expired", "generating"
]
10 changes: 10 additions & 0 deletions impl/model/create_assistant_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, Annotated, List

from pydantic import Field

from impl.model.assistant_object_tools_inner import AssistantObjectToolsInner
from openapi_server.models.create_assistant_request import CreateAssistantRequest as CreateAssistantRequestGenerated


class CreateAssistantRequest(CreateAssistantRequestGenerated):
tools: Optional[Annotated[List[AssistantObjectToolsInner], Field(max_length=128)]] = Field(default=None, description="assistant_tools_param_description")
Loading

0 comments on commit d82a728

Please sign in to comment.