Skip to content

Commit

Permalink
agent log parsing using marvin
Browse files Browse the repository at this point in the history
  • Loading branch information
minump committed Dec 7, 2023
1 parent 07c600f commit e24d717
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions ai_ta_backend/agents/customcallbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Union, Any, List
import re
import os

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish
Expand All @@ -8,7 +8,23 @@
from langchain.llms import OpenAI
from langchain import hub
from utils import SupabaseDB, get_langsmith_id

import marvin
from marvin import ai_model
from pydantic import BaseModel, Field

from dotenv import load_dotenv
load_dotenv(override=True, dotenv_path='.env')

marvin.settings.openai.api_key = os.getenv('OPENAI_API_KEY')
@ai_model
class AgentActionParser(BaseModel):
log: str = Field(..., description="clean logs information from input, without new line breaks and markdowns")
action: str = Field(..., description="parse action field from logs")
action_input: str = Field(..., description="parse action input from logs")
action_output: str = Field(..., description="parse action output from logs")
tool: str = Field(..., description="any tool information from input")
tool_input: str = Field(..., description="any tool input")
tool_output: str = Field(..., description="parse tool output")

class CustomCallbackHandler(BaseCallbackHandler):
"""A callback handler that stores the LLM's context and action in memory."""
Expand All @@ -30,7 +46,6 @@ def __init__(self, run_id=None, image_name=None):
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Callback for when a tool starts.
Set tool_in_progress to True and store the tool's name in memory."""

self.tool_in_progress['status'] = True
self.tool_in_progress['name'] = serialized['name']

Expand All @@ -45,10 +60,10 @@ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: An
response = self.db.upsert_field_in_db("on_tool_start", [serialized])



def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Callback for when a tool ends.
Use this to store the tool's output in a database. Use tool start parameters to identify the tool."""

if self.tool_in_progress['status']:
tool_name = self.tool_in_progress['name']
output = {"name": tool_name, "output": output}
Expand Down Expand Up @@ -78,20 +93,20 @@ def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when LLM predicts an action. Parse the action and store it in a database."""
action = action.dict()
# Use regex to delete everything after two consecutive newlines
action['log'] = re.sub(r'\n\n.*', '', action['log'], flags=re.DOTALL)
#action['log'] = action['log'].replace(["\n", "```"], "")
parsed_action = AgentActionParser(action)
action_data = {"log": parsed_action.log, "action": parsed_action.action, "action_input": parsed_action.action_input,
"action_output": parsed_action.action_output, "tool": parsed_action.tool,
"tool_input": parsed_action.tool_input, "tool_output": parsed_action.tool_output}

if self.db.is_exists_image():
data = self.db.fetch_field_from_db("on_agent_action")
if data:
data.append(action)
data.append(action_data)
response = self.db.update_field_in_db("on_agent_action", data)
else:
response = self.db.update_field_in_db("on_agent_action", [action])
response = self.db.update_field_in_db("on_agent_action", [action_data])
else:
response = self.db.upsert_field_in_db("on_agent_action", [action])

response = self.db.upsert_field_in_db("on_agent_action", [action_data])

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when LLM finishes. Store the finish in a database."""
Expand All @@ -108,11 +123,10 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
response = self.db.upsert_field_in_db("on_agent_finish", finish)


# def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# print(f"on_llm_start {serialized}")
# def agent_output():
# TODO
# missing agent output in db

# def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
# print(f"on_new_token {token}")

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Callback for when a chain starts."""
Expand Down

0 comments on commit e24d717

Please sign in to comment.