Skip to content
62 changes: 57 additions & 5 deletions libs/core/langchain_core/runnables/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pydantic import ConfigDict
from typing_extensions import TypedDict, override

import langchain_core.callbacks.manager as cb_manager
import langchain_core.runnables.config as run_config
from langchain_core.runnables.base import (
Runnable,
RunnableSerializable,
Expand All @@ -21,6 +23,7 @@
RunnableConfig,
get_config_list,
get_executor_for_config,
set_config_context,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Expand Down Expand Up @@ -80,7 +83,8 @@ def __init__(
runnables: A mapping of keys to `Runnable` objects.
"""
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()},
name="RouterRunnable",
)

model_config = ConfigDict(
Expand All @@ -107,14 +111,42 @@ def get_lc_namespace(cls) -> list[str]:
def invoke(
self, input: RouterInput, config: RunnableConfig | None = None, **kwargs: Any
) -> Output:
config = run_config.ensure_config(config)
callback_manager = cb_manager.CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
msg = f"No runnable associated with key '{key}'"
run_manager.on_chain_error(ValueError(msg))
raise ValueError(msg)

runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
try:
runnable = self.runnables[key]
child_config = run_config.patch_config(
config, callbacks=run_manager.get_child()
)
with set_config_context(child_config) as context:
output = context.run(runnable.invoke, actual_input, child_config)
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
return output

@override
async def ainvoke(
Expand All @@ -123,14 +155,34 @@ async def ainvoke(
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Output:
config = run_config.ensure_config(config)
callback_manager = run_config.get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
msg = f"No runnable associated with key '{key}'"
await run_manager.on_chain_error(ValueError(msg))
raise ValueError(msg)

runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
try:
runnable = self.runnables[key]
child_config = run_config.patch_config(
config, callbacks=run_manager.get_child()
)
with set_config_context(child_config) as context:
output = await context.run(runnable.ainvoke, actual_input, child_config)
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
return output

@override
def batch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,7 @@
"RouterRunnable"
],
"kwargs": {
"name": "RouterRunnable",
"runnables": {
"math": {
"lc": 1,
Expand Down
5 changes: 3 additions & 2 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,8 +2901,9 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
assert len(router_run.child_runs) == 2
assert router_run.name == "RouterRunnable"
assert len(router_run.child_runs) == 1
assert len(router_run.child_runs[0].child_runs) == 2


async def test_router_runnable_async() -> None:
Expand Down