Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions ee/hogai/graph/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any, Generic, Literal, Union
from uuid import UUID

from django.conf import settings

from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from langgraph.types import StreamWriter

from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage, ReasoningMessage

Expand All @@ -17,6 +15,7 @@

from ee.hogai.context import AssistantContextManager
from ee.hogai.graph.mixins import AssistantContextMixin, ReasoningNodeMixin
from ee.hogai.utils.dispatch import internal_dispatch
from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.helpers import find_start_message
from ee.hogai.utils.state import LangGraphState
Expand All @@ -27,13 +26,13 @@
PartialStateType,
StateType,
)
from ee.hogai.utils.types.actions import AssistantAction
from ee.hogai.utils.types.composed import MaxNodeName
from ee.models import Conversation


class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, ReasoningNodeMixin, ABC):
_writer: StreamWriter | None = None
config: RunnableConfig | None = None
_config: RunnableConfig | None = None
_context_manager: AssistantContextManager | None = None

def __init__(self, team: Team, user: User):
Expand All @@ -49,7 +48,7 @@ async def __call__(self, state: StateType, config: RunnableConfig) -> PartialSta
"""
Run the assistant node and handle cancelled conversation before the node is run.
"""
self.config = config
self._config = config
thread_id = (config.get("configurable") or {}).get("thread_id")
if thread_id and await self._is_conversation_cancelled(thread_id):
raise GenerationCanceled
Expand All @@ -67,34 +66,23 @@ def run(self, state: StateType, config: RunnableConfig) -> PartialStateType | No
async def arun(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
raise NotImplementedError

@property
def writer(self) -> StreamWriter | Callable[[Any], None]:
if self._writer:
return self._writer
try:
self._writer = get_stream_writer()
except RuntimeError:
# Not in a LangGraph context (e.g., during testing)
def noop(*args, **kwargs):
pass

return noop
return self._writer

@property
def context_manager(self) -> AssistantContextManager:
if self._context_manager is None:
if self.config is None:
if self._config is None:
# Only allow default config in test environments
if settings.TEST:
config = RunnableConfig(configurable={})
else:
raise ValueError("Config is required to create AssistantContextManager")
else:
config = self.config
config = self._config
self._context_manager = AssistantContextManager(self._team, self._user, config)
return self._context_manager

def dispatch(self, event: AssistantAction):
internal_dispatch(self.writer, self._config)(event)

async def _is_conversation_cancelled(self, conversation_id: UUID) -> bool:
conversation = await self._aget_conversation(conversation_id)
if not conversation:
Expand Down
18 changes: 18 additions & 0 deletions ee/hogai/graph/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import datetime
from abc import ABC
from collections.abc import Callable
from typing import Any, Optional, get_args, get_origin
from uuid import UUID

from django.utils import timezone

from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from langgraph.types import StreamWriter

from posthog.schema import ReasoningMessage

Expand All @@ -21,6 +24,21 @@
class AssistantContextMixin(ABC):
_team: Team
_user: User
_writer: StreamWriter | None = None

@property
def writer(self) -> StreamWriter | Callable[[Any], None]:
if self._writer:
return self._writer
try:
self._writer = get_stream_writer()
except RuntimeError:
# Not in a LangGraph context (e.g., during testing)
def noop(*args, **kwargs):
pass

return noop
return self._writer

async def _aget_core_memory(self) -> CoreMemory | None:
try:
Expand Down
2 changes: 1 addition & 1 deletion ee/hogai/graph/root/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs
root_tool_calls_count=tool_call_count + 1,
)
elif ToolClass := get_contextual_tool_class(tool_call.name):
tool_class = ToolClass(team=self._team, user=self._user, state=state)
tool_class = ToolClass(team=self._team, user=self._user, state=state, config=config)
try:
result = await tool_class.ainvoke(tool_call.model_dump(), config)
except Exception as e:
Expand Down
10 changes: 6 additions & 4 deletions ee/hogai/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.mixins import AssistantContextMixin
from ee.hogai.utils.dispatch import internal_dispatch
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.actions import AssistantAction

CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantContextualTool, type["MaxTool"]] = {}

Expand Down Expand Up @@ -59,7 +61,6 @@ class MaxTool(AssistantContextMixin, BaseTool):

show_tool_call_message: bool = Field(description="Whether to show tool call messages.", default=True)

_context: dict[str, Any]
_config: RunnableConfig
_state: AssistantState
_context_manager: AssistantContextManager
Expand Down Expand Up @@ -131,9 +132,10 @@ async def _arun(self, *args, config: RunnableConfig, **kwargs):

@property
def context(self) -> dict:
if not hasattr(self, "_context"):
raise AttributeError("Tool has not been run yet")
return self._context
return self._context_manager.get_contextual_tools().get(self.get_name(), {})

def dispatch(self, event: AssistantAction):
internal_dispatch(self.writer, self._config)(event)

def format_system_prompt_injection(self, context: dict[str, Any]) -> str:
formatted_context = {
Expand Down
16 changes: 16 additions & 0 deletions ee/hogai/utils/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections.abc import Callable

from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter

from ee.hogai.utils.types.actions import AssistantAction, DispatchedAction


def internal_dispatch(writer: StreamWriter, config: RunnableConfig | None) -> Callable[[AssistantAction], None]:
if not config:
raise AttributeError("Config is required to dispatch actions")

def dispatch(action: AssistantAction) -> None:
writer(DispatchedAction(langgraph_node=config.get("langgraph_node"), action=action))

return dispatch
15 changes: 15 additions & 0 deletions ee/hogai/utils/types/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any, Union

from pydantic import BaseModel


class UpdateReasoning(BaseModel):
content: str


AssistantAction = Union[UpdateReasoning]


class DispatchedAction(BaseModel):
langgraph_node: Any
action: AssistantAction
Loading