Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Add better support for file content parsing with Python Interpreter #805

Merged
merged 8 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ exec-backend:
exec-db:
docker exec -ti cohere-toolkit-db-1 bash

.PHONY: exec-terrarium
exec-terrarium:
docker exec -ti -u root cohere-toolkit-terrarium-1 /bin/sh

.PHONY: migration
migration:
docker compose run --build backend alembic -c src/backend/alembic.ini revision --autogenerate -m "$(message)"
Expand Down
2 changes: 0 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ services:
volumes:
# Mount alembic folder to sync migrations
- ./src/backend/alembic:/workspace/src/backend/alembic
# Mount data folder to sync uploaded files
- ./src/backend/data:/workspace/src/backend/data
# Mount configurations
- ./src/backend/config/secrets.yaml:/workspace/src/backend/config/secrets.yaml
- ./src/backend/config/configuration.yaml:/workspace/src/backend/config/configuration.yaml
Expand Down
6 changes: 3 additions & 3 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory

TIMEOUT = 300
TIMEOUT_SECONDS = 60

logger = LoggerFactory().get_logger()

Expand Down Expand Up @@ -60,13 +60,13 @@ async def _call_all_tools_async(
]
combined = asyncio.gather(*tasks)
try:
tool_results = await asyncio.wait_for(combined, timeout=TIMEOUT)
tool_results = await asyncio.wait_for(combined, timeout=TIMEOUT_SECONDS)
# Flatten a list of list of tool results
return [n for m in tool_results for n in m]
except asyncio.TimeoutError:
raise HTTPException(
status_code=500,
detail=f"Timeout while calling tools with timeout: {str(TIMEOUT)}",
detail=f"Timeout while calling tools with timeout: {TIMEOUT_SECONDS}",
)


