diff --git a/pyproject.toml b/pyproject.toml index 6463fc637490..609691310425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,8 @@ dependencies = [ "cuga~=0.1.11", "agent-lifecycle-toolkit~=0.4.4", "astrapy>=2.1.0,<3.0.0", - "aioboto3>=15.2.0,<16.0.0" + "aioboto3>=15.2.0,<16.0.0", + "ag-ui-protocol>=0.1.10", ] diff --git a/src/lfx/pyproject.toml b/src/lfx/pyproject.toml index 0c59307455f4..5378dcb92056 100644 --- a/src/lfx/pyproject.toml +++ b/src/lfx/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "validators>=0.34.0,<1.0.0", "filelock>=3.20.0", "pypdf>=5.1.0", + "ag-ui-protocol>=0.1.10", ] [project.scripts] diff --git a/src/lfx/src/lfx/events/observability/__init__.py b/src/lfx/src/lfx/events/observability/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/lfx/src/lfx/events/observability/lifecycle_events.py b/src/lfx/src/lfx/events/observability/lifecycle_events.py new file mode 100644 index 000000000000..c1b4b78ee21e --- /dev/null +++ b/src/lfx/src/lfx/events/observability/lifecycle_events.py @@ -0,0 +1,111 @@ +import functools +from collections.abc import Awaitable, Callable +from typing import Any + +from ag_ui.encoder import EventEncoder + +from lfx.log.logger import logger + +AsyncMethod = Callable[..., Awaitable[Any]] + + +def observable(observed_method: AsyncMethod) -> AsyncMethod: + """Make an async method emit lifecycle events by invoking optional callback hooks on the hosting instance. + + The hosting class may implement the following optional hooks to produce event payloads: + - before_callback_event(*args, **kwargs) -> dict: called before the decorated method runs. + - after_callback_event(result, *args, **kwargs) -> dict: called after the decorated method completes successfully. + - error_callback_event(exception, *args, **kwargs) -> dict: called if the decorated method raises an exception. + + If a hook is implemented, its returned dictionary will be encoded via EventEncoder and prepared for publishing; if a hook is absent, the corresponding event is skipped. Payloads may include custom metrics under the 'langflow' key inside a 'raw_events' dictionary. + + Returns: + The wrapped async function that preserves the original method's behavior while invoking lifecycle hooks when available. + """ + + async def check_event_manager(self, **kwargs): + """Check whether an EventManager instance is present in the provided keyword arguments. + + Parameters: + kwargs: Expects an 'event_manager' key whose value is the EventManager used for publishing lifecycle events. + + Returns: + `True` if 'event_manager' exists in kwargs and is not None, `False` otherwise. + """ + if "event_manager" not in kwargs or kwargs["event_manager"] is None: + await logger.awarning( + f"EventManager not available/provided, skipping observable event publishing " + f"from {self.__class__.__name__}" + ) + return False + return True + + async def before_callback(self, *args, **kwargs): + """Invoke the instance's pre-execution lifecycle hook to produce and encode an event payload. + + Checks for a valid `event_manager` in `kwargs`; if absent the function returns without action. + If the hosting instance implements `before_callback_event(*args, **kwargs)`, calls it to obtain a payload, + encodes the payload with EventEncoder (and prepares it for publishing). If the hook is not implemented, + logs a warning and skips publishing. + """ + if not await check_event_manager(self, **kwargs): + return + + if hasattr(self, "before_callback_event"): + event_payload = self.before_callback_event(*args, **kwargs) + encoder = EventEncoder() + event_payload = encoder.encode(event_payload) + # TODO: Publish event + else: + await logger.awarning( + f"before_callback_event not implemented for {self.__class__.__name__}. Skipping event publishing." + ) + + async def after_callback(self, res: Any | None = None, *args, **kwargs): + """Invoke the instance's after_callback_event to produce and encode a post-execution event payload when an EventManager is provided. + + Parameters: + res (Any | None): The result produced by the observed method; forwarded to `after_callback_event`. + *args: Positional arguments forwarded to `after_callback_event`. + **kwargs: Keyword arguments forwarded to `after_callback_event`. May include `event_manager` required to publish events; if no valid `event_manager` is present, the function returns without encoding or publishing. + """ + if not await check_event_manager(self, **kwargs): + return + if hasattr(self, "after_callback_event"): + event_payload = self.after_callback_event(res, *args, **kwargs) + encoder = EventEncoder() + event_payload = encoder.encode(event_payload) + # TODO: Publish event + else: + await logger.awarning( + f"after_callback_event not implemented for {self.__class__.__name__}. Skipping event publishing." + ) + + @functools.wraps(observed_method) + async def wrapper(self, *args, **kwargs): + """Wraps the observed async method to emit lifecycle events before execution, after successful completion, and on error. + + Calls the hosting instance's before_callback and after_callback helpers to produce and encode event payloads when available; if an exception occurs, encodes an error payload using the instance's error_callback_event when present, then re-raises the exception. + + Returns: + The value returned by the wrapped observed method. + + Raises: + Exception: Propagates any exception raised by the observed method after encoding the error event (if available). + """ + await before_callback(self, *args, **kwargs) + result = None + try: + result = await observed_method(self, *args, **kwargs) + await after_callback(self, result, *args, **kwargs) + except Exception as e: + await logger.aerror(f"Exception in {self.__class__.__name__}: {e}") + if hasattr(self, "error_callback_event"): + error_payload = self.error_callback_event(e, *args, **kwargs) + encoder = EventEncoder() + encoder.encode(error_payload) + # TODO: Publish error event + raise + return result + + return wrapper diff --git a/src/lfx/src/lfx/graph/graph/base.py b/src/lfx/src/lfx/graph/graph/base.py index b6e969465ed2..3317d74a762f 100644 --- a/src/lfx/src/lfx/graph/graph/base.py +++ b/src/lfx/src/lfx/graph/graph/base.py @@ -15,6 +15,9 @@ from itertools import chain from typing import TYPE_CHECKING, Any, cast +from ag_ui.core import RunFinishedEvent, RunStartedEvent + +from lfx.events.observability.lifecycle_events import observable from lfx.exceptions.component import ComponentBuildError from lfx.graph.edge.base import CycleEdge, Edge from lfx.graph.graph.constants import Finish, lazy_load_vertex_dict @@ -68,12 +71,27 @@ def __init__( log_config: LogConfig | None = None, context: dict[str, Any] | None = None, ) -> None: - """Initializes a new Graph instance. + """Create a new Graph instance and initialize its internal execution state. + + Parameters: + start (Component | None): Optional start component for the graph; when provided together with `end` + the graph is added and prepared for execution. + end (Component | None): Optional end component for the graph; must be provided together with `start`. + flow_id (str | None): Optional identifier for the flow. + flow_name (str | None): Optional human-readable flow name. + description (str | None): Optional flow description. + user_id (str | None): Optional user identifier used when instantiating components. + log_config (LogConfig | None): Optional logging configuration; if provided, logging is configured. + context (dict[str, Any] | None): Optional execution context; must be a dictionary if provided. + + Raises: + TypeError: If `context` is provided and is not a dict. + ValueError: If exactly one of `start` or `end` is provided. - If both start and end components are provided, the graph is initialized and prepared for execution. - If only one is provided, a ValueError is raised. The context must be a dictionary if specified, - otherwise a TypeError is raised. Internal data structures for vertices, edges, state management, - run management, and tracing are set up during initialization. + Notes: + - When both `start` and `end` are provided the graph will be wired and prepared (prepare is called). + - The constructor initializes internal structures used for vertices, edges, run management, caching, + tracing, and snapshotting. """ if log_config: configure(**log_config) @@ -144,7 +162,13 @@ def __init__( @property def lock(self): - """Lazy initialization of asyncio.Lock to avoid event loop binding issues.""" + """Provide a lazily-initialized asyncio Lock. + + Initializes the lock on first access to avoid binding it to an event loop at object construction. + + Returns: + lock (asyncio.Lock): The lock instance created on first access and reused thereafter. + """ if self._lock is None: self._lock = asyncio.Lock() return self._lock @@ -708,11 +732,15 @@ def define_vertices_lists(self) -> None: self._is_state_vertices.append(vertex.id) def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None) -> None: - """Updates input vertices' parameters with the provided inputs, filtering by component list and input type. + """Update input vertices' raw parameters from the provided inputs, filtering which vertices are updated by a list of component identifiers and an optional input type. + + Parameters: + input_components (list[str]): Vertex IDs or display names to target; if empty, all input vertices are considered. + inputs (dict[str, str]): Mapping of input field names to values to set on each matched input vertex. + input_type (InputType | None): If `None` or `"any"`, do not filter by type; otherwise only update vertices whose ID contains this type (case-insensitive). - Only vertices whose IDs or display names match the specified input components and whose IDs contain - the input type (unless input type is 'any' or None) are updated. Raises a ValueError if a specified - vertex is not found. + Raises: + ValueError: If a referenced input vertex cannot be found. """ for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) @@ -728,6 +756,7 @@ def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input raise ValueError(msg) vertex.update_raw_params(inputs, overwrite=True) + @observable async def _run( self, *, @@ -740,20 +769,24 @@ async def _run( fallback_to_env_vars: bool, event_manager: EventManager | None = None, ) -> list[ResultData | None]: - """Runs the graph with the given inputs. - - Args: - inputs (Dict[str, str]): The input values for the graph. - input_components (list[str]): The components to run for the inputs. - input_type: (Optional[InputType]): The input type. - outputs (list[str]): The outputs to retrieve from the graph. - stream (bool): Whether to stream the results or not. - session_id (str): The session ID for the graph. - fallback_to_env_vars (bool): Whether to fallback to environment variables. - event_manager (EventManager | None): The event manager for the graph. + """Execute the graph using the provided inputs and return the results collected from output vertices. + + Parameters: + inputs (dict[str, str]): Map of input field names to string values used to populate graph inputs. + input_components (list[str]): IDs of components that should receive the inputs; empty list means no specific mapping. + input_type (InputType | None): Optional type hint for the provided inputs. + outputs (list[str]): List of output component IDs or display names to collect results from; empty list collects all output vertices. + stream (bool): If True, streaming outputs may be left as generators; if False, streaming generators will be consumed to produce final results. + session_id (str): Session identifier to attach to vertices that accept session-scoped parameters. + fallback_to_env_vars (bool): If True, allow components to read missing inputs from environment variables where supported. + event_manager (EventManager | None): Optional event manager used during processing for lifecycle or observability hooks. Returns: - List[Optional["ResultData"]]: The outputs of the graph. + list[ResultData | None]: A list of results corresponding to output vertices; each element is a vertex result (`ResultData`) or `None` if a vertex produced no result. + + Raises: + TypeError: If an expected input value is not a string. + ValueError: If provided component lists are invalid, a referenced vertex is missing, or graph processing fails. """ if input_components and not isinstance(input_components, list): msg = f"Invalid components value: {input_components}. Expected list" @@ -1399,6 +1432,18 @@ async def astep( user_id: str | None = None, event_manager: EventManager | None = None, ): + """Advance the graph execution by building the next scheduled vertex and update run state. + + Parameters: + inputs (InputValueRequest | None): Optional input values for the vertex being built. + files (list[str] | None): Optional list of file paths to provide to the vertex build. + user_id (str | None): Optional identifier of the user initiating the step. + event_manager (EventManager | None): Optional event manager used during vertex build. + + Returns: + VertexBuildResult: Result of building the next vertex when a vertex was processed. + Finish: A sentinel `Finish` instance when the run queue is empty and execution is complete. + """ if not self._prepared: msg = "Graph not prepared. Call prepare() first." raise ValueError(msg) @@ -1421,6 +1466,11 @@ async def get_cache_func(*args, **kwargs): # noqa: ARG001 return None async def set_cache_func(*args, **kwargs) -> bool: # noqa: ARG001 + """No-op fallback cache setter that accepts any arguments and always reports success. + + Returns: + `true` indicating the cache operation was accepted. + """ return True vertex_build_result = await self.build_vertex( @@ -2288,7 +2338,13 @@ def build_adjacency_maps(edges: list[CycleEdge]) -> tuple[dict[str, list[str]], return predecessor_map, successor_map def __to_dict(self) -> dict[str, dict[str, list[str]]]: - """Converts the graph to a dictionary.""" + """Produce a mapping of each vertex ID to its successor and predecessor vertex ID lists. + + Returns: + dict[str, dict[str, list[str]]]: A dictionary where each key is a vertex ID and each value is a dict with two keys: + - "successors": list of successor vertex IDs + - "predecessors": list of predecessor vertex IDs + """ result: dict = {} for vertex in self.vertices: vertex_id = vertex.id @@ -2296,3 +2352,47 @@ def __to_dict(self) -> dict[str, dict[str, list[str]]]: predecessors = [i.id for i in self.get_predecessors(vertex)] result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}} return result + + def raw_event_metrics(self, optional_fields: dict | None = None) -> dict: + """Create a timestamped metrics payload merged with any provided fields. + + Parameters: + optional_fields (dict | None): Additional key-value metrics to include in the payload. + + Returns: + dict: A dictionary containing a "timestamp" key with the current POSIX time in seconds and the merged optional fields. + """ + if optional_fields is None: + optional_fields = {} + import time + + return {"timestamp": time.time(), **optional_fields} + + def before_callback_event(self, *args, **kwargs) -> RunStartedEvent: # noqa: ARG002 + """Create a RunStartedEvent populated with the graph's run and flow identifiers and optional raw metrics. + + If the Graph exposes `raw_event_metrics`, its output is included in the event's `raw_event` (it will include `total_components` keyed to the count of vertices). + + Returns: + RunStartedEvent: Event with `run_id` set to the graph's `_run_id`, `thread_id` set to `flow_id`, and `raw_event` containing any metrics. + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"total_components": len(self.vertices)}) + return RunStartedEvent(run_id=self._run_id, thread_id=self.flow_id, raw_event=metrics) + + def after_callback_event(self, result: Any = None, *args, **kwargs) -> RunFinishedEvent: # noqa: ARG002 + """Create a RunFinishedEvent representing the end of the current run. + + Parameters: + result (Any): Final run result (currently unused when constructing the event). + *args: Ignored. + **kwargs: Ignored. + + Returns: + RunFinishedEvent: Event containing `run_id`, `thread_id` (flow_id), `result` set to `None`, and optional `raw_event` metrics including `total_components` when available. + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"total_components": len(self.vertices)}) + return RunFinishedEvent(run_id=self._run_id, thread_id=self.flow_id, result=None, raw_event=metrics) diff --git a/src/lfx/src/lfx/graph/vertex/base.py b/src/lfx/src/lfx/graph/vertex/base.py index 1d8cdbb595ce..86c854045ab4 100644 --- a/src/lfx/src/lfx/graph/vertex/base.py +++ b/src/lfx/src/lfx/graph/vertex/base.py @@ -8,6 +8,9 @@ from enum import Enum from typing import TYPE_CHECKING, Any +from ag_ui.core import StepFinishedEvent, StepStartedEvent + +from lfx.events.observability.lifecycle_events import observable from lfx.exceptions.component import ComponentBuildError from lfx.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData from lfx.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction @@ -168,6 +171,13 @@ def get_built_result(self): # If the Vertex.type is a power component # then we need to return the built object # instead of the result dict + """Return the vertex's effective built result transformed for callers. + + If this vertex represents an interface/component that produced a concrete built object, the built object (or its `content` attribute for model-like objects) is returned. If the built object is a string it will be promoted to the built result. If no concrete result is available, an empty dict is returned. Otherwise, if the stored built result is a dict it is returned as-is; non-dict results are wrapped in a dict under the key `"result"`. + + Returns: + dict | Any: The effective result for consumers — either a dict of outputs, an empty dict when unbuilt, or a wrapped non-dict value (or the raw built object/model content for interface components). + """ if self.is_interface_component and not isinstance(self.built_object, UnbuiltObject): result = self.built_object # if it is not a dict or a string and hasattr model_dump then @@ -180,10 +190,14 @@ def get_built_result(self): if isinstance(self.built_result, UnbuiltResult): return {} + return self.built_result if isinstance(self.built_result, dict) else {"result": self.built_result} def set_artifacts(self) -> None: - pass + """Placeholder hook to update this vertex's artifacts after a build. + + Intended to be overridden by subclasses or implemented to extract and assign artifacts produced during the component build; currently a no-op. + """ @property def edges(self) -> list[CycleEdge]: @@ -375,22 +389,35 @@ def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwr self.updated_raw_params = True def instantiate_component(self, user_id=None) -> None: + """Ensure the vertex has an associated component instance by creating and attaching one if missing. + + Parameters: + user_id (Optional[str]): Identifier of the user requesting instantiation; forwarded to the component loader and may influence permission/context during creation. + """ if not self.custom_component: self.custom_component, _ = initialize.loading.instantiate_class( user_id=user_id, vertex=self, ) + @observable async def _build( self, fallback_to_env_vars, user_id=None, event_manager: EventManager | None = None, ) -> None: - """Initiate the build process.""" + """Perform the vertex build: resolve dependent vertices, instantiate or reuse a custom component, obtain build results, validate the built object, and mark the vertex as built. + + Parameters: + fallback_to_env_vars (bool): If True, allow parameter values to be resolved from environment variables when not provided. + user_id (Optional[str]): Identifier of the user performing the build, used when instantiating components. + + Raises: + ValueError: If the vertex's base type is not found. + """ await logger.adebug(f"Building {self.display_name}") await self._build_each_vertex_in_params_dict() - if self.base_type is None: msg = f"Base type for vertex {self.display_name} not found" raise ValueError(msg) @@ -819,8 +846,59 @@ def built_object_repr(self) -> str: return "Built successfully ✨" if self.built_object is not None else "Failed to build 😵‍💫" def apply_on_outputs(self, func: Callable[[Any], Any]) -> None: - """Applies a function to the outputs of the vertex.""" + """Apply a function to each output provided by the vertex's component. + + Parameters: + func (Callable[[Any], Any]): Function to invoke for each output value; return values are ignored. + """ if not self.custom_component or not self.custom_component.outputs: return # Apply the function to each output [func(output) for output in self.custom_component.get_outputs_map().values()] + + # AGUI/AG UI Event Streaming Callbacks/Methods - (Optional, see Observable decorator) + def raw_event_metrics(self, optional_fields: dict | None) -> dict: + """Builds an AGUI-compatible metrics payload containing a timestamp and any provided additional fields. + + Parameters: + optional_fields (dict | None): Additional key/value pairs to include in the payload. If None, no extra fields are added. + + Returns: + dict: A metrics dictionary with a "timestamp" (seconds since the epoch) merged with the provided optional fields. + """ + if optional_fields is None: + optional_fields = {} + import time + + return {"timestamp": time.time(), **optional_fields} + + def before_callback_event(self, *args, **kwargs) -> StepStartedEvent: # noqa: ARG002 + """Create an AGUI-compatible StepStartedEvent for this vertex. + + If the vertex exposes `raw_event_metrics`, its output is included under the `raw_event["langflow"]` key with `component_id` set to this vertex's id. + + Returns: + StepStartedEvent: event with `step_name` set to the vertex's display name and `raw_event` containing the assembled metrics. + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"component_id": self.id}) + + return StepStartedEvent(step_name=self.display_name, raw_event={"langflow": metrics}) + + def after_callback_event(self, result, *args, **kwargs) -> StepFinishedEvent: # noqa: ARG002 + """Create an AGUI-compatible StepFinishedEvent for this vertex. + + Parameters: + result: The final result produced by the step (may be unused by this method). + *args: Additional positional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + + Returns: + StepFinishedEvent: Event with `step_name` set to the vertex display name and `raw_event` + containing a `langflow` entry with metrics (including `component_id`). + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"component_id": self.id}) + return StepFinishedEvent(step_name=self.display_name, raw_event={"langflow": metrics}) diff --git a/src/lfx/tests/unit/events/observability/__init__.py b/src/lfx/tests/unit/events/observability/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/lfx/tests/unit/events/observability/test_lifecycle_events.py b/src/lfx/tests/unit/events/observability/test_lifecycle_events.py new file mode 100644 index 000000000000..13ac7f915619 --- /dev/null +++ b/src/lfx/tests/unit/events/observability/test_lifecycle_events.py @@ -0,0 +1,337 @@ +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Import the actual decorator we want to test +from lfx.events.observability.lifecycle_events import observable + + +# Mock classes for dependencies +class MockEventManager: + """Mock for lfx.events.event_manager.EventManager.""" + + def __init__(self): + # We'll use AsyncMock for publish + """Initialize the mock event manager with an asynchronous `publish` method. + + Sets `self.publish` to an AsyncMock so tests can await and assert calls to the publish coroutine. + """ + self.publish = AsyncMock() + + +class MockLogger: + """Mock for lfx.log.logger.logger.""" + + def __init__(self): + """Create a mock logger exposing awaitable `awarning` and `aerror` callables. + + The constructor initializes `awarning` and `aerror` as AsyncMock instances that can be awaited like asynchronous warning and error logging methods. + """ + self.awarning = AsyncMock() + self.aerror = AsyncMock() + + +# --- Pytest Fixtures --- + + +@pytest.fixture +def mock_dependencies(): + """Create and patch mocked external dependencies for tests. + + Yields: + dict: A mapping with keys: + - "event_manager": MockEventManager instance used to simulate event publishing. + - "logger": MockLogger instance with async warning/error helpers. + - "encoder_cls": A mock EventEncoder class whose instances provide an `encode` method that returns a dict with `encoded` and `original_event` keys. + """ + # 1. Logger Mock + mock_logger_instance = MockLogger() + + # 2. EventManager Mock + mock_event_manager = MockEventManager() + + # 3. Encoder Mock - create a mock instance with a mocked encode method + mock_encoder_instance = MagicMock() + mock_encoder_instance.encode = MagicMock(side_effect=lambda payload: {"encoded": True, "original_event": payload}) + + # Mock the EventEncoder class to return our mock instance + mock_encoder_cls = MagicMock(return_value=mock_encoder_instance) + + # Patch the actual imports in the lifecycle_events module + with ( + patch("lfx.events.observability.lifecycle_events.logger", mock_logger_instance), + patch("lfx.events.observability.lifecycle_events.EventEncoder", mock_encoder_cls), + ): + yield { + "event_manager": mock_event_manager, + "logger": mock_logger_instance, + "encoder_cls": mock_encoder_cls, + } + + +@pytest.fixture(autouse=True) +def reset_mocks(mock_dependencies): + """Reset mock objects used by tests to a clean state. + + Parameters: + mock_dependencies (dict): Mapping of mocked test dependencies created by the fixture. + Expected keys: + - "logger": object with `awarning` and `aerror` AsyncMock attributes. + - "encoder_cls": mock class whose instantiation should be reset. + - "event_manager": object with `publish` AsyncMock attribute. + """ + # Ensure all mocks are reset before test execution + mock_dependencies["logger"].awarning.reset_mock() + mock_dependencies["logger"].aerror.reset_mock() + mock_dependencies["encoder_cls"].reset_mock() + mock_dependencies["event_manager"].publish.reset_mock() + + +# --- Test Classes (remain largely the same, but now used by pytest functions) --- + + +class TestClassWithCallbacks: + display_name = "ObservableTest" + + def before_callback_event(self, *args, **kwargs): + """Construct the payload for a 'start' lifecycle observable event. + + Parameters: + *args: Positional arguments passed to the observed method; only the count is used. + **kwargs: Keyword arguments passed to the observed method; only the set of keys is used. + + Returns: + dict: Payload containing: + - `lifecycle` (str): fixed value "start". + - `args_len` (int): number of positional arguments. + - `kw_keys` (list[str]): list of keyword argument names. + """ + return {"lifecycle": "start", "args_len": len(args), "kw_keys": list(kwargs.keys())} + + def after_callback_event(self, result: Any, *args, **kwargs): # noqa: ARG002 + """Produce an event payload for the post-execution lifecycle including the call result and the names of keyword arguments. + + Parameters: + result: The value returned by the observed function. + *args: Positional arguments passed to the observed function (ignored). + **kwargs: Keyword arguments passed to the observed function; their keys are captured. + + Returns: + dict: A mapping containing: + - `lifecycle`: the string `"end"`. + - `result`: the provided `result` value. + - `kw_keys`: list of keyword argument names present in `kwargs`. + """ + return {"lifecycle": "end", "result": result, "kw_keys": list(kwargs.keys())} + + def error_callback_event(self, exception: Exception, *args, **kwargs): # noqa: ARG002 + """Builds an error lifecycle payload describing an exception. + + Parameters: + exception (Exception): The exception that occurred. + *args: Positional arguments passed to the original call (ignored in the payload). + **kwargs: Keyword arguments passed to the original call; their keys are included in the payload. + + Returns: + dict: A payload with the following keys: + - `lifecycle`: the string "error". + - `error`: the exception message (`str(exception)`). + - `error_type`: the exception class name. + - `kw_keys`: list of keyword argument names present in `kwargs`. + """ + return { + "lifecycle": "error", + "error": str(exception), + "error_type": type(exception).__name__, + "kw_keys": list(kwargs.keys()), + } + + # Mock observable method + @observable + async def run_success(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + """Produce a processed string by prefixing the input with "Processed:". + + Parameters: + data (str): Input payload to be processed. + + Returns: + processed (str): The resulting string in the form "Processed:". + """ + await asyncio.sleep(0.001) + return f"Processed:{data}" + + @observable + async def run_exception(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + """Simulates asynchronous work then raises a ValueError to trigger error handling. + + Raises: + ValueError: Always raised to simulate a failure during execution. + """ + await asyncio.sleep(0.001) + err_msg = "Simulated failure" + raise ValueError + + +class TestClassWithoutCallbacks: + display_name = "NonObservableTest" + + @observable + async def run_success(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + """Produce a processed string by prefixing the input with "Processed:". + + Parameters: + data (str): Input payload to be processed. + + Returns: + processed (str): The resulting string in the form "Processed:". + """ + await asyncio.sleep(0.001) + return f"Processed:{data}" + + +# --- Pytest Test Functions --- + + +# Use pytest.mark.asyncio for running async functions +@pytest.mark.asyncio +async def test_successful_run_with_callbacks(mock_dependencies): + """Verify that when an observable-decorated method with lifecycle callbacks completes successfully, the encoder is invoked for BOTH the BEFORE and AFTER events with correct payloads and no warnings/errors are logged. + + Parameters: + mock_dependencies (dict): Fixture providing mocks: `logger` (with `awarning`/`aerror`), `event_manager`, `encoder_cls` (mocked encoder class whose `return_value` is the encoder instance with an `encode` method). + + """ + instance = TestClassWithCallbacks() + data = "test_data" + + event_manager = mock_dependencies["event_manager"] + + result = await instance.run_success(event_manager=event_manager, data=data) + + # 1. Assert result + assert result == f"Processed:{data}" + + # 2. Assert encoder was called twice (once for BEFORE, once for AFTER) + assert mock_dependencies["encoder_cls"].call_count == 2 + + # 3. Verify the encoder was called with the correct payloads + encoder_instance = mock_dependencies["encoder_cls"].return_value + assert encoder_instance.encode.call_count == 2 + + # Get the actual calls to encode + encode_calls = encoder_instance.encode.call_args_list + + # First call should be the BEFORE event + before_payload = encode_calls[0][0][0] + assert before_payload["lifecycle"] == "start" + assert before_payload["args_len"] == 0 + assert "event_manager" in before_payload["kw_keys"] + assert "data" in before_payload["kw_keys"] + + # Second call should be the AFTER event + after_payload = encode_calls[1][0][0] + assert after_payload["lifecycle"] == "end" + assert after_payload["result"] == f"Processed:{data}" + assert "event_manager" in after_payload["kw_keys"] + assert "data" in after_payload["kw_keys"] + + # 4. Assert no warnings or errors were logged + mock_dependencies["logger"].awarning.assert_not_called() + mock_dependencies["logger"].aerror.assert_not_called() + + +@pytest.mark.asyncio +async def test_exception_run_with_callbacks(mock_dependencies): + """Verify that when an observable-wrapped method raises an exception, the decorator logs the error, encodes both the start and error lifecycle events, and re-raises the exception. + + Asserts: + - A ValueError is raised by the wrapped method. + - logger.aerror was called once with the message "Exception in TestClassWithCallbacks: ". + - The EventEncoder class was instantiated twice and its instance `encode` was called twice. + - The first encoded payload has `lifecycle` == 'start'. + - The second encoded payload has `lifecycle` == 'error', `error` == '', and `error_type` == 'ValueError'. + - No warning logs were emitted via logger.awarning. + + Parameters: + mock_dependencies (dict): Fixture-provided mocks; expected keys include "logger", "event_manager", and "encoder_cls". + """ + instance = TestClassWithCallbacks() + + event_manager = mock_dependencies["event_manager"] + + # The decorator now re-raises the exception after logging and encoding the error event + with pytest.raises(ValueError, match=""): + await instance.run_exception(event_manager=event_manager, data="fail_data") + + # 1. Assert error was logged + mock_dependencies["logger"].aerror.assert_called_once() + mock_dependencies["logger"].aerror.assert_called_with("Exception in TestClassWithCallbacks: ") + + # 2. Assert encoder was called twice (once for BEFORE event, once for ERROR event) + assert mock_dependencies["encoder_cls"].call_count == 2 + + # 3. Verify the encoder was called with the correct payloads + encoder_instance = mock_dependencies["encoder_cls"].return_value + assert encoder_instance.encode.call_count == 2 + + # Get the actual calls to encode + encode_calls = encoder_instance.encode.call_args_list + + # First call should be the BEFORE event + before_payload = encode_calls[0][0][0] + assert before_payload["lifecycle"] == "start" + + # Second call should be the ERROR event + error_payload = encode_calls[1][0][0] + assert error_payload["lifecycle"] == "error" + assert error_payload["error"] == "" + assert error_payload["error_type"] == "ValueError" + + # 4. Assert no warnings were logged + mock_dependencies["logger"].awarning.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_without_event_manager(mock_dependencies): + instance = TestClassWithCallbacks() + data = "no_manager" + + # No event_manager passed (or explicitly passed as None) + result = await instance.run_success(event_manager=None, data=data) + + # 1. Assert result is correct + assert result == f"Processed:{data}" + + # 2. Assert warning for missing EventManager was logged twice (once for before, once for after) + assert mock_dependencies["logger"].awarning.call_count == 2 + mock_dependencies["logger"].awarning.assert_any_call( + "EventManager not available/provided, skipping observable event publishing from TestClassWithCallbacks" + ) + + +@pytest.mark.asyncio +async def test_run_without_callbacks(mock_dependencies): + instance = TestClassWithoutCallbacks() + data = "no_callbacks" + + event_manager = mock_dependencies["event_manager"] + + # Run the method with a manager + result = await instance.run_success(event_manager=event_manager, data=data) + + # 1. Assert result is correct + assert result == f"Processed:{data}" + + # 2. Assert warnings for missing callbacks were logged + assert mock_dependencies["logger"].awarning.call_count == 2 + mock_dependencies["logger"].awarning.assert_any_call( + "before_callback_event not implemented for TestClassWithoutCallbacks. Skipping event publishing." + ) + mock_dependencies["logger"].awarning.assert_any_call( + "after_callback_event not implemented for TestClassWithoutCallbacks. Skipping event publishing." + ) + + # 3. Assert no errors were logged + mock_dependencies["logger"].aerror.assert_not_called() diff --git a/src/lfx/tests/unit/graph/graph/test_base.py b/src/lfx/tests/unit/graph/graph/test_base.py index dddb61689a01..4081f4c89d8f 100644 --- a/src/lfx/tests/unit/graph/graph/test_base.py +++ b/src/lfx/tests/unit/graph/graph/test_base.py @@ -1,6 +1,7 @@ from collections import deque import pytest +from ag_ui.core import RunFinishedEvent, RunStartedEvent from lfx.components.input_output import ChatInput, ChatOutput, TextOutputComponent from lfx.graph import Graph from lfx.graph.graph.constants import Finish @@ -111,3 +112,117 @@ def test_graph_set_with_valid_component(): tool = YfinanceToolComponent() tool_calling_agent = ToolCallingAgentComponent() tool_calling_agent.set(tools=[tool]) + + +def test_graph_before_callback_event(): + """Test that before_callback_event generates the correct RunStartedEvent payload.""" + # Create a simple graph with two components and a flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call before_callback_event + event = graph.before_callback_event() + + # Assert the event is a RunStartedEvent + assert isinstance(event, RunStartedEvent) + + # Assert the event has the correct run_id and thread_id + assert event.run_id == graph._run_id + assert event.thread_id == graph.flow_id + assert event.thread_id == "test_flow_id" + + # Assert the raw_event contains metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + + # Assert the raw_event contains timestamp + assert "timestamp" in event.raw_event + assert isinstance(event.raw_event["timestamp"], float) + + # Assert the raw_event contains total_components + assert "total_components" in event.raw_event + assert event.raw_event["total_components"] == len(graph.vertices) + assert event.raw_event["total_components"] == 2 # chat_input and chat_output + + +def test_graph_after_callback_event(): + """Test that after_callback_event generates the correct RunFinishedEvent payload.""" + # Create a simple graph with two components and a flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call after_callback_event + event = graph.after_callback_event(result="test_result") + + # Assert the event is a RunFinishedEvent + assert isinstance(event, RunFinishedEvent) + + # Assert the event has the correct run_id and thread_id + assert event.run_id == graph._run_id + assert event.thread_id == graph.flow_id + assert event.thread_id == "test_flow_id" + + # Assert the result is None (as per the implementation) + assert event.result is None + + # Assert the raw_event contains metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + + # Assert the raw_event contains timestamp + assert "timestamp" in event.raw_event + assert isinstance(event.raw_event["timestamp"], float) + + # Assert the raw_event contains total_components + assert "total_components" in event.raw_event + assert event.raw_event["total_components"] == len(graph.vertices) + assert event.raw_event["total_components"] == 2 # chat_input and chat_output + + +def test_graph_raw_event_metrics(): + """Test that raw_event_metrics generates the correct metrics dictionary.""" + # Create a simple graph with flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call raw_event_metrics with optional fields + metrics = graph.raw_event_metrics({"custom_field": "custom_value"}) + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert custom field is present + assert "custom_field" in metrics + assert metrics["custom_field"] == "custom_value" + + +def test_graph_raw_event_metrics_no_optional_fields(): + """Test that raw_event_metrics works without optional fields.""" + # Create a simple graph with flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call raw_event_metrics without optional fields + metrics = graph.raw_event_metrics() + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert only timestamp is present (no optional fields) + assert len(metrics) == 1 diff --git a/src/lfx/tests/unit/graph/vertex/test_vertex_base.py b/src/lfx/tests/unit/graph/vertex/test_vertex_base.py index f1e1ea2623cf..357cc55d6c6f 100644 --- a/src/lfx/tests/unit/graph/vertex/test_vertex_base.py +++ b/src/lfx/tests/unit/graph/vertex/test_vertex_base.py @@ -7,6 +7,8 @@ from unittest.mock import Mock import pytest +from ag_ui.core import StepFinishedEvent, StepStartedEvent +from lfx.components.input_output import ChatInput from lfx.graph.edge.base import Edge from lfx.graph.vertex.base import ParameterHandler, Vertex from lfx.services.storage.service import StorageService @@ -263,3 +265,140 @@ def test_process_field_parameters_table_field_invalid(parameter_handler, mock_ve with pytest.raises(ValueError, match="Invalid value type"): parameter_handler.process_field_parameters() + + +def test_vertex_before_callback_event(): + """Test that Vertex.before_callback_event generates the correct StepStartedEvent payload.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call before_callback_event + event = vertex.before_callback_event() + + # Assert the event is a StepStartedEvent + assert isinstance(event, StepStartedEvent) + + # Assert the event has the correct step_name + assert event.step_name == vertex.display_name + + # Assert the raw_event contains the langflow metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + assert "langflow" in event.raw_event + + # Assert the langflow metrics contain expected fields + langflow_metrics = event.raw_event["langflow"] + assert isinstance(langflow_metrics, dict) + assert "timestamp" in langflow_metrics + assert isinstance(langflow_metrics["timestamp"], float) + assert "component_id" in langflow_metrics + assert langflow_metrics["component_id"] == vertex.id + assert langflow_metrics["component_id"] == "test_vertex_id" + + +def test_vertex_after_callback_event(): + """Test that Vertex.after_callback_event generates the correct StepFinishedEvent payload.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call after_callback_event with a result + test_result = "test_result_value" + event = vertex.after_callback_event(result=test_result) + + # Assert the event is a StepFinishedEvent + assert isinstance(event, StepFinishedEvent) + + # Assert the event has the correct step_name + assert event.step_name == vertex.display_name + + # Assert the raw_event contains the langflow metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + assert "langflow" in event.raw_event + + # Assert the langflow metrics contain expected fields + langflow_metrics = event.raw_event["langflow"] + assert isinstance(langflow_metrics, dict) + assert "timestamp" in langflow_metrics + assert isinstance(langflow_metrics["timestamp"], float) + assert "component_id" in langflow_metrics + assert langflow_metrics["component_id"] == vertex.id + assert langflow_metrics["component_id"] == "test_vertex_id" + + +def test_vertex_raw_event_metrics(): + """Test that Vertex.raw_event_metrics generates the correct metrics dictionary.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call raw_event_metrics with optional fields + metrics = vertex.raw_event_metrics({"custom_field": "custom_value"}) + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert custom field is present + assert "custom_field" in metrics + assert metrics["custom_field"] == "custom_value" + + # Assert raw_metrics from the vertex are included + # (raw_metrics should be an empty dict by default for ChatInput) + for key, value in vertex.raw_metrics.items(): + assert key in metrics + assert metrics[key] == value + + +def test_vertex_raw_event_metrics_no_optional_fields(): + """Test that Vertex.raw_event_metrics works without optional fields.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call raw_event_metrics without optional fields + metrics = vertex.raw_event_metrics() + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # The metrics should contain timestamp plus any raw_metrics from the vertex + # For ChatInput, raw_metrics is typically empty, so we should have at least timestamp + assert len(metrics) >= 1 diff --git a/uv.lock b/uv.lock index 5f01f43f7a85..c03f6d2e3ffe 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_machine == 'arm64' and sys_platform == 'darwin'", @@ -49,6 +49,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11", size = 380935, upload-time = "2025-11-21T11:27:44.522Z" }, ] +[[package]] +name = "ag-ui-protocol" +version = "0.1.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/bb/5a5ec893eea5805fb9a3db76a9888c3429710dfb6f24bbb37568f2cf7320/ag_ui_protocol-0.1.10.tar.gz", hash = "sha256:3213991c6b2eb24bb1a8c362ee270c16705a07a4c5962267a083d0959ed894f4", size = 6945, upload-time = "2025-11-06T15:17:17.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/78/eb55fabaab41abc53f52c0918a9a8c0f747807e5306273f51120fd695957/ag_ui_protocol-0.1.10-py3-none-any.whl", hash = "sha256:c81e6981f30aabdf97a7ee312bfd4df0cd38e718d9fc10019c7d438128b93ab5", size = 7889, upload-time = "2025-11-06T15:17:15.325Z" }, +] + [[package]] name = "agent-lifecycle-toolkit" version = "0.4.4" @@ -5531,6 +5543,7 @@ name = "langflow" version = "1.7.0" source = { editable = "." } dependencies = [ + { name = "ag-ui-protocol" }, { name = "agent-lifecycle-toolkit" }, { name = "aioboto3" }, { name = "aiofile" }, @@ -5740,6 +5753,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", specifier = ">=0.1.10" }, { name = "agent-lifecycle-toolkit", specifier = "~=0.4.4" }, { name = "aioboto3", specifier = ">=15.2.0,<16.0.0" }, { name = "aiofile", specifier = ">=3.9.0,<4.0.0" }, @@ -6393,6 +6407,7 @@ name = "lfx" version = "0.2.0" source = { editable = "src/lfx" } dependencies = [ + { name = "ag-ui-protocol" }, { name = "aiofile" }, { name = "aiofiles" }, { name = "asyncer" }, @@ -6442,6 +6457,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", specifier = ">=0.1.10" }, { name = "aiofile", specifier = ">=3.8.0,<4.0.0" }, { name = "aiofiles", specifier = ">=24.1.0,<25.0.0" }, { name = "asyncer", specifier = ">=0.0.8,<1.0.0" },