diff --git a/ee/hogai/graph/base.py b/ee/hogai/graph/base.py index 31f9677316fa7..d397643139b95 100644 --- a/ee/hogai/graph/base.py +++ b/ee/hogai/graph/base.py @@ -1,13 +1,11 @@ from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Sequence from typing import Any, Generic, Literal, Union from uuid import UUID from django.conf import settings from langchain_core.runnables import RunnableConfig -from langgraph.config import get_stream_writer -from langgraph.types import StreamWriter from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage, ReasoningMessage @@ -17,6 +15,7 @@ from ee.hogai.context import AssistantContextManager from ee.hogai.graph.mixins import AssistantContextMixin, ReasoningNodeMixin +from ee.hogai.utils.dispatch import internal_dispatch from ee.hogai.utils.exceptions import GenerationCanceled from ee.hogai.utils.helpers import find_start_message from ee.hogai.utils.state import LangGraphState @@ -27,13 +26,13 @@ PartialStateType, StateType, ) +from ee.hogai.utils.types.actions import AssistantAction from ee.hogai.utils.types.composed import MaxNodeName from ee.models import Conversation class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, ReasoningNodeMixin, ABC): - _writer: StreamWriter | None = None - config: RunnableConfig | None = None + _config: RunnableConfig | None = None _context_manager: AssistantContextManager | None = None def __init__(self, team: Team, user: User): @@ -49,7 +48,7 @@ async def __call__(self, state: StateType, config: RunnableConfig) -> PartialSta """ Run the assistant node and handle cancelled conversation before the node is run. """ - self.config = config + self._config = config thread_id = (config.get("configurable") or {}).get("thread_id") if thread_id and await self._is_conversation_cancelled(thread_id): raise GenerationCanceled @@ -67,34 +66,23 @@ def run(self, state: StateType, config: RunnableConfig) -> PartialStateType | No async def arun(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: raise NotImplementedError - @property - def writer(self) -> StreamWriter | Callable[[Any], None]: - if self._writer: - return self._writer - try: - self._writer = get_stream_writer() - except RuntimeError: - # Not in a LangGraph context (e.g., during testing) - def noop(*args, **kwargs): - pass - - return noop - return self._writer - @property def context_manager(self) -> AssistantContextManager: if self._context_manager is None: - if self.config is None: + if self._config is None: # Only allow default config in test environments if settings.TEST: config = RunnableConfig(configurable={}) else: raise ValueError("Config is required to create AssistantContextManager") else: - config = self.config + config = self._config self._context_manager = AssistantContextManager(self._team, self._user, config) return self._context_manager + def dispatch(self, event: AssistantAction): + internal_dispatch(self.writer, self._config)(event) + async def _is_conversation_cancelled(self, conversation_id: UUID) -> bool: conversation = await self._aget_conversation(conversation_id) if not conversation: diff --git a/ee/hogai/graph/mixins.py b/ee/hogai/graph/mixins.py index 3e11b365a50d7..613cd2c9f8281 100644 --- a/ee/hogai/graph/mixins.py +++ b/ee/hogai/graph/mixins.py @@ -1,11 +1,14 @@ import datetime from abc import ABC +from collections.abc import Callable from typing import Any, Optional, get_args, get_origin from uuid import UUID from django.utils import timezone from langchain_core.runnables import RunnableConfig +from langgraph.config import get_stream_writer +from langgraph.types import StreamWriter from posthog.schema import ReasoningMessage @@ -21,6 +24,21 @@ class AssistantContextMixin(ABC): _team: Team _user: User + _writer: StreamWriter | None = None + + @property + def writer(self) -> StreamWriter | Callable[[Any], None]: + if self._writer: + return self._writer + try: + self._writer = get_stream_writer() + except RuntimeError: + # Not in a LangGraph context (e.g., during testing) + def noop(*args, **kwargs): + pass + + return noop + return self._writer async def _aget_core_memory(self) -> CoreMemory | None: try: diff --git a/ee/hogai/graph/root/nodes.py b/ee/hogai/graph/root/nodes.py index 2487fee021171..544f29a12aa95 100644 --- a/ee/hogai/graph/root/nodes.py +++ b/ee/hogai/graph/root/nodes.py @@ -472,7 +472,7 @@ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAs root_tool_calls_count=tool_call_count + 1, ) elif ToolClass := get_contextual_tool_class(tool_call.name): - tool_class = ToolClass(team=self._team, user=self._user, state=state) + tool_class = ToolClass(team=self._team, user=self._user, state=state, config=config) try: result = await tool_class.ainvoke(tool_call.model_dump(), config) except Exception as e: diff --git a/ee/hogai/tool.py b/ee/hogai/tool.py index 989e3576feb09..058ed7282bbfa 100644 --- a/ee/hogai/tool.py +++ b/ee/hogai/tool.py @@ -16,7 +16,9 @@ from ee.hogai.context.context import AssistantContextManager from ee.hogai.graph.mixins import AssistantContextMixin +from ee.hogai.utils.dispatch import internal_dispatch from ee.hogai.utils.types import AssistantState +from ee.hogai.utils.types.actions import AssistantAction CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantContextualTool, type["MaxTool"]] = {} @@ -59,7 +61,6 @@ class MaxTool(AssistantContextMixin, BaseTool): show_tool_call_message: bool = Field(description="Whether to show tool call messages.", default=True) - _context: dict[str, Any] _config: RunnableConfig _state: AssistantState _context_manager: AssistantContextManager @@ -131,9 +132,10 @@ async def _arun(self, *args, config: RunnableConfig, **kwargs): @property def context(self) -> dict: - if not hasattr(self, "_context"): - raise AttributeError("Tool has not been run yet") - return self._context + return self._context_manager.get_contextual_tools().get(self.get_name(), {}) + + def dispatch(self, event: AssistantAction): + internal_dispatch(self.writer, self._config)(event) def format_system_prompt_injection(self, context: dict[str, Any]) -> str: formatted_context = { diff --git a/ee/hogai/utils/dispatch.py b/ee/hogai/utils/dispatch.py new file mode 100644 index 0000000000000..376d47c903280 --- /dev/null +++ b/ee/hogai/utils/dispatch.py @@ -0,0 +1,16 @@ +from collections.abc import Callable + +from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter + +from ee.hogai.utils.types.actions import AssistantAction, DispatchedAction + + +def internal_dispatch(writer: StreamWriter, config: RunnableConfig | None) -> Callable[[AssistantAction], None]: + if not config: + raise AttributeError("Config is required to dispatch actions") + + def dispatch(action: AssistantAction) -> None: + writer(DispatchedAction(langgraph_node=config.get("langgraph_node"), action=action)) + + return dispatch diff --git a/ee/hogai/utils/types/actions.py b/ee/hogai/utils/types/actions.py new file mode 100644 index 0000000000000..93a2855477405 --- /dev/null +++ b/ee/hogai/utils/types/actions.py @@ -0,0 +1,15 @@ +from typing import Any, Union + +from pydantic import BaseModel + + +class UpdateReasoning(BaseModel): + content: str + + +AssistantAction = Union[UpdateReasoning] + + +class DispatchedAction(BaseModel): + langgraph_node: Any + action: AssistantAction