Expand Down
2 changes: 1 addition & 1 deletion src/backend/config/configuration.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tools:
enabled_tools:
- wikipedia
- search_file
- read_document
- read_file
- toolkit_python_interpreter
- toolkit_calculator
- hybrid_web_search
Expand Down
6 changes: 5 additions & 1 deletion src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ class ToolName(StrEnum):
implementation=PythonInterpreter,
parameter_definitions={
"code": {
"description": "Python code to execute using an interpreter",
"description": (
"Python code to execute using the Python interpreter with no internet access. "
"Do not generate code that tries to open files directly, instead use file contents passed to the interpreter, "
"then print output or save output to a file."
),
"type": "str",
"required": True,
}
Expand Down
2 changes: 1 addition & 1 deletion src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ async def delete_agent_file(
"""
user_id = ctx.get_user_id()
_ = validate_agent_exists(session, agent_id)
validate_file(session, file_id, user_id, agent_id)
validate_file(session, file_id, user_id)

# Delete the File DB object
get_file_service().delete_agent_file_by_id(session, agent_id, file_id, user_id, ctx)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ async def delete_file(
"""
user_id = ctx.get_user_id()
_ = validate_conversation(session, conversation_id, user_id)
validate_file(session, file_id, user_id, conversation_id, ctx)
validate_file(session, file_id, user_id )

# Delete the File DB object
get_file_service().delete_conversation_file_by_id(
Expand Down
4 changes: 0 additions & 4 deletions src/backend/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,6 @@ class ChatResponseEvent(BaseModel):


class BaseChatRequest(BaseModel):
# user_id: str = Field(
# title="A user id to store to store the conversation under.", exclude=True
# )
message: str = Field(
title="The message to send to the chatbot.",
)
Expand All @@ -313,7 +310,6 @@ class BaseChatRequest(BaseModel):
default_factory=lambda: str(uuid4()),
title="To store a conversation then create a conversation id and use it for every related request",
)

tools: List[Tool] | None = Field(
default_factory=list,
title="""
Expand Down
1 change: 1 addition & 0 deletions src/backend/schemas/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Config:

class ConversationFilePublic(BaseModel):
id: str
user_id: str
created_at: datetime.datetime
updated_at: datetime.datetime

Expand Down
92 changes: 47 additions & 45 deletions src/backend/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_files_by_message_id(

# Misc
def validate_file(
session: DBSessionDep, file_id: str, user_id: str, index: str, ctx: Context
session: DBSessionDep, file_id: str, user_id: str
) -> File:
"""
Validates if a file exists and belongs to the user
Expand Down Expand Up @@ -366,6 +366,52 @@ def attach_conversation_id_to_files(
return results



def read_excel(file_contents: bytes) -> str:
"""Reads the text from an Excel file using Pandas

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the Excel
"""
excel = pd.read_excel(io.BytesIO(file_contents), engine="calamine")
return excel.to_string()


def read_docx(file_contents: bytes) -> str:
"""Reads the text from a DOCX file

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the DOCX file, with each paragraph separated by a newline
"""
document = Document(io.BytesIO(file_contents))
text = ""

for paragraph in document.paragraphs:
text += paragraph.text + "\n"

return text


def read_parquet(file_contents: bytes) -> str:
"""Reads the text from a Parquet file using Pandas

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the Parquet
"""
parquet = pd.read_parquet(io.BytesIO(file_contents), engine="pyarrow")
return parquet.to_string()



def get_file_extension(file_name: str) -> str:
"""Returns the file extension

Expand Down Expand Up @@ -411,47 +457,3 @@ async def get_file_content(file: FastAPIUploadFile) -> str:
return read_excel(file_contents)

raise ValueError(f"File extension {file_extension} is not supported")


def read_excel(file_contents: bytes) -> str:
"""Reads the text from an Excel file using Pandas

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the Excel
"""
excel = pd.read_excel(io.BytesIO(file_contents), engine="calamine")
return excel.to_string()


def read_docx(file_contents: bytes) -> str:
"""Reads the text from a DOCX file

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the DOCX file, with each paragraph separated by a newline
"""
document = Document(io.BytesIO(file_contents))
text = ""

for paragraph in document.paragraphs:
text += paragraph.text + "\n"

return text


def read_parquet(file_contents: bytes) -> str:
"""Reads the text from a Parquet file using Pandas

Args:
file_contents (bytes): The file contents

Returns:
str: The text extracted from the Parquet
"""
parquet = pd.read_parquet(io.BytesIO(file_contents), engine="pyarrow")
return parquet.to_string()
2 changes: 1 addition & 1 deletion src/backend/tests/unit/chat/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def call(
]


@patch("backend.chat.custom.tool_calls.TIMEOUT", 1)
@patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1)
def test_async_call_tools_timeout() -> None:
class MockCalculator(BaseTool):
NAME = "toolkit_calculator"
Expand Down
3 changes: 2 additions & 1 deletion src/backend/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ReadFileTool(BaseTool):
Tool to read a file from the file system.
"""

NAME = "read_document"
NAME = "read_file"
MAX_NUM_CHUNKS = 10
SEARCH_LIMIT = 5

Expand Down Expand Up @@ -75,6 +75,7 @@ async def call(

file_ids = [file_id for _, file_id in files]
retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id)

if not retrieved_files:
return []

Expand Down
21 changes: 1 addition & 20 deletions src/backend/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@

import requests
from dotenv import load_dotenv
from langchain_core.tools import Tool as LangchainTool
from pydantic.v1 import BaseModel, Field

from backend.config.settings import Settings
from backend.tools.base import BaseTool

load_dotenv()


class LangchainPythonInterpreterToolInput(BaseModel):
code: str = Field(description="Python code to execute.")


class PythonInterpreter(BaseTool):
"""
This class calls arbitrary code against a Python interpreter.
Expand All @@ -35,8 +29,8 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any):

code = parameters.get("code", "")
res = requests.post(self.INTERPRETER_URL, json={"code": code})
clean_res = self._clean_response(res.json())

clean_res = self._clean_response(res.json())
return clean_res

def _clean_response(self, result: Any) -> Dict[str, str]:
Expand Down Expand Up @@ -74,16 +68,3 @@ def _clean_response(self, result: Any) -> Dict[str, str]:
r[key] = str(value)

return result_list

# langchain does not return a dict as a parameter, only a code string
def langchain_call(self, code: str):
return self.call({"code": code}, ctx=None)

def to_langchain_tool(self) -> LangchainTool:
tool = LangchainTool(
name="python_interpreter",
description="Executes python code and returns the result. The code runs in a static sandbox without interactive mode, so print output or save output to a file.",
func=self.langchain_call,
)
tool.args_schema = LangchainPythonInterpreterToolInput
return tool
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ const Login: React.FC = () => {

return (
<div className="flex flex-col items-center justify-center">
<Text
as="h1"
styleAs="h3"
>
<Text as="h1" styleAs="h3">
Log in
</Text>
<div className="mt-10 flex w-full flex-col items-center gap-1">
Expand Down
Loading
Loading