From 33e7af160bfc75ffe8d187ebc6e9c2dbbd70693e Mon Sep 17 00:00:00 2001 From: Debojit Kaushik Date: Tue, 25 Nov 2025 12:55:34 -0500 Subject: [PATCH] Proof of Concept: Adding lifecycle events. Added AGUI events to vertices and the graph. Created Decorator to add observability to Langflow workflows. TODO: Unit tests, feature gating to nullify impact in current production code paths. Added unit tests for lifecycle_events. Refraining from using event_manager for now since this capability is goig to stay dormant until APIs are ready for streaming. Added unit tests for before_callback_event and after_callback_event in graph and vertex classes. --- pyproject.toml | 3 +- src/lfx/pyproject.toml | 1 + .../src/lfx/events/observability/__init__.py | 0 .../events/observability/lifecycle_events.py | 111 ++++++++ src/lfx/src/lfx/graph/graph/base.py | 23 ++ src/lfx/src/lfx/graph/vertex/base.py | 42 ++- .../unit/events/observability/__init__.py | 0 .../observability/test_lifecycle_events.py | 247 ++++++++++++++++++ src/lfx/tests/unit/graph/graph/test_base.py | 115 ++++++++ .../unit/graph/vertex/test_vertex_base.py | 132 ++++++++++ uv.lock | 16 ++ 11 files changed, 688 insertions(+), 2 deletions(-) create mode 100644 src/lfx/src/lfx/events/observability/__init__.py create mode 100644 src/lfx/src/lfx/events/observability/lifecycle_events.py create mode 100644 src/lfx/tests/unit/events/observability/__init__.py create mode 100644 src/lfx/tests/unit/events/observability/test_lifecycle_events.py diff --git a/pyproject.toml b/pyproject.toml index 9277a91680d8..2347178472d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,7 +139,8 @@ dependencies = [ "langchain-mcp-adapters>=0.1.14,<0.2.0", # Pin to avoid incompatibility with langchain-core<1.0.0 "agent-lifecycle-toolkit~=0.4.4", "astrapy>=2.1.0,<3.0.0", - "aioboto3>=15.2.0,<16.0.0" + "aioboto3>=15.2.0,<16.0.0", + "ag-ui-protocol>=0.1.10", ] diff --git a/src/lfx/pyproject.toml b/src/lfx/pyproject.toml index 8b3412b6fd42..b94038d1c2f6 100644 --- a/src/lfx/pyproject.toml +++ b/src/lfx/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "filelock>=3.20.0", "pypdf>=5.1.0", "cryptography>=43.0.0", + "ag-ui-protocol>=0.1.10", ] [project.scripts] diff --git a/src/lfx/src/lfx/events/observability/__init__.py b/src/lfx/src/lfx/events/observability/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/lfx/src/lfx/events/observability/lifecycle_events.py b/src/lfx/src/lfx/events/observability/lifecycle_events.py new file mode 100644 index 000000000000..fa41d1343c34 --- /dev/null +++ b/src/lfx/src/lfx/events/observability/lifecycle_events.py @@ -0,0 +1,111 @@ +import functools +from collections.abc import Awaitable, Callable +from typing import Any + +from ag_ui.encoder.encoder import EventEncoder + +from lfx.log.logger import logger + +AsyncMethod = Callable[..., Awaitable[Any]] + +encoder: EventEncoder = EventEncoder() + + +def observable(observed_method: AsyncMethod) -> AsyncMethod: + """Decorator to make an async method observable by emitting lifecycle events. + + Decorated classes are expected to implement specific methods to emit AGUI events: + - `before_callback_event(*args, **kwargs)`: Called before the decorated method executes. + It should return a dictionary representing the event payload. + - `after_callback_event(result, *args, **kwargs)`: Called after the decorated method + successfully completes. It should return a dictionary representing the event payload. + The `result` of the decorated method is passed as the first argument. + - `error_callback_event(exception, *args, **kwargs)`: (Optional) Called if the decorated + method raises an exception. It should return a dictionary representing the error event payload. + The `exception` is passed as the first argument. + + If these methods are implemented, the decorator will call them to generate event payloads. + If an implementation is missing, the corresponding event publishing will be skipped without error. + + Payloads returned by these methods can include custom metrics by placing them + under the 'langflow' key within the 'raw_events' dictionary. + + Example: + class MyClass: + display_name = "My Observable Class" + + def before_callback_event(self, *args, **kwargs): + return {"event_name": "my_method_started", "data": {"input_args": args}} + + async def my_method(self, event_manager: EventManager, data: str): + # ... method logic ... + return "processed_data" + + def after_callback_event(self, result, *args, **kwargs): + return {"event_name": "my_method_completed", "data": {"output": result}} + + def error_callback_event(self, exception, *args, **kwargs): + return {"event_name": "my_method_failed", "error": str(exception)} + + @observable + async def my_observable_method(self, event_manager: EventManager, data: str): + # ... method logic ... + pass + """ + + async def check_event_manager(self, **kwargs): + if "event_manager" not in kwargs or kwargs["event_manager"] is None: + await logger.awarning( + f"EventManager not available/provided, skipping observable event publishing " + f"from {self.__class__.__name__}" + ) + return False + return True + + async def before_callback(self, *args, **kwargs): + if not await check_event_manager(self, **kwargs): + return + + if hasattr(self, "before_callback_event"): + event_payload = self.before_callback_event(*args, **kwargs) + event_payload = encoder.encode(event_payload) + # TODO: Publish event per request, would required context based queues + else: + await logger.awarning( + f"before_callback_event not implemented for {self.__class__.__name__}. Skipping event publishing." + ) + + async def after_callback(self, res: Any | None = None, *args, **kwargs): + if not await check_event_manager(self, **kwargs): + return + if hasattr(self, "after_callback_event"): + event_payload = self.after_callback_event(res, *args, **kwargs) + event_payload = encoder.encode(event_payload) + # TODO: Publish event per request, would required context based queues + else: + await logger.awarning( + f"after_callback_event not implemented for {self.__class__.__name__}. Skipping event publishing." + ) + + @functools.wraps(observed_method) + async def wrapper(self, *args, **kwargs): + await before_callback(self, *args, **kwargs) + result = None + try: + result = await observed_method(self, *args, **kwargs) + await after_callback(self, result, *args, **kwargs) + except Exception as e: + await logger.aerror(f"Exception in {self.__class__.__name__}: {e}") + if hasattr(self, "error_callback_event"): + try: + event_payload = self.error_callback_event(e, *args, **kwargs) + event_payload = encoder.encode(event_payload) + # TODO: Publish event per request, would required context based queues + except Exception as callback_e: # noqa: BLE001 + await logger.aerror( + f"Exception during error_callback_event for {self.__class__.__name__}: {callback_e}" + ) + raise + return result + + return wrapper diff --git a/src/lfx/src/lfx/graph/graph/base.py b/src/lfx/src/lfx/graph/graph/base.py index d7ac23785c1d..a8f1356b9eb0 100644 --- a/src/lfx/src/lfx/graph/graph/base.py +++ b/src/lfx/src/lfx/graph/graph/base.py @@ -15,6 +15,9 @@ from itertools import chain from typing import TYPE_CHECKING, Any, cast +from ag_ui.core import RunFinishedEvent, RunStartedEvent + +from lfx.events.observability.lifecycle_events import observable from lfx.exceptions.component import ComponentBuildError from lfx.graph.edge.base import CycleEdge, Edge from lfx.graph.graph.constants import Finish, lazy_load_vertex_dict @@ -728,6 +731,7 @@ def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input raise ValueError(msg) vertex.update_raw_params(inputs, overwrite=True) + @observable async def _run( self, *, @@ -2309,3 +2313,22 @@ def __to_dict(self) -> dict[str, dict[str, list[str]]]: predecessors = [i.id for i in self.get_predecessors(vertex)] result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}} return result + + def raw_event_metrics(self, optional_fields: dict | None = None) -> dict: + if optional_fields is None: + optional_fields = {} + import time + + return {"timestamp": time.time(), **optional_fields} + + def before_callback_event(self, *args, **kwargs) -> RunStartedEvent: # noqa: ARG002 + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"total_components": len(self.vertices)}) + return RunStartedEvent(run_id=self._run_id, thread_id=self.flow_id, raw_event=metrics) + + def after_callback_event(self, result: Any = None, *args, **kwargs) -> RunFinishedEvent: # noqa: ARG002 + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"total_components": len(self.vertices)}) + return RunFinishedEvent(run_id=self._run_id, thread_id=self.flow_id, result=None, raw_event=metrics) diff --git a/src/lfx/src/lfx/graph/vertex/base.py b/src/lfx/src/lfx/graph/vertex/base.py index 30f8380e9e52..525390158263 100644 --- a/src/lfx/src/lfx/graph/vertex/base.py +++ b/src/lfx/src/lfx/graph/vertex/base.py @@ -8,6 +8,9 @@ from enum import Enum from typing import TYPE_CHECKING, Any +from ag_ui.core import StepFinishedEvent, StepStartedEvent + +from lfx.events.observability.lifecycle_events import observable from lfx.exceptions.component import ComponentBuildError from lfx.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData from lfx.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction @@ -179,6 +182,7 @@ def get_built_result(self): if isinstance(self.built_result, UnbuiltResult): return {} + return self.built_result if isinstance(self.built_result, dict) else {"result": self.built_result} def set_artifacts(self) -> None: @@ -380,6 +384,7 @@ def instantiate_component(self, user_id=None) -> None: vertex=self, ) + @observable async def _build( self, fallback_to_env_vars, @@ -389,7 +394,6 @@ async def _build( """Initiate the build process.""" await logger.adebug(f"Building {self.display_name}") await self._build_each_vertex_in_params_dict() - if self.base_type is None: msg = f"Base type for vertex {self.display_name} not found" raise ValueError(msg) @@ -833,3 +837,39 @@ def apply_on_outputs(self, func: Callable[[Any], Any]) -> None: return # Apply the function to each output [func(output) for output in self.custom_component.get_outputs_map().values()] + + # AGUI/AG UI Event Streaming Callbacks/Methods - (Optional, see Observable decorator) + def raw_event_metrics(self, optional_fields: dict | None) -> dict: + """This method is used to get the metrics of the vertex by the Observable decorator. + + If the vertex has a get_metrics method, it will be called, and the metrics will be captured + to stream back to the user in an AGUI compliant format. + Additional fields/metrics to be captured can be modified in this method, or in the callback methods, + which are before_callback_event and after_callback_event before returning the AGUI event. + """ + if optional_fields is None: + optional_fields = {} + import time + + return {"timestamp": time.time(), **optional_fields} + + def before_callback_event(self, *args, **kwargs) -> StepStartedEvent: # noqa: ARG002 + """Should be a AGUI compatible event. + + VERTEX class generates a StepStartedEvent event. + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"component_id": self.id}) + + return StepStartedEvent(step_name=self.display_name, raw_event={"langflow": metrics}) + + def after_callback_event(self, result, *args, **kwargs) -> StepFinishedEvent: # noqa: ARG002 + """Should be a AGUI compatible event. + + VERTEX class generates a StepFinishedEvent event. + """ + metrics = {} + if hasattr(self, "raw_event_metrics"): + metrics = self.raw_event_metrics({"component_id": self.id}) + return StepFinishedEvent(step_name=self.display_name, raw_event={"langflow": metrics}) diff --git a/src/lfx/tests/unit/events/observability/__init__.py b/src/lfx/tests/unit/events/observability/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/lfx/tests/unit/events/observability/test_lifecycle_events.py b/src/lfx/tests/unit/events/observability/test_lifecycle_events.py new file mode 100644 index 000000000000..f9e817d6cd54 --- /dev/null +++ b/src/lfx/tests/unit/events/observability/test_lifecycle_events.py @@ -0,0 +1,247 @@ +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from ag_ui.core import CustomEvent, StepFinishedEvent, StepStartedEvent + +# Import the actual decorator we want to test +from lfx.events.observability.lifecycle_events import observable + + +# Mock classes for dependencies +class MockEventManager: + """Mock for lfx.events.event_manager.EventManager.""" + + def __init__(self): + # We'll use AsyncMock for publish + self.publish = AsyncMock() + + +class MockLogger: + """Mock for lfx.log.logger.logger.""" + + def __init__(self): + self.awarning = AsyncMock() + self.aerror = AsyncMock() + + +# --- Pytest Fixtures --- + + +@pytest.fixture +def mock_dependencies(): + """Provides mocked instances of external dependencies and patches them.""" + # 1. Logger Mock + mock_logger_instance = MockLogger() + + # 2. EventManager Mock + mock_event_manager = MockEventManager() + + # 3. Encoder Mock - create a mock instance with a mocked encode method + mock_encoder_instance = MagicMock() + # The encode method should return a string (SSE format) + mock_encoder_instance.encode = MagicMock(side_effect=lambda payload: f"data: {payload}\n\n") + + # Patch the actual imports in the lifecycle_events module + with ( + patch("lfx.events.observability.lifecycle_events.logger", mock_logger_instance), + patch("lfx.events.observability.lifecycle_events.encoder", mock_encoder_instance), + ): + yield { + "event_manager": mock_event_manager, + "logger": mock_logger_instance, + "encoder": mock_encoder_instance, + } + + +@pytest.fixture(autouse=True) +def reset_mocks(mock_dependencies): + """Resets the state of the mocks before each test.""" + # Ensure all mocks are reset before test execution + mock_dependencies["logger"].awarning.reset_mock() + mock_dependencies["logger"].aerror.reset_mock() + mock_dependencies["encoder"].encode.reset_mock() + + +# --- Test Classes (remain largely the same, but now used by pytest functions) --- + + +class TestClassWithCallbacks: + display_name = "ObservableTest" + + def before_callback_event(self, *args, **kwargs): + return StepStartedEvent( + step_name=self.display_name, + raw_event={"lifecycle": "start", "args_len": len(args), "kw_keys": list(kwargs.keys())}, + ) + + def after_callback_event(self, result: Any, *args, **kwargs): # noqa: ARG002 + return StepFinishedEvent( + step_name=self.display_name, + raw_event={"lifecycle": "end", "result": result, "kw_keys": list(kwargs.keys())}, + ) + + def error_callback_event(self, exception: Exception, *args, **kwargs): # noqa: ARG002 + return CustomEvent( + name="error", + value={ + "error": str(exception), + "error_type": type(exception).__name__, + }, + raw_event={"lifecycle": "error", "kw_keys": list(kwargs.keys())}, + ) + + # Mock observable method + @observable + async def run_success(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + await asyncio.sleep(0.001) + return f"Processed:{data}" + + @observable + async def run_exception(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + await asyncio.sleep(0.001) + raise ValueError + + +class TestClassWithoutCallbacks: + display_name = "NonObservableTest" + + @observable + async def run_success(self, event_manager: MockEventManager, data: str) -> str: # noqa: ARG002 + await asyncio.sleep(0.001) + return f"Processed:{data}" + + +# --- Pytest Test Functions --- + + +# Use pytest.mark.asyncio for running async functions +@pytest.mark.asyncio +async def test_successful_run_with_callbacks(mock_dependencies): + instance = TestClassWithCallbacks() + data = "test_data" + + event_manager = mock_dependencies["event_manager"] + + result = await instance.run_success(event_manager=event_manager, data=data) + + # 1. Assert result + assert result == f"Processed:{data}" + + # 2. Assert encoder was called twice (once for BEFORE, once for AFTER) + assert mock_dependencies["encoder"].encode.call_count == 2 + + # 3. Verify the encoder was called with the correct payloads + encoder_instance = mock_dependencies["encoder"] + assert encoder_instance.encode.call_count == 2 + + # Get the actual calls to encode + encode_calls = encoder_instance.encode.call_args_list + + # First call should be the BEFORE event (StepStartedEvent) + before_event = encode_calls[0][0][0] + assert isinstance(before_event, StepStartedEvent) + assert before_event.step_name == "ObservableTest" + assert before_event.raw_event["lifecycle"] == "start" + assert before_event.raw_event["args_len"] == 0 + assert "event_manager" in before_event.raw_event["kw_keys"] + assert "data" in before_event.raw_event["kw_keys"] + + # Second call should be the AFTER event (StepFinishedEvent) + after_event = encode_calls[1][0][0] + assert isinstance(after_event, StepFinishedEvent) + assert after_event.step_name == "ObservableTest" + assert after_event.raw_event["lifecycle"] == "end" + assert after_event.raw_event["result"] == f"Processed:{data}" + assert "event_manager" in after_event.raw_event["kw_keys"] + assert "data" in after_event.raw_event["kw_keys"] + + # 4. Assert no warnings or errors were logged + mock_dependencies["logger"].awarning.assert_not_called() + mock_dependencies["logger"].aerror.assert_not_called() + + +@pytest.mark.asyncio +async def test_exception_run_with_callbacks(mock_dependencies): + instance = TestClassWithCallbacks() + + event_manager = mock_dependencies["event_manager"] + + # The decorator now re-raises the exception after logging and encoding the error event + with pytest.raises(ValueError): # noqa: PT011 + await instance.run_exception(event_manager=event_manager, data="fail_data") + + # 1. Assert error was logged + mock_dependencies["logger"].aerror.assert_called_once() + mock_dependencies["logger"].aerror.assert_called_with("Exception in TestClassWithCallbacks: ") + + # 2. Assert encoder was called twice (once for BEFORE event, once for ERROR event) + assert mock_dependencies["encoder"].encode.call_count == 2 + + # 3. Verify the encoder was called with the correct payloads + encoder_instance = mock_dependencies["encoder"] + assert encoder_instance.encode.call_count == 2 + + # Get the actual calls to encode + encode_calls = encoder_instance.encode.call_args_list + + # First call should be the BEFORE event (StepStartedEvent) + before_event = encode_calls[0][0][0] + assert isinstance(before_event, StepStartedEvent) + assert before_event.raw_event["lifecycle"] == "start" + + # Second call should be the ERROR event (CustomEvent) + error_event = encode_calls[1][0][0] + assert isinstance(error_event, CustomEvent) + assert error_event.name == "error" + assert error_event.value["error"] == "" + assert error_event.value["error_type"] == "ValueError" + assert error_event.raw_event["lifecycle"] == "error" + + # 4. Assert no warnings were logged + mock_dependencies["logger"].awarning.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_without_event_manager(mock_dependencies): + instance = TestClassWithCallbacks() + data = "no_manager" + + # No event_manager passed (or explicitly passed as None) + result = await instance.run_success(event_manager=None, data=data) + + # 1. Assert result is correct + assert result == f"Processed:{data}" + + # 2. Assert warning for missing EventManager was logged twice (once for before, once for after) + assert mock_dependencies["logger"].awarning.call_count == 2 + mock_dependencies["logger"].awarning.assert_any_call( + "EventManager not available/provided, skipping observable event publishing from TestClassWithCallbacks" + ) + + +@pytest.mark.asyncio +async def test_run_without_callbacks(mock_dependencies): + instance = TestClassWithoutCallbacks() + data = "no_callbacks" + + event_manager = mock_dependencies["event_manager"] + + # Run the method with a manager + result = await instance.run_success(event_manager=event_manager, data=data) + + # 1. Assert result is correct + assert result == f"Processed:{data}" + + # 2. Assert warnings for missing callbacks were logged + assert mock_dependencies["logger"].awarning.call_count == 2 + mock_dependencies["logger"].awarning.assert_any_call( + "before_callback_event not implemented for TestClassWithoutCallbacks. Skipping event publishing." + ) + mock_dependencies["logger"].awarning.assert_any_call( + "after_callback_event not implemented for TestClassWithoutCallbacks. Skipping event publishing." + ) + + # 3. Assert no errors were logged + mock_dependencies["logger"].aerror.assert_not_called() diff --git a/src/lfx/tests/unit/graph/graph/test_base.py b/src/lfx/tests/unit/graph/graph/test_base.py index 950d60a1c35d..8b98da9c3f88 100644 --- a/src/lfx/tests/unit/graph/graph/test_base.py +++ b/src/lfx/tests/unit/graph/graph/test_base.py @@ -1,6 +1,7 @@ from collections import deque import pytest +from ag_ui.core import RunFinishedEvent, RunStartedEvent from lfx.components.input_output import ChatInput, ChatOutput, TextOutputComponent from lfx.components.langchain_utilities.tool_calling import ToolCallingAgentComponent from lfx.components.processing.combine_text import CombineTextComponent @@ -247,3 +248,117 @@ def test_graph_set_with_valid_component(): tool = YfinanceToolComponent() tool_calling_agent = ToolCallingAgentComponent() tool_calling_agent.set(tools=[tool]) + + +def test_graph_before_callback_event(): + """Test that before_callback_event generates the correct RunStartedEvent payload.""" + # Create a simple graph with two components and a flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call before_callback_event + event = graph.before_callback_event() + + # Assert the event is a RunStartedEvent + assert isinstance(event, RunStartedEvent) + + # Assert the event has the correct run_id and thread_id + assert event.run_id == graph._run_id + assert event.thread_id == graph.flow_id + assert event.thread_id == "test_flow_id" + + # Assert the raw_event contains metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + + # Assert the raw_event contains timestamp + assert "timestamp" in event.raw_event + assert isinstance(event.raw_event["timestamp"], float) + + # Assert the raw_event contains total_components + assert "total_components" in event.raw_event + assert event.raw_event["total_components"] == len(graph.vertices) + assert event.raw_event["total_components"] == 2 # chat_input and chat_output + + +def test_graph_after_callback_event(): + """Test that after_callback_event generates the correct RunFinishedEvent payload.""" + # Create a simple graph with two components and a flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call after_callback_event + event = graph.after_callback_event(result="test_result") + + # Assert the event is a RunFinishedEvent + assert isinstance(event, RunFinishedEvent) + + # Assert the event has the correct run_id and thread_id + assert event.run_id == graph._run_id + assert event.thread_id == graph.flow_id + assert event.thread_id == "test_flow_id" + + # Assert the result is None (as per the implementation) + assert event.result is None + + # Assert the raw_event contains metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + + # Assert the raw_event contains timestamp + assert "timestamp" in event.raw_event + assert isinstance(event.raw_event["timestamp"], float) + + # Assert the raw_event contains total_components + assert "total_components" in event.raw_event + assert event.raw_event["total_components"] == len(graph.vertices) + assert event.raw_event["total_components"] == 2 # chat_input and chat_output + + +def test_graph_raw_event_metrics(): + """Test that raw_event_metrics generates the correct metrics dictionary.""" + # Create a simple graph with flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call raw_event_metrics with optional fields + metrics = graph.raw_event_metrics({"custom_field": "custom_value"}) + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert custom field is present + assert "custom_field" in metrics + assert metrics["custom_field"] == "custom_value" + + +def test_graph_raw_event_metrics_no_optional_fields(): + """Test that raw_event_metrics works without optional fields.""" + # Create a simple graph with flow_id + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output, flow_id="test_flow_id") + + # Call raw_event_metrics without optional fields + metrics = graph.raw_event_metrics() + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert only timestamp is present (no optional fields) + assert len(metrics) == 1 diff --git a/src/lfx/tests/unit/graph/vertex/test_vertex_base.py b/src/lfx/tests/unit/graph/vertex/test_vertex_base.py index f1e1ea2623cf..27271650210d 100644 --- a/src/lfx/tests/unit/graph/vertex/test_vertex_base.py +++ b/src/lfx/tests/unit/graph/vertex/test_vertex_base.py @@ -7,6 +7,8 @@ from unittest.mock import Mock import pytest +from ag_ui.core import StepFinishedEvent, StepStartedEvent +from lfx.components.input_output import ChatInput from lfx.graph.edge.base import Edge from lfx.graph.vertex.base import ParameterHandler, Vertex from lfx.services.storage.service import StorageService @@ -263,3 +265,133 @@ def test_process_field_parameters_table_field_invalid(parameter_handler, mock_ve with pytest.raises(ValueError, match="Invalid value type"): parameter_handler.process_field_parameters() + + +def test_vertex_before_callback_event(): + """Test that Vertex.before_callback_event generates the correct StepStartedEvent payload.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call before_callback_event + event = vertex.before_callback_event() + + # Assert the event is a StepStartedEvent + assert isinstance(event, StepStartedEvent) + + # Assert the event has the correct step_name + assert event.step_name == vertex.display_name + + # Assert the raw_event contains the langflow metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + assert "langflow" in event.raw_event + + # Assert the langflow metrics contain expected fields + langflow_metrics = event.raw_event["langflow"] + assert isinstance(langflow_metrics, dict) + assert "timestamp" in langflow_metrics + assert isinstance(langflow_metrics["timestamp"], float) + assert "component_id" in langflow_metrics + assert langflow_metrics["component_id"] == vertex.id + assert langflow_metrics["component_id"] == "test_vertex_id" + + +def test_vertex_after_callback_event(): + """Test that Vertex.after_callback_event generates the correct StepFinishedEvent payload.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call after_callback_event with a result + test_result = "test_result_value" + event = vertex.after_callback_event(result=test_result) + + # Assert the event is a StepFinishedEvent + assert isinstance(event, StepFinishedEvent) + + # Assert the event has the correct step_name + assert event.step_name == vertex.display_name + + # Assert the raw_event contains the langflow metrics + assert event.raw_event is not None + assert isinstance(event.raw_event, dict) + assert "langflow" in event.raw_event + + # Assert the langflow metrics contain expected fields + langflow_metrics = event.raw_event["langflow"] + assert isinstance(langflow_metrics, dict) + assert "timestamp" in langflow_metrics + assert isinstance(langflow_metrics["timestamp"], float) + assert "component_id" in langflow_metrics + assert langflow_metrics["component_id"] == vertex.id + assert langflow_metrics["component_id"] == "test_vertex_id" + + +def test_vertex_raw_event_metrics(): + """Test that Vertex.raw_event_metrics generates the correct metrics dictionary.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call raw_event_metrics with optional fields + metrics = vertex.raw_event_metrics({"custom_field": "custom_value"}) + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # Assert custom field is present + assert "custom_field" in metrics + assert metrics["custom_field"] == "custom_value" + + +def test_vertex_raw_event_metrics_no_optional_fields(): + """Test that Vertex.raw_event_metrics works without optional fields.""" + # Create a graph with a ChatInput component, which creates a vertex + from lfx.graph import Graph + + chat_input = ChatInput(_id="test_vertex_id") + chat_output = ChatInput(_id="output_id") # Need two components for Graph + graph = Graph(chat_input, chat_output, flow_id="test_flow") + + # Get the vertex from the graph + vertex = graph.vertices[0] # First vertex should be chat_input + assert vertex.id == "test_vertex_id" + + # Call raw_event_metrics without optional fields (pass None) + metrics = vertex.raw_event_metrics(None) + + # Assert metrics is a dictionary + assert isinstance(metrics, dict) + + # Assert timestamp is present and is a float + assert "timestamp" in metrics + assert isinstance(metrics["timestamp"], float) + + # The metrics should contain only timestamp when no optional fields are provided + assert len(metrics) == 1 diff --git a/uv.lock b/uv.lock index 891025f35cb9..8061f3d4e0e1 100644 --- a/uv.lock +++ b/uv.lock @@ -50,6 +50,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11", size = 380935, upload-time = "2025-11-21T11:27:44.522Z" }, ] +[[package]] +name = "ag-ui-protocol" +version = "0.1.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/bb/5a5ec893eea5805fb9a3db76a9888c3429710dfb6f24bbb37568f2cf7320/ag_ui_protocol-0.1.10.tar.gz", hash = "sha256:3213991c6b2eb24bb1a8c362ee270c16705a07a4c5962267a083d0959ed894f4", size = 6945, upload-time = "2025-11-06T15:17:17.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/78/eb55fabaab41abc53f52c0918a9a8c0f747807e5306273f51120fd695957/ag_ui_protocol-0.1.10-py3-none-any.whl", hash = "sha256:c81e6981f30aabdf97a7ee312bfd4df0cd38e718d9fc10019c7d438128b93ab5", size = 7889, upload-time = "2025-11-06T15:17:15.325Z" }, +] + [[package]] name = "agent-lifecycle-toolkit" version = "0.4.5" @@ -5531,6 +5543,7 @@ name = "langflow" version = "1.8.0" source = { editable = "." } dependencies = [ + { name = "ag-ui-protocol" }, { name = "agent-lifecycle-toolkit" }, { name = "aioboto3" }, { name = "aiofile" }, @@ -5741,6 +5754,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", specifier = ">=0.1.10" }, { name = "agent-lifecycle-toolkit", specifier = "~=0.4.4" }, { name = "aioboto3", specifier = ">=15.2.0,<16.0.0" }, { name = "aiofile", specifier = ">=3.9.0,<4.0.0" }, @@ -6395,6 +6409,7 @@ name = "lfx" version = "0.3.0" source = { editable = "src/lfx" } dependencies = [ + { name = "ag-ui-protocol" }, { name = "aiofile" }, { name = "aiofiles" }, { name = "asyncer" }, @@ -6451,6 +6466,7 @@ integration = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", specifier = ">=0.1.10" }, { name = "aiofile", specifier = ">=3.8.0,<4.0.0" }, { name = "aiofiles", specifier = ">=24.1.0,<25.0.0" }, { name = "asyncer", specifier = ">=0.0.8,<1.0.0" },