diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 4b62beba0c388..03095d06644f9 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -12,6 +12,8 @@ from pydantic import ConfigDict from typing_extensions import TypedDict, override +import langchain_core.callbacks.manager as cb_manager +import langchain_core.runnables.config as run_config from langchain_core.runnables.base import ( Runnable, RunnableSerializable, @@ -21,6 +23,7 @@ RunnableConfig, get_config_list, get_executor_for_config, + set_config_context, ) from langchain_core.runnables.utils import ( ConfigurableFieldSpec, @@ -80,7 +83,8 @@ def __init__( runnables: A mapping of keys to `Runnable` objects. """ super().__init__( - runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} + runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}, + name="RouterRunnable", ) model_config = ConfigDict( @@ -107,14 +111,42 @@ def get_lc_namespace(cls) -> list[str]: def invoke( self, input: RouterInput, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: + config = run_config.ensure_config(config) + callback_manager = cb_manager.CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + run_manager = callback_manager.on_chain_start( + None, + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), + ) key = input["key"] actual_input = input["input"] if key not in self.runnables: msg = f"No runnable associated with key '{key}'" + run_manager.on_chain_error(ValueError(msg)) raise ValueError(msg) - runnable = self.runnables[key] - return runnable.invoke(actual_input, config) + try: + runnable = self.runnables[key] + child_config = run_config.patch_config( + config, callbacks=run_manager.get_child() + ) + with set_config_context(child_config) as context: + output = context.run(runnable.invoke, actual_input, child_config) + except BaseException as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end(output) + return output @override async def ainvoke( @@ -123,14 +155,34 @@ async def ainvoke( config: RunnableConfig | None = None, **kwargs: Any | None, ) -> Output: + config = run_config.ensure_config(config) + callback_manager = run_config.get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( + None, + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), + ) key = input["key"] actual_input = input["input"] if key not in self.runnables: msg = f"No runnable associated with key '{key}'" + await run_manager.on_chain_error(ValueError(msg)) raise ValueError(msg) - runnable = self.runnables[key] - return await runnable.ainvoke(actual_input, config) + try: + runnable = self.runnables[key] + child_config = run_config.patch_config( + config, callbacks=run_manager.get_child() + ) + with set_config_context(child_config) as context: + output = await context.run(runnable.ainvoke, actual_input, child_config) + except BaseException as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end(output) + return output @override def batch( diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index da03e5052046c..481cb3b124068 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -1799,6 +1799,7 @@ "RouterRunnable" ], "kwargs": { + "name": "RouterRunnable", "runnables": { "math": { "lc": 1, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 91702130ade85..c851470b531b6 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -2901,8 +2901,9 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 2 router_run = parent_run.child_runs[1] - assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter - assert len(router_run.child_runs) == 2 + assert router_run.name == "RouterRunnable" + assert len(router_run.child_runs) == 1 + assert len(router_run.child_runs[0].child_runs) == 2 async def test_router_runnable_async() -> None: