From da745fd747556389b88596605db2ff26a446179c Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 08:25:15 -0500 Subject: [PATCH 01/15] feat(rlm): add multimodal media support (Audio/Image) to RLM Add llm_query_with_media() tool that lets RLM's sandboxed code send Audio and Image inputs to sub-LLM calls for multimodal reasoning. Changes: - python_interpreter.py: Handle Audio/Image in _serialize_value() and _to_json_compatible() by converting to descriptor strings - rlm.py: Add media registry that captures Audio/Image inputs, new llm_query_with_media(prompt, *media_var_names) tool, and dynamic instruction generation when media fields are detected Media objects can't be serialized into the Deno/WASM sandbox, so they are stored in a registry and referenced by variable name. The sandbox sees descriptor strings; actual media data flows through llm_query_with_media() to the sub-LLM as proper multimodal content. Tested with Gemini 3 Flash (via OpenRouter) for both audio transcription and image description tasks. --- dspy/predict/rlm.py | 145 ++++++++++++++++++++++++-- dspy/primitives/python_interpreter.py | 40 +++++++ 2 files changed, 175 insertions(+), 10 deletions(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 66c011d234..55fcf4ca92 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -22,6 +22,8 @@ import dspy from dspy.adapters.types.tool import Tool from dspy.adapters.utils import parse_value, translate_field_type +from dspy.adapters.types.audio import Audio +from dspy.adapters.types.image import Image from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreter, CodeInterpreterError, FinalOutput from dspy.primitives.module import Module from dspy.primitives.prediction import Prediction @@ -30,6 +32,20 @@ from dspy.signatures.signature import ensure_signature from dspy.utils.annotation import experimental +# Types considered "media" — their data can't be serialized into the sandbox +# but can be forwarded to sub-LLM calls via llm_query_with_media(). +_MEDIA_TYPES = (Audio, Image) + + +def _is_media(value): + """Check if a value is a media type (Audio or Image).""" + return isinstance(value, _MEDIA_TYPES) + + +def _format_media_for_lm(value): + """Convert a media object to LM message content parts.""" + return value.format() + if TYPE_CHECKING: from dspy.signatures.signature import Signature @@ -46,7 +62,7 @@ Available: - Variables: {inputs} (your input data) - `llm_query(prompt)` - query a sub-LLM (~500K char capacity) for semantic analysis -- `llm_query_batched(prompts)` - query multiple prompts concurrently (much faster for multiple queries) +- `llm_query_batched(prompts)` - query multiple prompts concurrently (much faster for multiple queries){media_tools} - `print()` - ALWAYS print to see results - `SUBMIT({final_output_names})` - submit final output when done - Standard libraries: re, json, collections, math, etc. @@ -56,7 +72,7 @@ 1. EXPLORE FIRST - Look at your data before processing it. Print samples, check types/lengths, understand the structure. 2. ITERATE - Write small code snippets, observe outputs, then decide next steps. State persists between iterations. 3. VERIFY BEFORE SUBMITTING - If results seem wrong (zeros, empty, unexpected), reconsider your approach. -4. USE llm_query FOR SEMANTICS - String matching finds WHERE things are; llm_query understands WHAT things mean. +4. USE llm_query FOR SEMANTICS - String matching finds WHERE things are; llm_query understands WHAT things mean.{media_guidelines} 5. MINIMIZE RETYPING (INPUTS & OUTPUTS) - When values are long, precise, or error-prone (IDs, numbers, code, quotes), re-access them via variables and parse/compute in code instead of retyping. Use small, targeted prints to sanity-check, but avoid manual copying when variables can carry the exact value. 6. SUBMIT ONLY AFTER SEEING OUTPUTS - SUBMIT ends the current run immediately. If you need to inspect printed output, run it in one step, review the result, then call SUBMIT in a later step. @@ -144,7 +160,7 @@ def __init__( # ========================================================================= # Reserved tool names that conflict with built-in sandbox functions - _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "SUBMIT", "print"}) + _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "llm_query_with_media", "SUBMIT", "print"}) def _normalize_tools(self, tools: list[Callable] | None) -> dict[str, Tool]: """Normalize tools list to a dict of Tool objects keyed by name.""" @@ -198,11 +214,18 @@ def _format_tool_docs(self, tools: dict[str, Tool]) -> str: return "\n".join(lines) - def _make_llm_tools(self, max_workers: int = 8) -> dict[str, Callable]: - """Create llm_query and llm_query_batched tools with a fresh call counter.""" + def _make_llm_tools(self, media_registry: dict[str, Any] | None = None, max_workers: int = 8) -> dict[str, Callable]: + """Create llm_query, llm_query_batched, and llm_query_with_media tools with a fresh call counter. + + Args: + media_registry: Dict mapping variable names to media objects (Audio/Image). + Used by llm_query_with_media to attach media to sub-LLM calls. + max_workers: Max concurrent workers for batched queries. + """ state = {"call_count": 0} lock = threading.Lock() lm = self.sub_lm + _media_registry = media_registry or {} def _check_and_increment(n: int = 1) -> None: with lock: @@ -225,6 +248,26 @@ def _query_lm(prompt: str) -> str: return item return str(response) + def _query_lm_multimodal(prompt: str, media_objects: list) -> str: + """Query the LLM with a prompt string and media content parts.""" + target_lm = lm if lm is not None else dspy.settings.lm + if target_lm is None: + raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + + # Build multimodal content: text prompt + media content parts + content_parts = [{"type": "text", "text": prompt}] + for media_obj in media_objects: + content_parts.extend(_format_media_for_lm(media_obj)) + + messages = [{"role": "user", "content": content_parts}] + response = target_lm(messages=messages) + if isinstance(response, list) and response: + item = response[0] + if isinstance(item, dict) and "text" in item: + return item["text"] + return item + return str(response) + def llm_query(prompt: str) -> str: """Query the LLM with a prompt string.""" if not prompt: @@ -249,7 +292,42 @@ def llm_query_batched(prompts: list[str]) -> list[str]: results[idx] = f"[ERROR] {e}" return [results[i] for i in range(len(prompts))] - return {"llm_query": llm_query, "llm_query_batched": llm_query_batched} + def llm_query_with_media(prompt: str, *media_var_names: str) -> str: + """Query the LLM with a prompt and media variables (audio/image). + + Args: + prompt: The text prompt for the LLM. + *media_var_names: Names of media variables to include (e.g., 'audio_input', 'my_image'). + These must be names of Audio or Image input variables. + + Returns: + The LLM's text response. + """ + if not prompt: + raise ValueError("prompt cannot be empty") + if not media_var_names: + raise ValueError( + "At least one media variable name is required. " + f"Available media variables: {list(_media_registry.keys())}" + ) + + # Resolve media objects from the registry + media_objects = [] + for var_name in media_var_names: + if var_name not in _media_registry: + available = list(_media_registry.keys()) + raise ValueError( + f"Media variable '{var_name}' not found. Available media variables: {available}" + ) + media_objects.append(_media_registry[var_name]) + + _check_and_increment(1) + return _query_lm_multimodal(prompt, media_objects) + + tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched} + if _media_registry: + tools["llm_query_with_media"] = llm_query_with_media + return tools @property def tools(self) -> dict[str, Tool]: @@ -260,6 +338,21 @@ def tools(self) -> dict[str, Tool]: # Signature Building # ========================================================================= + def _detect_media_fields(self) -> dict[str, str]: + """Detect input fields that are Audio or Image types. + + Returns: + Dict mapping field name to media type name (e.g., {'audio_input': 'Audio', 'photo': 'Image'}). + """ + media_fields = {} + for name, field in self.signature.input_fields.items(): + annotation = getattr(field, "annotation", None) + if annotation is Audio: + media_fields[name] = "Audio" + elif annotation is Image: + media_fields[name] = "Image" + return media_fields + def _build_signatures(self) -> tuple[Signature, Signature]: """Build the action and extract signatures from templates.""" inputs_str = ", ".join(f"`{n}`" for n in self.signature.input_fields) @@ -278,10 +371,28 @@ def _build_signatures(self) -> tuple[Signature, Signature]: # Format tool documentation for user-provided tools tool_docs = self._format_tool_docs(self._user_tools) + # Detect media fields and build media-specific instructions + media_fields = self._detect_media_fields() + if media_fields: + media_var_list = ", ".join(f"'{name}'" for name in media_fields) + media_tools_str = ( + f"\n- `llm_query_with_media(prompt, *media_var_names)` - query sub-LLM with media (audio/image) attached. " + f"Media variables: {media_var_list}. The sub-LLM can see/hear the media content." + ) + media_guidelines_str = ( + f"\n FOR MEDIA INPUTS (Audio/Image): Variables like {media_var_list} are media objects. " + f"In the sandbox they appear as descriptor strings — you CANNOT decode or process them as raw data. " + f"Use `llm_query_with_media(prompt, {list(media_fields.keys())[0]!r})` to send media to a sub-LLM that can perceive it." + ) + else: + media_tools_str = "" + media_guidelines_str = "" + action_sig = ( dspy.Signature({}, task_instructions + ACTION_INSTRUCTIONS_TEMPLATE.format( inputs=inputs_str, final_output_names=final_output_names, output_fields=output_fields, max_llm_calls=self.max_llm_calls, + media_tools=media_tools_str, media_guidelines=media_guidelines_str, ) + tool_docs) .append("variables_info", dspy.InputField(desc="Metadata about the variables available in the REPL"), type_=str) .append("repl_history", dspy.InputField(desc="Previous REPL code executions and their outputs"), type_=REPLHistory) @@ -350,9 +461,21 @@ def _validate_inputs(self, input_args: dict[str, Any]) -> None: # CodeInterpreter Lifecycle # ========================================================================= - def _prepare_execution_tools(self) -> dict[str, Callable]: + def _build_media_registry(self, input_args: dict[str, Any]) -> dict[str, Any]: + """Extract media objects (Audio/Image) from inputs into a registry. + + Returns: + Dict mapping variable names to their media objects. + """ + registry = {} + for name, value in input_args.items(): + if _is_media(value): + registry[name] = value + return registry + + def _prepare_execution_tools(self, media_registry: dict[str, Any] | None = None) -> dict[str, Callable]: """Create fresh LLM tools and merge with user-provided tools.""" - execution_tools = self._make_llm_tools() + execution_tools = self._make_llm_tools(media_registry=media_registry) # Extract underlying functions from Tool objects for the interpreter execution_tools.update({name: tool.func for name, tool in self._user_tools.items()}) return execution_tools @@ -551,7 +674,8 @@ def forward(self, **input_args) -> Prediction: self._validate_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - execution_tools = self._prepare_execution_tools() + media_registry = self._build_media_registry(input_args) + execution_tools = self._prepare_execution_tools(media_registry=media_registry) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: @@ -634,7 +758,8 @@ async def aforward(self, **input_args) -> Prediction: self._validate_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - execution_tools = self._prepare_execution_tools() + media_registry = self._build_media_registry(input_args) + execution_tools = self._prepare_execution_tools(media_registry=media_registry) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 76b685ad58..58a74b0538 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -19,6 +19,40 @@ from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreterError, FinalOutput +# Lazy import helpers for multimodal types to avoid circular imports +def _is_media_type(value: Any) -> bool: + """Check if value is a DSPy Audio or Image type.""" + try: + from dspy.adapters.types.audio import Audio + if isinstance(value, Audio): + return True + except ImportError: + pass + try: + from dspy.adapters.types.image import Image + if isinstance(value, Image): + return True + except ImportError: + pass + return False + + +def _media_descriptor(value: Any) -> str: + """Return a human-readable descriptor string for a media object.""" + try: + from dspy.adapters.types.audio import Audio + if isinstance(value, Audio): + return f"" + except ImportError: + pass + try: + from dspy.adapters.types.image import Image + if isinstance(value, Image): + return repr(value) + except ImportError: + pass + return repr(value) + __all__ = ["PythonInterpreter", "FinalOutput", "CodeInterpreterError"] logger = logging.getLogger(__name__) @@ -378,6 +412,8 @@ def _to_json_compatible(self, value: Any) -> Any: """Recursively convert Python values to JSON-compatible types.""" if value is None or isinstance(value, (str, int, float, bool)): return value + elif _is_media_type(value): + return _media_descriptor(value) elif isinstance(value, dict): return {k: self._to_json_compatible(v) for k, v in value.items()} elif isinstance(value, (list, tuple)): @@ -448,6 +484,10 @@ def _serialize_value(self, value: Any) -> str: sorted_items = list(value) items = ", ".join(self._serialize_value(item) for item in sorted_items) return f"[{items}]" + elif _is_media_type(value): + # Media types (Audio, Image) are represented as descriptor strings in the sandbox. + # The actual media data is accessed via llm_query_with_media() in the RLM context. + return repr(_media_descriptor(value)) else: raise CodeInterpreterError(f"Unsupported value type: {type(value).__name__}") From 9eeee381f2a2a9a37504ae575e83f6dcff44845d Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 09:45:52 -0500 Subject: [PATCH 02/15] test(rlm): add unit tests for multimodal media support + fix lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix ruff lint errors: import sorting, unused loop var, list()[0] → next(iter()) - Add 15 unit tests for RLM media detection, registry, llm_query_with_media - Add 12 unit tests for python_interpreter media serialization helpers - All 78 unit tests pass (37 deno tests skipped as expected) Co-Authored-By: Claude Opus 4.6 --- dspy/predict/rlm.py | 8 +- dspy/primitives/python_interpreter.py | 1 + tests/predict/test_rlm.py | 218 +++++++++++++++++++ tests/primitives/test_media_serialization.py | 78 +++++++ 4 files changed, 301 insertions(+), 4 deletions(-) create mode 100644 tests/primitives/test_media_serialization.py diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 55fcf4ca92..7dfc533722 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -20,10 +20,10 @@ import pydantic import dspy -from dspy.adapters.types.tool import Tool -from dspy.adapters.utils import parse_value, translate_field_type from dspy.adapters.types.audio import Audio from dspy.adapters.types.image import Image +from dspy.adapters.types.tool import Tool +from dspy.adapters.utils import parse_value, translate_field_type from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreter, CodeInterpreterError, FinalOutput from dspy.primitives.module import Module from dspy.primitives.prediction import Prediction @@ -187,7 +187,7 @@ def to_tool(func: Callable | Tool) -> Tool: def _validate_tools(self, tools: dict[str, Tool]) -> None: """Validate user-provided tools have valid names.""" - for name, tool in tools.items(): + for name, _tool in tools.items(): if not name.isidentifier(): raise ValueError(f"Invalid tool name '{name}': must be a valid Python identifier") if name in self._RESERVED_TOOL_NAMES: @@ -382,7 +382,7 @@ def _build_signatures(self) -> tuple[Signature, Signature]: media_guidelines_str = ( f"\n FOR MEDIA INPUTS (Audio/Image): Variables like {media_var_list} are media objects. " f"In the sandbox they appear as descriptor strings — you CANNOT decode or process them as raw data. " - f"Use `llm_query_with_media(prompt, {list(media_fields.keys())[0]!r})` to send media to a sub-LLM that can perceive it." + f"Use `llm_query_with_media(prompt, {next(iter(media_fields.keys()))!r})` to send media to a sub-LLM that can perceive it." ) else: media_tools_str = "" diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 58a74b0538..df61180b3e 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -19,6 +19,7 @@ from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreterError, FinalOutput + # Lazy import helpers for multimodal types to avoid circular imports def _is_media_type(value: Any) -> bool: """Check if value is a DSPy Audio or Image type.""" diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index dab71e3416..1be7b83748 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1110,5 +1110,223 @@ def test_with_llm_query(self): assert "dog" in result.answer.lower() +# ============================================================================ +# Unit Tests: Multimodal Media Support (Audio/Image) +# ============================================================================ + + +class TestMediaDetection: + """Unit tests for media field detection and registry building.""" + + def test_detect_audio_field(self): + """Test _detect_media_fields finds Audio-typed inputs.""" + import dspy + from dspy.adapters.types.audio import Audio + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + media_fields = rlm._detect_media_fields() + + assert "audio_input" in media_fields + assert media_fields["audio_input"] == "Audio" + + def test_detect_image_field(self): + """Test _detect_media_fields finds Image-typed inputs.""" + import dspy + from dspy.adapters.types.image import Image + + class DescribeSig(dspy.Signature): + """Describe an image.""" + photo: Image = dspy.InputField() + description: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('a cat')"}]): + rlm = RLM(DescribeSig, max_iterations=3) + media_fields = rlm._detect_media_fields() + + assert "photo" in media_fields + assert media_fields["photo"] == "Image" + + def test_detect_mixed_media_fields(self): + """Test _detect_media_fields finds both Audio and Image in same signature.""" + import dspy + from dspy.adapters.types.audio import Audio + from dspy.adapters.types.image import Image + + class MultimodalSig(dspy.Signature): + """Process audio and image together.""" + audio_clip: Audio = dspy.InputField() + photo: Image = dspy.InputField() + analysis: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('done')"}]): + rlm = RLM(MultimodalSig, max_iterations=3) + media_fields = rlm._detect_media_fields() + + assert len(media_fields) == 2 + assert media_fields["audio_clip"] == "Audio" + assert media_fields["photo"] == "Image" + + def test_no_media_fields(self): + """Test _detect_media_fields returns empty dict for text-only signatures.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + media_fields = rlm._detect_media_fields() + + assert media_fields == {} + + def test_build_media_registry_with_audio(self): + """Test _build_media_registry extracts Audio objects from inputs.""" + import dspy + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + registry = rlm._build_media_registry({"audio_input": audio}) + + assert "audio_input" in registry + assert registry["audio_input"] is audio + + def test_build_media_registry_ignores_text(self): + """Test _build_media_registry skips non-media values.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + registry = rlm._build_media_registry({"query": "hello world"}) + + assert registry == {} + + +class TestLLMQueryWithMedia: + """Unit tests for llm_query_with_media tool creation and validation.""" + + def test_media_tool_available_when_registry_populated(self): + """Test llm_query_with_media is created when media registry is non-empty.""" + import dspy + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + + assert "llm_query_with_media" in tools + assert "llm_query" in tools + assert "llm_query_batched" in tools + + def test_media_tool_absent_when_no_media(self): + """Test llm_query_with_media is NOT created when media registry is empty.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(media_registry={}) + + assert "llm_query_with_media" not in tools + assert "llm_query" in tools + + def test_media_tool_absent_when_registry_none(self): + """Test llm_query_with_media is NOT created when media registry is None.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(media_registry=None) + + assert "llm_query_with_media" not in tools + + def test_media_tool_rejects_empty_prompt(self): + """Test llm_query_with_media raises on empty prompt.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="prompt cannot be empty"): + tools["llm_query_with_media"]("") + + def test_media_tool_rejects_no_media_vars(self): + """Test llm_query_with_media raises when no media var names given.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="At least one media variable"): + tools["llm_query_with_media"]("transcribe this") + + def test_media_tool_rejects_unknown_var(self): + """Test llm_query_with_media raises on nonexistent media variable.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="not found"): + tools["llm_query_with_media"]("transcribe", "nonexistent_var") + + def test_reserved_tool_names_includes_media(self): + """Test llm_query_with_media is in the reserved tool names set.""" + assert "llm_query_with_media" in RLM._RESERVED_TOOL_NAMES + + +class TestMediaInstructions: + """Unit tests for media-specific instruction injection in signatures.""" + + def test_media_tools_in_action_instructions(self): + """Test that media fields cause llm_query_with_media docs to appear.""" + import dspy + from dspy.adapters.types.audio import Audio + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + action_sig, _extract_sig = rlm._build_signatures() + + instructions = action_sig.instructions + assert "llm_query_with_media" in instructions + assert "audio_input" in instructions + + def test_no_media_instructions_for_text_only(self): + """Test that text-only signatures do NOT include media instructions.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + action_sig, _extract_sig = rlm._build_signatures() + + instructions = action_sig.instructions + assert "llm_query_with_media" not in instructions + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/primitives/test_media_serialization.py b/tests/primitives/test_media_serialization.py new file mode 100644 index 0000000000..99246128da --- /dev/null +++ b/tests/primitives/test_media_serialization.py @@ -0,0 +1,78 @@ +""" +Tests for media type helpers in python_interpreter.py. + +These tests do NOT require Deno — they test pure Python serialization logic +for Audio and Image objects in the sandboxed interpreter. +""" + + +from dspy.primitives.python_interpreter import _is_media_type, _media_descriptor + + +class TestIsMediaType: + """Tests for _is_media_type helper.""" + + def test_audio(self): + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + assert _is_media_type(audio) is True + + def test_image(self): + from dspy.adapters.types.image import Image + + img = Image(url="") + assert _is_media_type(img) is True + + def test_string(self): + assert _is_media_type("hello") is False + + def test_int(self): + assert _is_media_type(42) is False + + def test_none(self): + assert _is_media_type(None) is False + + def test_dict(self): + assert _is_media_type({"data": "abc"}) is False + + def test_list(self): + assert _is_media_type([1, 2, 3]) is False + + +class TestMediaDescriptor: + """Tests for _media_descriptor helper.""" + + def test_audio_descriptor(self): + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + desc = _media_descriptor(audio) + assert "Audio" in desc + assert "wav" in desc + assert "8" in desc # len("dGVzdA==") == 8 + + def test_image_descriptor(self): + from dspy.adapters.types.image import Image + + img = Image(url="") + desc = _media_descriptor(img) + assert "Image" in desc + + def test_non_media_falls_back_to_repr(self): + desc = _media_descriptor("just a string") + assert desc == repr("just a string") + + def test_audio_descriptor_includes_format(self): + from dspy.adapters.types.audio import Audio + + audio = Audio(data="AAAA", audio_format="mpeg") + desc = _media_descriptor(audio) + assert "mpeg" in desc + + def test_audio_descriptor_includes_data_length(self): + from dspy.adapters.types.audio import Audio + + audio = Audio(data="A" * 100, audio_format="wav") + desc = _media_descriptor(audio) + assert "100" in desc From b189e295e36ab5554e51ec10fa4c82cd3b14a417 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 11:02:49 -0500 Subject: [PATCH 03/15] feat(rlm): add multi-model sub-call routing via sub_lms parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable sandbox code to select different LMs for sub-queries by passing model='name' to llm_query, llm_query_batched, and llm_query_with_media. This matches the upstream RLM design where registered models can be chosen at call time within the REPL. - Add sub_lms: dict[str, dspy.LM] parameter to RLM.__init__ - Add _resolve_lm() helper with 3-tier routing: named → sub_lm → default - Add model parameter to all llm_query* tool functions - Auto-document available model names in action instructions template - Shared call counter across all model choices Co-Authored-By: Claude Opus 4.6 --- dspy/predict/rlm.py | 78 ++++++++++++++----- tests/predict/test_rlm.py | 152 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 20 deletions(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 7dfc533722..e4b5f4fecd 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -61,8 +61,8 @@ def _format_media_for_lm(value): Available: - Variables: {inputs} (your input data) -- `llm_query(prompt)` - query a sub-LLM (~500K char capacity) for semantic analysis -- `llm_query_batched(prompts)` - query multiple prompts concurrently (much faster for multiple queries){media_tools} +- `llm_query(prompt, model=None)` - query a sub-LLM (~500K char capacity) for semantic analysis{model_docs} +- `llm_query_batched(prompts, model=None)` - query multiple prompts concurrently (much faster for multiple queries){media_tools} - `print()` - ALWAYS print to see results - `SUBMIT({final_output_names})` - submit final output when done - Standard libraries: re, json, collections, math, etc. @@ -123,6 +123,7 @@ def __init__( verbose: bool = False, tools: list[Callable] | None = None, sub_lm: dspy.LM | None = None, + sub_lms: dict[str, dspy.LM] | None = None, interpreter: CodeInterpreter | None = None, ): """ @@ -135,8 +136,11 @@ def __init__( verbose: Whether to log detailed execution info. tools: List of tool functions or dspy.Tool objects callable from interpreter code. Built-in tools: llm_query(prompt), llm_query_batched(prompts). - sub_lm: LM for llm_query/llm_query_batched. Defaults to dspy.settings.lm. + sub_lm: Default LM for llm_query/llm_query_batched. Defaults to dspy.settings.lm. Allows using a different (e.g., cheaper) model for sub-queries. + sub_lms: Dict mapping model names to LM instances, enabling sandbox code to select + a specific model via llm_query(prompt, model="name"). When model is None, + falls back to sub_lm, then dspy.settings.lm. interpreter: CodeInterpreter implementation to use. Defaults to PythonInterpreter. """ super().__init__() @@ -146,6 +150,7 @@ def __init__( self.max_output_chars = max_output_chars self.verbose = verbose self.sub_lm = sub_lm + self.sub_lms = sub_lms or {} self._interpreter = interpreter self._user_tools = self._normalize_tools(tools) self._validate_tools(self._user_tools) @@ -224,9 +229,27 @@ def _make_llm_tools(self, media_registry: dict[str, Any] | None = None, max_work """ state = {"call_count": 0} lock = threading.Lock() - lm = self.sub_lm + default_lm = self.sub_lm + named_lms = self.sub_lms _media_registry = media_registry or {} + def _resolve_lm(model: str | None = None) -> dspy.LM: + """Resolve a model name to an LM instance. + + Resolution order: named sub_lms[model] → sub_lm → dspy.settings.lm. + """ + if model is not None: + if model in named_lms: + return named_lms[model] + available = list(named_lms.keys()) + raise ValueError( + f"Model '{model}' not found in sub_lms. Available models: {available}" + ) + target_lm = default_lm if default_lm is not None else dspy.settings.lm + if target_lm is None: + raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + return target_lm + def _check_and_increment(n: int = 1) -> None: with lock: if state["call_count"] + n > self.max_llm_calls: @@ -236,10 +259,8 @@ def _check_and_increment(n: int = 1) -> None: ) state["call_count"] += n - def _query_lm(prompt: str) -> str: - target_lm = lm if lm is not None else dspy.settings.lm - if target_lm is None: - raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + def _query_lm(prompt: str, model: str | None = None) -> str: + target_lm = _resolve_lm(model) response = target_lm(prompt) if isinstance(response, list) and response: item = response[0] @@ -248,11 +269,9 @@ def _query_lm(prompt: str) -> str: return item return str(response) - def _query_lm_multimodal(prompt: str, media_objects: list) -> str: + def _query_lm_multimodal(prompt: str, media_objects: list, model: str | None = None) -> str: """Query the LLM with a prompt string and media content parts.""" - target_lm = lm if lm is not None else dspy.settings.lm - if target_lm is None: - raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + target_lm = _resolve_lm(model) # Build multimodal content: text prompt + media content parts content_parts = [{"type": "text", "text": prompt}] @@ -268,22 +287,32 @@ def _query_lm_multimodal(prompt: str, media_objects: list) -> str: return item return str(response) - def llm_query(prompt: str) -> str: - """Query the LLM with a prompt string.""" + def llm_query(prompt: str, model: str | None = None) -> str: + """Query the LLM with a prompt string. + + Args: + prompt: The text prompt for the LLM. + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. + """ if not prompt: raise ValueError("prompt cannot be empty") _check_and_increment(1) - return _query_lm(prompt) + return _query_lm(prompt, model=model) - def llm_query_batched(prompts: list[str]) -> list[str]: - """Query the LLM with multiple prompts concurrently.""" + def llm_query_batched(prompts: list[str], model: str | None = None) -> list[str]: + """Query the LLM with multiple prompts concurrently. + + Args: + prompts: List of prompts to send to the LLM. + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. + """ if not prompts: return [] _check_and_increment(len(prompts)) results: dict[int, str] = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_idx = {executor.submit(_query_lm, p): i for i, p in enumerate(prompts)} + future_to_idx = {executor.submit(_query_lm, p, model): i for i, p in enumerate(prompts)} for future in as_completed(future_to_idx): idx = future_to_idx[future] try: @@ -292,13 +321,14 @@ def llm_query_batched(prompts: list[str]) -> list[str]: results[idx] = f"[ERROR] {e}" return [results[i] for i in range(len(prompts))] - def llm_query_with_media(prompt: str, *media_var_names: str) -> str: + def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = None) -> str: """Query the LLM with a prompt and media variables (audio/image). Args: prompt: The text prompt for the LLM. *media_var_names: Names of media variables to include (e.g., 'audio_input', 'my_image'). These must be names of Audio or Image input variables. + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. Returns: The LLM's text response. @@ -322,7 +352,7 @@ def llm_query_with_media(prompt: str, *media_var_names: str) -> str: media_objects.append(_media_registry[var_name]) _check_and_increment(1) - return _query_lm_multimodal(prompt, media_objects) + return _query_lm_multimodal(prompt, media_objects, model=model) tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched} if _media_registry: @@ -388,11 +418,19 @@ def _build_signatures(self) -> tuple[Signature, Signature]: media_tools_str = "" media_guidelines_str = "" + # Document available model names if sub_lms is configured + if self.sub_lms: + model_names = ", ".join(f"'{name}'" for name in self.sub_lms) + model_docs_str = f"\n Available models: {model_names}. Pass model= to select one." + else: + model_docs_str = "" + action_sig = ( dspy.Signature({}, task_instructions + ACTION_INSTRUCTIONS_TEMPLATE.format( inputs=inputs_str, final_output_names=final_output_names, output_fields=output_fields, max_llm_calls=self.max_llm_calls, media_tools=media_tools_str, media_guidelines=media_guidelines_str, + model_docs=model_docs_str, ) + tool_docs) .append("variables_info", dspy.InputField(desc="Metadata about the variables available in the REPL"), type_=str) .append("repl_history", dspy.InputField(desc="Previous REPL code executions and their outputs"), type_=REPLHistory) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 1be7b83748..1928c247d1 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1328,5 +1328,157 @@ def test_no_media_instructions_for_text_only(self): assert "llm_query_with_media" not in instructions +class TestMultiModelSubCalls: + """Unit tests for multi-model sub-call routing via sub_lms parameter.""" + + def test_sub_lms_stored_on_init(self): + """Test sub_lms dict is stored on the RLM instance.""" + from unittest.mock import MagicMock + + lm1 = MagicMock() + lm2 = MagicMock() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": lm1, "pro": lm2}) + + assert rlm.sub_lms == {"flash": lm1, "pro": lm2} + + def test_sub_lms_defaults_to_empty_dict(self): + """Test sub_lms defaults to empty dict when not provided.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + assert rlm.sub_lms == {} + + def test_llm_query_routes_to_named_model(self): + """Test llm_query(prompt, model='name') routes to the correct LM.""" + from unittest.mock import MagicMock + + mock_flash = MagicMock(return_value=["flash response"]) + mock_pro = MagicMock(return_value=["pro response"]) + mock_default = MagicMock(return_value=["default response"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_default, + sub_lms={"flash": mock_flash, "pro": mock_pro}) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt", model="flash") + assert result == "flash response" + mock_flash.assert_called_once() + + result = tools["llm_query"]("test prompt", model="pro") + assert result == "pro response" + mock_pro.assert_called_once() + + def test_llm_query_default_model_uses_sub_lm(self): + """Test llm_query without model param falls back to sub_lm.""" + from unittest.mock import MagicMock + + mock_sub = MagicMock(return_value=["sub_lm response"]) + mock_flash = MagicMock(return_value=["flash response"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_sub, sub_lms={"flash": mock_flash}) + tools = rlm._make_llm_tools() + + # model=None should use sub_lm, not flash + result = tools["llm_query"]("test prompt") + assert result == "sub_lm response" + mock_sub.assert_called_once() + mock_flash.assert_not_called() + + def test_llm_query_raises_on_unknown_model(self): + """Test llm_query raises ValueError for unknown model name.""" + from unittest.mock import MagicMock + + mock_flash = MagicMock(return_value=["flash"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": mock_flash}) + tools = rlm._make_llm_tools() + + with pytest.raises(ValueError, match="Model 'nonexistent' not found"): + tools["llm_query"]("test", model="nonexistent") + + def test_llm_query_batched_routes_to_named_model(self): + """Test llm_query_batched with model param routes all prompts to named LM.""" + from unittest.mock import MagicMock + + mock_default = MagicMock(return_value=["default response"]) + mock_pro = MagicMock(return_value=["pro response"]) + + rlm = RLM("query -> answer", max_iterations=3, max_llm_calls=10, + sub_lm=mock_default, sub_lms={"pro": mock_pro}) + tools = rlm._make_llm_tools() + + results = tools["llm_query_batched"](["prompt1", "prompt2"], model="pro") + assert len(results) == 2 + # Both should come from the pro LM + assert all(r == "pro response" for r in results) + assert mock_pro.call_count == 2 + mock_default.assert_not_called() + + def test_llm_query_with_media_routes_to_named_model(self): + """Test llm_query_with_media with model kwarg routes to named LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_default = MagicMock(return_value=["default response"]) + mock_pro = MagicMock(return_value=["pro media response"]) + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_default, + sub_lms={"pro": mock_pro}) + tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + + result = tools["llm_query_with_media"]("transcribe", "audio_input", model="pro") + assert result == "pro media response" + mock_pro.assert_called_once() + mock_default.assert_not_called() + + def test_model_docs_in_instructions_when_sub_lms(self): + """Test that model names appear in action instructions when sub_lms is set.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": mock_lm, "pro": mock_lm}) + action_sig, _ = rlm._build_signatures() + + instructions = action_sig.instructions + assert "flash" in instructions + assert "pro" in instructions + assert "model=" in instructions + + def test_no_model_docs_when_no_sub_lms(self): + """Test that model docs are absent when sub_lms is not set.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + action_sig, _ = rlm._build_signatures() + + instructions = action_sig.instructions + assert "Available models:" not in instructions + + def test_call_count_shared_across_models(self): + """Test that the call counter is shared across all model choices.""" + from unittest.mock import MagicMock + + mock_default = MagicMock(return_value=["default"]) + mock_flash = MagicMock(return_value=["flash"]) + mock_pro = MagicMock(return_value=["pro"]) + + rlm = RLM("query -> answer", max_iterations=3, max_llm_calls=3, + sub_lm=mock_default, sub_lms={"flash": mock_flash, "pro": mock_pro}) + tools = rlm._make_llm_tools() + + tools["llm_query"]("p1", model="flash") + tools["llm_query"]("p2", model="pro") + tools["llm_query"]("p3") # default + + # 4th call should exceed the limit of 3 + with pytest.raises(RuntimeError, match="LLM call limit exceeded"): + tools["llm_query"]("p4", model="flash") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From a1de35b938b8b92cbd8727b1be41e60ff4118cdf Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 11:24:32 -0500 Subject: [PATCH 04/15] feat(rlm): add LocalInterpreter for unsandboxed host-process execution Add LocalInterpreter as an alternative to PythonInterpreter (Deno/Pyodide) for RLM code execution. Executes code directly in the host Python process via exec(), giving RLM agents access to any installed package (PIL, numpy, soundfile, etc.) without WASM sandbox restrictions. Implements the CodeInterpreter protocol: start(), execute(), shutdown(), tools property. State persists across execute() calls. SUBMIT() maps positional args to output_fields (matching PythonInterpreter behavior). Usage: from dspy.primitives.local_interpreter import LocalInterpreter rlm = dspy.RLM(sig, interpreter=LocalInterpreter()) Includes 33 unit tests covering execution, state persistence, variable injection, error handling, SUBMIT/FinalOutput, tools, and context manager. Co-Authored-By: Claude Opus 4.6 --- dspy/primitives/local_interpreter.py | 175 ++++++++++++ tests/primitives/test_local_interpreter.py | 313 +++++++++++++++++++++ 2 files changed, 488 insertions(+) create mode 100644 dspy/primitives/local_interpreter.py create mode 100644 tests/primitives/test_local_interpreter.py diff --git a/dspy/primitives/local_interpreter.py b/dspy/primitives/local_interpreter.py new file mode 100644 index 0000000000..0a703b3ca5 --- /dev/null +++ b/dspy/primitives/local_interpreter.py @@ -0,0 +1,175 @@ +""" +Unsandboxed local Python interpreter for RLM. + +Implements the CodeInterpreter protocol but executes code directly in the host +Python process via exec(). This gives the RLM agent full access to any installed +Python package (PIL, pydub, numpy, scipy, etc.). + +Use this when the sandboxed PythonInterpreter (Deno/Pyodide) is too restrictive — +e.g., when the RLM agent needs to manipulate images with PIL or process audio +with pydub directly in its generated code. + +Security: This is intentionally UNSANDBOXED. The LLM-generated code runs with +full host process privileges. Only use for local experiments or trusted workloads. + +Usage: + from dspy.primitives.local_interpreter import LocalInterpreter + + rlm = dspy.RLM("context -> answer", interpreter=LocalInterpreter()) +""" + +import io +import sys +import traceback +from typing import Any, Callable + +from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput + + +class _SubmitCalled(Exception): + """Internal signal raised when SUBMIT() is called in user code.""" + def __init__(self, output: Any): + self.output = output + + +class LocalInterpreter: + """Unsandboxed Python interpreter implementing the CodeInterpreter protocol. + + Executes code directly in the host process via exec(). State persists + across execute() calls within a session. Tools are injected as callable + functions in the execution namespace. + + This gives the RLM agent full access to the host Python environment: + - PIL/Pillow for image manipulation + - pydub/ffmpeg for audio manipulation + - numpy, scipy, scikit-image, etc. + - Any installed Python package + + Note: Not thread-safe. Create separate instances for concurrent use. + """ + + def __init__( + self, + tools: dict[str, Callable[..., str]] | None = None, + output_fields: list[dict] | None = None, + ): + """ + Args: + tools: Dictionary mapping tool names to callable functions. + Tools are available as top-level functions in the namespace. + output_fields: Output field definitions for typed SUBMIT signature. + """ + self._tools: dict[str, Callable[..., str]] = dict(tools) if tools else {} + self.output_fields = output_fields + self._namespace: dict[str, Any] = {} + self._started = False + + @property + def tools(self) -> dict[str, Callable[..., str]]: + """Tools available for interpreter code to call.""" + return self._tools + + @tools.setter + def tools(self, value: dict[str, Callable[..., str]]) -> None: + self._tools = value + + def start(self) -> None: + """Initialize the interpreter namespace.""" + if self._started: + return + self._namespace = {"__builtins__": __builtins__} + self._started = True + + def execute( + self, + code: str, + variables: dict[str, Any] | None = None, + ) -> Any: + """Execute Python code in the host process. + + Args: + code: Python code to execute. + variables: Variables to inject into the namespace before execution. + Media objects (Audio, Image) are injected AS-IS, giving + code direct access to their data for manipulation. + + Returns: + - FinalOutput: If SUBMIT() was called + - str: Captured stdout (from print() calls) + - None: If no output was produced + + Raises: + CodeInterpreterError: On runtime errors + SyntaxError: On invalid Python syntax + """ + if not self._started: + self.start() + + # Inject variables directly into namespace (no serialization — objects stay live) + if variables: + self._namespace.update(variables) + + # Inject tools as callable functions + for name, func in self._tools.items(): + self._namespace[name] = func + + # Inject SUBMIT function — maps args to output field names (matching PythonInterpreter) + output_fields = self.output_fields or [] + field_names = [f["name"] for f in output_fields] + + def SUBMIT(*args, **kwargs): + if not args and not kwargs: + raise ValueError("SUBMIT requires at least one argument") + if args and kwargs: + raise ValueError("SUBMIT accepts either positional args or keyword args, not both") + if kwargs: + output = kwargs + elif field_names: + if len(args) != len(field_names): + expected = ", ".join(field_names) + raise TypeError( + f"SUBMIT() takes {len(field_names)} positional argument(s) " + f"({expected}), but {len(args)} were given" + ) + output = dict(zip(field_names, args)) + elif len(args) == 1: + output = {"output": args[0]} + else: + output = {"output": args} + raise _SubmitCalled(output) + + self._namespace["SUBMIT"] = SUBMIT + + # Capture stdout + old_stdout = sys.stdout + captured = io.StringIO() + sys.stdout = captured + + try: + exec(code, self._namespace) + except _SubmitCalled as e: + return FinalOutput(e.output) + except SyntaxError: + raise + except Exception as e: + tb = traceback.format_exc() + raise CodeInterpreterError(f"{type(e).__name__}: {e}\n{tb}") from e + finally: + sys.stdout = old_stdout + + output = captured.getvalue() + if output: + return output.rstrip("\n") + return None + + def shutdown(self) -> None: + """Release resources and clear the namespace.""" + self._namespace.clear() + self._started = False + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.shutdown() diff --git a/tests/primitives/test_local_interpreter.py b/tests/primitives/test_local_interpreter.py new file mode 100644 index 0000000000..664d60a7a5 --- /dev/null +++ b/tests/primitives/test_local_interpreter.py @@ -0,0 +1,313 @@ +"""Tests for LocalInterpreter — unsandboxed host-process Python execution.""" + +import pytest + +from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput +from dspy.primitives.local_interpreter import LocalInterpreter + + +# ============================================================================= +# Basic Execution +# ============================================================================= + + +def test_execute_simple_print(): + with LocalInterpreter() as interp: + result = interp.execute("print('Hello, World!')") + assert result == "Hello, World!" + + +def test_execute_no_output(): + with LocalInterpreter() as interp: + result = interp.execute("x = 42") + assert result is None + + +def test_execute_multiline(): + with LocalInterpreter() as interp: + code = "a = 3\nb = 4\nprint(a + b)" + result = interp.execute(code) + assert result == "7" + + +def test_import_stdlib(): + with LocalInterpreter() as interp: + result = interp.execute("import math\nprint(math.sqrt(16))") + assert result == "4.0" + + +def test_import_third_party(): + """LocalInterpreter can import any installed package (unlike Pyodide sandbox).""" + with LocalInterpreter() as interp: + result = interp.execute("import numpy as np\nprint(np.array([1,2,3]).sum())") + assert result == "6" + + +# ============================================================================= +# State Persistence +# ============================================================================= + + +def test_state_persists_across_calls(): + with LocalInterpreter() as interp: + interp.execute("x = 10") + interp.execute("x += 5") + result = interp.execute("print(x)") + assert result == "15" + + +def test_state_cleared_on_shutdown(): + interp = LocalInterpreter() + interp.start() + interp.execute("x = 42") + interp.shutdown() + interp.start() + with pytest.raises(CodeInterpreterError, match="NameError"): + interp.execute("print(x)") + interp.shutdown() + + +def test_auto_start(): + """execute() auto-starts if not already started.""" + interp = LocalInterpreter() + result = interp.execute("print('auto')") + assert result == "auto" + interp.shutdown() + + +# ============================================================================= +# Variable Injection +# ============================================================================= + + +def test_variable_injection(): + with LocalInterpreter() as interp: + result = interp.execute("print(number + 1)", variables={"number": 4}) + assert result == "5" + + +def test_variable_injection_objects(): + """Variables are injected AS-IS, not serialized — live Python objects.""" + with LocalInterpreter() as interp: + data = {"key": [1, 2, 3]} + result = interp.execute("print(type(data).__name__, len(data['key']))", variables={"data": data}) + assert result == "dict 3" + + +def test_variable_injection_persists(): + with LocalInterpreter() as interp: + interp.execute("pass", variables={"x": 100}) + result = interp.execute("print(x)") + assert result == "100" + + +# ============================================================================= +# Error Handling +# ============================================================================= + + +def test_syntax_error(): + with LocalInterpreter() as interp: + with pytest.raises(SyntaxError): + interp.execute("def foo(") + + +def test_runtime_error(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="ZeroDivisionError"): + interp.execute("1 / 0") + + +def test_name_error(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="NameError"): + interp.execute("print(undefined_var)") + + +def test_error_includes_traceback(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="Traceback"): + interp.execute("raise RuntimeError('test error')") + + +def test_stdout_restored_after_error(): + """sys.stdout must be restored even if execution raises.""" + import sys + original_stdout = sys.stdout + interp = LocalInterpreter() + interp.start() + with pytest.raises(CodeInterpreterError): + interp.execute("raise ValueError('boom')") + assert sys.stdout is original_stdout + interp.shutdown() + + +# ============================================================================= +# SUBMIT / FinalOutput +# ============================================================================= + + +def test_submit_single_arg(): + """Single-arg SUBMIT wraps in {"output": value}.""" + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT("the answer")') + assert isinstance(result, FinalOutput) + assert result.output == {"output": "the answer"} + + +def test_submit_kwargs(): + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT(answer="hello", score=42)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "hello", "score": 42} + + +def test_submit_typed_positional(): + """Positional args mapped to output_fields names.""" + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "confidence", "type": "float"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + result = interp.execute('SUBMIT("the answer", 0.95)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "the answer", "confidence": 0.95} + + +def test_submit_typed_kwargs(): + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "confidence", "type": "float"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + result = interp.execute('SUBMIT(answer="the answer", confidence=0.95)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "the answer", "confidence": 0.95} + + +def test_submit_wrong_arg_count(): + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "score", "type": "int"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + with pytest.raises(CodeInterpreterError, match="takes 2 positional"): + interp.execute("SUBMIT('only one')") + + +def test_submit_no_args(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="SUBMIT requires at least one argument"): + interp.execute("SUBMIT()") + + +def test_submit_mixed_args_and_kwargs(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="SUBMIT accepts either positional"): + interp.execute('SUBMIT("pos", key="val")') + + +def test_submit_stops_execution(): + """Code after SUBMIT should not execute.""" + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT(answer="done")\nprint("should not print")') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "done"} + + +# ============================================================================= +# Tools +# ============================================================================= + + +def test_tool_basic(): + def greet(name: str) -> str: + return f"Hello, {name}!" + + with LocalInterpreter(tools={"greet": greet}) as interp: + result = interp.execute('print(greet("World"))') + assert result == "Hello, World!" + + +def test_tool_default_args(): + def search(query: str, limit: int = 10) -> str: + return f"query={query}, limit={limit}" + + with LocalInterpreter(tools={"search": search}) as interp: + result = interp.execute('print(search("test"))') + assert result == "query=test, limit=10" + result = interp.execute('print(search("test", limit=5))') + assert result == "query=test, limit=5" + + +def test_tools_via_setter(): + """Tools can be added/updated after construction via the tools setter.""" + interp = LocalInterpreter() + + def my_tool() -> str: + return "tool_result" + + interp.tools = {"my_tool": my_tool} + interp.start() + result = interp.execute("print(my_tool())") + assert result == "tool_result" + interp.shutdown() + + +def test_tools_refreshed_each_execute(): + """Tools dict is re-injected on each execute(), so updates are visible.""" + interp = LocalInterpreter() + interp.start() + + interp.tools["counter"] = lambda: "v1" + result = interp.execute("print(counter())") + assert result == "v1" + + interp.tools["counter"] = lambda: "v2" + result = interp.execute("print(counter())") + assert result == "v2" + + interp.shutdown() + + +# ============================================================================= +# Context Manager +# ============================================================================= + + +def test_context_manager(): + with LocalInterpreter() as interp: + assert interp._started is True + result = interp.execute("print('inside')") + assert result == "inside" + assert interp._started is False + + +def test_context_manager_cleanup_on_error(): + with pytest.raises(CodeInterpreterError): + with LocalInterpreter() as interp: + interp.execute("raise RuntimeError('fail')") + assert interp._started is False + + +# ============================================================================= +# Stdout Capture +# ============================================================================= + + +def test_multiple_prints(): + with LocalInterpreter() as interp: + result = interp.execute("print('a')\nprint('b')\nprint('c')") + assert result == "a\nb\nc" + + +def test_trailing_newlines_stripped(): + with LocalInterpreter() as interp: + result = interp.execute("print('hello')") + assert result == "hello" # No trailing newline + + +def test_empty_print(): + with LocalInterpreter() as interp: + result = interp.execute("print()") + assert result == "" # Empty print produces empty string after strip From 77143e74552b6c59ce07538820992acce3b3ea21 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 11:57:27 -0500 Subject: [PATCH 05/15] =?UTF-8?q?feat(rlm):=20add=20budget=20awareness=20?= =?UTF-8?q?=E2=80=94=20budget()=20tool=20and=20max=5Ftime=20parameter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The agent can now call budget() from the REPL to check remaining iterations, LLM calls, and wall-clock time. This enables cost-aware recursive strategies — the agent can decide to use cheaper/fewer sub-queries as resources dwindle. Changes: - Add max_time parameter: optional wall-clock seconds limit per forward() call. Gracefully falls back to extract when exceeded (no exception, just early termination). - Add budget() tool: injected alongside llm_query/llm_query_batched, returns human-readable summary of remaining resources. - Track execution state (start_time, current iteration) via mutable dict shared between forward() loop and budget() closure. - Update ACTION_INSTRUCTIONS_TEMPLATE to mention budget(). - Add 'budget' to reserved tool names. - Mirror all changes in aforward() for async path. 13 new tests covering: initialization, budget output format, iteration tracking, LLM call tracking, time tracking, reserved name, action instructions, timeout fallback, and end-to-end budget updates. Co-Authored-By: Claude Opus 4.6 --- dspy/predict/rlm.py | 111 ++++++++++++++++++++++++--- tests/predict/test_rlm.py | 154 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 9 deletions(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index e4b5f4fecd..652d73fb2b 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -13,6 +13,7 @@ import logging import re import threading +import time as _time from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator @@ -65,6 +66,7 @@ def _format_media_for_lm(value): - `llm_query_batched(prompts, model=None)` - query multiple prompts concurrently (much faster for multiple queries){media_tools} - `print()` - ALWAYS print to see results - `SUBMIT({final_output_names})` - submit final output when done +- `budget()` - check remaining iterations, LLM calls, and time - Standard libraries: re, json, collections, math, etc. IMPORTANT: This is ITERATIVE. Each code block you write will execute, you'll see the output, then you decide what to do next. Do NOT try to solve everything in one step. @@ -76,7 +78,7 @@ def _format_media_for_lm(value): 5. MINIMIZE RETYPING (INPUTS & OUTPUTS) - When values are long, precise, or error-prone (IDs, numbers, code, quotes), re-access them via variables and parse/compute in code instead of retyping. Use small, targeted prints to sanity-check, but avoid manual copying when variables can carry the exact value. 6. SUBMIT ONLY AFTER SEEING OUTPUTS - SUBMIT ends the current run immediately. If you need to inspect printed output, run it in one step, review the result, then call SUBMIT in a later step. -You have max {max_llm_calls} sub-LLM calls. When done, call SUBMIT() with your output.""" +You have max {max_llm_calls} sub-LLM calls. Call budget() to check remaining resources. When done, call SUBMIT() with your output.""" # Pattern to match markdown code fences: ```python\n...\n``` or ```\n...\n``` _CODE_FENCE_PATTERN = re.compile(r"^```(?:python|py)?\s*\n(.*)\n```\s*$", re.DOTALL) @@ -120,6 +122,7 @@ def __init__( max_iterations: int = 20, max_llm_calls: int = 50, max_output_chars: int = 10_000, + max_time: float | None = None, verbose: bool = False, tools: list[Callable] | None = None, sub_lm: dspy.LM | None = None, @@ -133,6 +136,8 @@ def __init__( max_iterations: Maximum REPL interaction iterations. max_llm_calls: Maximum sub-LLM calls (llm_query/llm_query_batched) per execution. max_output_chars: Maximum characters to include from REPL output. + max_time: Maximum wall-clock seconds per forward() call. None means no limit. + The agent can check remaining time via budget(). verbose: Whether to log detailed execution info. tools: List of tool functions or dspy.Tool objects callable from interpreter code. Built-in tools: llm_query(prompt), llm_query_batched(prompts). @@ -148,6 +153,7 @@ def __init__( self.max_iterations = max_iterations self.max_llm_calls = max_llm_calls self.max_output_chars = max_output_chars + self.max_time = max_time self.verbose = verbose self.sub_lm = sub_lm self.sub_lms = sub_lms or {} @@ -165,7 +171,7 @@ def __init__( # ========================================================================= # Reserved tool names that conflict with built-in sandbox functions - _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "llm_query_with_media", "SUBMIT", "print"}) + _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "llm_query_with_media", "SUBMIT", "print", "budget"}) def _normalize_tools(self, tools: list[Callable] | None) -> dict[str, Tool]: """Normalize tools list to a dict of Tool objects keyed by name.""" @@ -219,16 +225,24 @@ def _format_tool_docs(self, tools: dict[str, Tool]) -> str: return "\n".join(lines) - def _make_llm_tools(self, media_registry: dict[str, Any] | None = None, max_workers: int = 8) -> dict[str, Callable]: - """Create llm_query, llm_query_batched, and llm_query_with_media tools with a fresh call counter. + def _make_llm_tools( + self, + media_registry: dict[str, Any] | None = None, + max_workers: int = 8, + execution_state: dict[str, Any] | None = None, + ) -> dict[str, Callable]: + """Create llm_query, llm_query_batched, llm_query_with_media, and budget tools with a fresh call counter. Args: media_registry: Dict mapping variable names to media objects (Audio/Image). Used by llm_query_with_media to attach media to sub-LLM calls. max_workers: Max concurrent workers for batched queries. + execution_state: Mutable dict tracking iteration/time state. Keys: + 'start_time' (float), 'iteration' (int). Updated by forward(). """ state = {"call_count": 0} lock = threading.Lock() + _execution_state = execution_state or {} default_lm = self.sub_lm named_lms = self.sub_lms _media_registry = media_registry or {} @@ -354,7 +368,39 @@ def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = _check_and_increment(1) return _query_lm_multimodal(prompt, media_objects, model=model) - tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched} + max_iterations = self.max_iterations + max_llm_calls = self.max_llm_calls + max_time = self.max_time + + def budget() -> str: + """Check remaining execution budget: iterations, LLM calls, and time. + + Returns a human-readable summary of remaining resources. + """ + with lock: + calls_used = state["call_count"] + + iteration = _execution_state.get("iteration", 0) + remaining_iterations = max_iterations - iteration - 1 # -1 for current + remaining_calls = max_llm_calls - calls_used + + parts = [ + f"Iterations: {remaining_iterations}/{max_iterations} remaining", + f"LLM calls: {remaining_calls}/{max_llm_calls} remaining", + ] + + start_time = _execution_state.get("start_time") + if max_time is not None and start_time is not None: + elapsed = _time.monotonic() - start_time + remaining_time = max(0.0, max_time - elapsed) + parts.append(f"Time: {remaining_time:.1f}s/{max_time:.1f}s remaining ({elapsed:.1f}s elapsed)") + elif start_time is not None: + elapsed = _time.monotonic() - start_time + parts.append(f"Time: no limit ({elapsed:.1f}s elapsed)") + + return " | ".join(parts) + + tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched, "budget": budget} if _media_registry: tools["llm_query_with_media"] = llm_query_with_media return tools @@ -511,9 +557,16 @@ def _build_media_registry(self, input_args: dict[str, Any]) -> dict[str, Any]: registry[name] = value return registry - def _prepare_execution_tools(self, media_registry: dict[str, Any] | None = None) -> dict[str, Callable]: + def _prepare_execution_tools( + self, + media_registry: dict[str, Any] | None = None, + execution_state: dict[str, Any] | None = None, + ) -> dict[str, Callable]: """Create fresh LLM tools and merge with user-provided tools.""" - execution_tools = self._make_llm_tools(media_registry=media_registry) + execution_tools = self._make_llm_tools( + media_registry=media_registry, + execution_state=execution_state, + ) # Extract underlying functions from Tool objects for the interpreter execution_tools.update({name: tool.func for name, tool in self._user_tools.items()}) return execution_tools @@ -708,18 +761,38 @@ def forward(self, **input_args) -> Prediction: Raises: ValueError: If required input fields are missing + RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) media_registry = self._build_media_registry(input_args) - execution_tools = self._prepare_execution_tools(media_registry=media_registry) + + # Mutable execution state — shared with budget() tool via closure + execution_state = {"start_time": _time.monotonic(), "iteration": 0} + + execution_tools = self._prepare_execution_tools( + media_registry=media_registry, + execution_state=execution_state, + ) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: history: REPLHistory = REPLHistory(max_output_chars=self.max_output_chars) for iteration in range(self.max_iterations): + execution_state["iteration"] = iteration + + # Check time budget before starting iteration + if self.max_time is not None: + elapsed = _time.monotonic() - execution_state["start_time"] + if elapsed > self.max_time: + logger.warning( + f"RLM time budget exceeded ({elapsed:.1f}s > {self.max_time:.1f}s) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return self._extract_fallback(variables, history, output_field_names) + result: Prediction | REPLHistory = self._execute_iteration( repl, variables, history, iteration, input_args, output_field_names ) @@ -792,18 +865,38 @@ async def aforward(self, **input_args) -> Prediction: Raises: ValueError: If required input fields are missing + RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) media_registry = self._build_media_registry(input_args) - execution_tools = self._prepare_execution_tools(media_registry=media_registry) + + # Mutable execution state — shared with budget() tool via closure + execution_state = {"start_time": _time.monotonic(), "iteration": 0} + + execution_tools = self._prepare_execution_tools( + media_registry=media_registry, + execution_state=execution_state, + ) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: history = REPLHistory(max_output_chars=self.max_output_chars) for iteration in range(self.max_iterations): + execution_state["iteration"] = iteration + + # Check time budget before starting iteration + if self.max_time is not None: + elapsed = _time.monotonic() - execution_state["start_time"] + if elapsed > self.max_time: + logger.warning( + f"RLM time budget exceeded ({elapsed:.1f}s > {self.max_time:.1f}s) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return await self._aextract_fallback(variables, history, output_field_names) + result = await self._aexecute_iteration( repl, variables, history, iteration, input_args, output_field_names ) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 1928c247d1..00f6ab7b6e 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1480,5 +1480,159 @@ def test_call_count_shared_across_models(self): tools["llm_query"]("p4", model="flash") +class TestBudgetTracking: + """Tests for budget() tool and max_time enforcement.""" + + def test_max_time_initialization(self): + """Test that max_time is stored on the RLM instance.""" + rlm = RLM("query -> answer", max_time=60.0) + assert rlm.max_time == 60.0 + + def test_max_time_default_none(self): + """Test that max_time defaults to None (no limit).""" + rlm = RLM("query -> answer") + assert rlm.max_time is None + + def test_budget_tool_created(self): + """Test that the budget tool is included in execution tools.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + tools = rlm._make_llm_tools() + assert "budget" in tools + assert callable(tools["budget"]) + + def test_budget_returns_string(self): + """Test that budget() returns a human-readable string.""" + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=20) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 3} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert isinstance(result, str) + assert "Iterations:" in result + assert "LLM calls:" in result + + def test_budget_reflects_iteration(self): + """Test that budget() shows correct remaining iterations.""" + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=20) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 7} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + # iteration=7, max=10, remaining = 10 - 7 - 1 = 2 + assert "2/10 remaining" in result + + def test_budget_reflects_llm_calls(self): + """Test that budget() shows correct remaining LLM calls after usage.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + + from unittest.mock import MagicMock + mock_lm = MagicMock(return_value=["response"]) + rlm_with_lm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, sub_lm=mock_lm) + tools = rlm_with_lm._make_llm_tools(execution_state=execution_state) + + # Use 3 LLM calls + tools["llm_query"]("prompt1") + tools["llm_query"]("prompt2") + tools["llm_query"]("prompt3") + + result = tools["budget"]() + assert "7/10 remaining" in result + + def test_budget_shows_time_when_max_time_set(self): + """Test that budget() includes time info when max_time is configured.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, max_time=120.0) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "Time:" in result + assert "/120.0s remaining" in result + + def test_budget_no_time_when_max_time_none(self): + """Test that budget() shows 'no limit' when max_time is None.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "no limit" in result + + def test_budget_reserved_name(self): + """Test that 'budget' is a reserved tool name.""" + def budget() -> str: + return "custom" + + from dspy.adapters.types.tool import Tool + tool = Tool(budget, name="budget") + with pytest.raises(ValueError, match="conflicts with built-in"): + RLM("query -> answer", tools=[tool]) + + def test_budget_in_action_instructions(self): + """Test that the action instructions mention budget().""" + rlm = RLM("query -> answer", max_iterations=5) + action_sig = rlm.generate_action.signature + assert "budget()" in action_sig.instructions + + def test_max_time_triggers_extract_fallback(self): + """Test that exceeding max_time triggers extract fallback (not exception).""" + import time + + mock = MockInterpreter(responses=[ + "exploring...", + "still exploring...", + ]) + # Set max_time to 0 so it's already exceeded on first check + rlm = RLM("query -> answer", max_iterations=5, max_time=0.0, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "timeout_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "timeout_fallback" + assert result.final_reasoning == "Extract forced final output" + + def test_max_time_none_no_timeout(self): + """Test that max_time=None means no time checking.""" + mock = MockInterpreter(responses=[FinalOutput({"answer": "42"})]) + rlm = RLM("query -> answer", max_iterations=5, max_time=None, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("42")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "42" + + def test_budget_iteration_updates_via_tool(self): + """Test that budget() reports decreasing iterations across a forward() run.""" + budget_reports = [] + + class BudgetCapturingInterpreter(MockInterpreter): + def __init__(self): + super().__init__(responses=["output1", "output2", FinalOutput({"answer": "done"})]) + + def execute(self, code, variables=None): + # Call the budget tool if available + if "budget" in self.tools: + budget_reports.append(self.tools["budget"]()) + return super().execute(code, variables) + + mock_interp = BudgetCapturingInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=mock_interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Step 1", "code": "print('a')"}, + {"reasoning": "Step 2", "code": "print('b')"}, + {"reasoning": "Step 3", "code": 'SUBMIT("done")'}, + ]) + + result = rlm(query="test") + assert result.answer == "done" + # We got 3 iterations (0, 1, 2), should have 3 budget reports + assert len(budget_reports) == 3 + # Each report should show decreasing remaining iterations + for report in budget_reports: + assert "Iterations:" in report + assert "LLM calls:" in report + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 514978918a5e5f3fe6fe942e96295de4b1ce618e Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 12:03:52 -0500 Subject: [PATCH 06/15] feat(rlm): add max_cost parameter with real cost tracking via litellm Extends the budget system with actual dollar cost tracking: - Add max_cost parameter: optional dollar limit per forward() call. Tracked via litellm's per-call cost reporting in lm.history entries. Gracefully falls back to extract when exceeded. - budget() now reports: iterations, LLM calls, time, cost, and tokens. Cost is computed by summing lm.history entries added since tool creation (snapshot offsets at start of each forward() call). - Cost enforcement in both forward() and aforward() loops. - Expose _get_cost_and_tokens via execution_state dict so the forward() loop can check cost without reaching into tool internals. 5 new tests for max_cost init, budget cost display, cost fallback. Co-Authored-By: Claude Opus 4.6 --- dspy/predict/rlm.py | 68 ++++++++++++++++++++++++++++++++++++++- tests/predict/test_rlm.py | 49 ++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 652d73fb2b..56b8e0b55f 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -123,6 +123,7 @@ def __init__( max_llm_calls: int = 50, max_output_chars: int = 10_000, max_time: float | None = None, + max_cost: float | None = None, verbose: bool = False, tools: list[Callable] | None = None, sub_lm: dspy.LM | None = None, @@ -138,6 +139,9 @@ def __init__( max_output_chars: Maximum characters to include from REPL output. max_time: Maximum wall-clock seconds per forward() call. None means no limit. The agent can check remaining time via budget(). + max_cost: Maximum dollar cost per forward() call. None means no limit. + Tracked via litellm's per-call cost reporting. The agent can + check remaining cost via budget(). verbose: Whether to log detailed execution info. tools: List of tool functions or dspy.Tool objects callable from interpreter code. Built-in tools: llm_query(prompt), llm_query_batched(prompts). @@ -154,6 +158,7 @@ def __init__( self.max_llm_calls = max_llm_calls self.max_output_chars = max_output_chars self.max_time = max_time + self.max_cost = max_cost self.verbose = verbose self.sub_lm = sub_lm self.sub_lms = sub_lms or {} @@ -247,6 +252,18 @@ def _make_llm_tools( named_lms = self.sub_lms _media_registry = media_registry or {} + # Snapshot LM history lengths for cost tracking. + # We'll sum cost from entries added after these offsets. + _all_lms: list[dspy.LM] = [] + if default_lm is not None: + _all_lms.append(default_lm) + for lm_inst in named_lms.values(): + if lm_inst not in _all_lms: + _all_lms.append(lm_inst) + if not _all_lms and dspy.settings.lm is not None: + _all_lms.append(dspy.settings.lm) + _history_offsets = {id(lm_inst): len(lm_inst.history) for lm_inst in _all_lms} + def _resolve_lm(model: str | None = None) -> dspy.LM: """Resolve a model name to an LM instance. @@ -371,9 +388,24 @@ def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = max_iterations = self.max_iterations max_llm_calls = self.max_llm_calls max_time = self.max_time + max_cost = self.max_cost + + def _get_cost_and_tokens() -> tuple[float, int]: + """Sum cost and tokens from LM history entries added since tool creation.""" + total_cost = 0.0 + total_tokens = 0 + for lm_inst in _all_lms: + offset = _history_offsets.get(id(lm_inst), 0) + for entry in lm_inst.history[offset:]: + cost = entry.get("cost") + if cost is not None: + total_cost += cost + usage = entry.get("usage", {}) + total_tokens += usage.get("total_tokens", 0) + return total_cost, total_tokens def budget() -> str: - """Check remaining execution budget: iterations, LLM calls, and time. + """Check remaining execution budget: iterations, LLM calls, time, and cost. Returns a human-readable summary of remaining resources. """ @@ -398,8 +430,18 @@ def budget() -> str: elapsed = _time.monotonic() - start_time parts.append(f"Time: no limit ({elapsed:.1f}s elapsed)") + cost_spent, tokens_used = _get_cost_and_tokens() + if max_cost is not None: + remaining_cost = max(0.0, max_cost - cost_spent) + parts.append(f"Cost: ${remaining_cost:.4f}/${max_cost:.4f} remaining (${cost_spent:.4f} spent, {tokens_used:,} tokens)") + elif cost_spent > 0: + parts.append(f"Cost: no limit (${cost_spent:.4f} spent, {tokens_used:,} tokens)") + return " | ".join(parts) + # Expose cost tracker to forward() for budget enforcement + _execution_state["_get_cost_and_tokens"] = _get_cost_and_tokens + tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched, "budget": budget} if _media_registry: tools["llm_query_with_media"] = llm_query_with_media @@ -793,6 +835,18 @@ def forward(self, **input_args) -> Prediction: ) return self._extract_fallback(variables, history, output_field_names) + # Check cost budget before starting iteration + if self.max_cost is not None: + get_cost = execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + if cost_spent > self.max_cost: + logger.warning( + f"RLM cost budget exceeded (${cost_spent:.4f} > ${self.max_cost:.4f}) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return self._extract_fallback(variables, history, output_field_names) + result: Prediction | REPLHistory = self._execute_iteration( repl, variables, history, iteration, input_args, output_field_names ) @@ -897,6 +951,18 @@ async def aforward(self, **input_args) -> Prediction: ) return await self._aextract_fallback(variables, history, output_field_names) + # Check cost budget before starting iteration + if self.max_cost is not None: + get_cost = execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + if cost_spent > self.max_cost: + logger.warning( + f"RLM cost budget exceeded (${cost_spent:.4f} > ${self.max_cost:.4f}) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return await self._aextract_fallback(variables, history, output_field_names) + result = await self._aexecute_iteration( repl, variables, history, iteration, input_args, output_field_names ) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 00f6ab7b6e..296b965bb7 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1634,5 +1634,54 @@ def execute(self, code, variables=None): assert "LLM calls:" in report + def test_max_cost_initialization(self): + """Test that max_cost is stored on the RLM instance.""" + rlm = RLM("query -> answer", max_cost=0.10) + assert rlm.max_cost == 0.10 + + def test_max_cost_default_none(self): + """Test that max_cost defaults to None (no limit).""" + rlm = RLM("query -> answer") + assert rlm.max_cost is None + + def test_budget_shows_cost_when_max_cost_set(self): + """Test that budget() includes cost info when max_cost is configured.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, max_cost=0.50) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "Cost:" in result + assert "$0.50" in result + + def test_budget_no_cost_when_max_cost_none_and_no_spending(self): + """Test that budget() omits cost when max_cost is None and nothing spent.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + # No cost tracking when max_cost is None and no LM history entries with cost + assert "Cost:" not in result or "no limit" in result + + def test_max_cost_zero_triggers_immediate_fallback(self): + """Test that max_cost=0 triggers extract fallback immediately.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + # Give the mock LM a history with a cost entry + mock_lm.history = [{"cost": 0.001, "usage": {"total_tokens": 100}}] + + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_cost=0.0, sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "cost_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "cost_fallback" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 65f19fc576c836268be661216bbdd711700161c4 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 12:19:36 -0500 Subject: [PATCH 07/15] =?UTF-8?q?feat(rlm):=20improve=20cost=20tracking=20?= =?UTF-8?q?=E2=80=94=20BYOK=20upstream=20cost,=20budget=20warnings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Sum both provider cost (litellm response_cost) and upstream inference cost (usage.cost_details.upstream_inference_cost) for accurate BYOK tracking (e.g. OpenRouter BYOK → Vertex with 5% markup) - Add budget warnings when any resource drops below 20% remaining: iterations, LLM calls, time, or cost - budget() output now prefixed with "⚠ LOW: ..." when resources are low Co-Authored-By: Claude Opus 4.6 --- dspy/predict/rlm.py | 37 +++++++++++++--- tests/predict/test_rlm.py | 90 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 56b8e0b55f..821f53577f 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -391,23 +391,37 @@ def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = max_cost = self.max_cost def _get_cost_and_tokens() -> tuple[float, int]: - """Sum cost and tokens from LM history entries added since tool creation.""" + """Sum cost and tokens from LM history entries added since tool creation. + + Aggregates both the provider cost (litellm response_cost) and the upstream + inference cost (e.g. OpenRouter BYOK → Vertex). For BYOK providers where + response_cost is 0, the upstream cost is the actual charge. + """ total_cost = 0.0 total_tokens = 0 for lm_inst in _all_lms: offset = _history_offsets.get(id(lm_inst), 0) for entry in lm_inst.history[offset:]: + entry_cost = 0.0 cost = entry.get("cost") - if cost is not None: - total_cost += cost + if cost: + entry_cost += cost usage = entry.get("usage", {}) - total_tokens += usage.get("total_tokens", 0) + if isinstance(usage, dict): + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + upstream = cost_details.get("upstream_inference_cost") + if upstream: + entry_cost += upstream + total_cost += entry_cost + total_tokens += usage.get("total_tokens", 0) if isinstance(usage, dict) else 0 return total_cost, total_tokens def budget() -> str: """Check remaining execution budget: iterations, LLM calls, time, and cost. Returns a human-readable summary of remaining resources. + Includes warnings when any resource drops below 20% remaining. """ with lock: calls_used = state["call_count"] @@ -416,16 +430,24 @@ def budget() -> str: remaining_iterations = max_iterations - iteration - 1 # -1 for current remaining_calls = max_llm_calls - calls_used + warnings = [] parts = [ f"Iterations: {remaining_iterations}/{max_iterations} remaining", f"LLM calls: {remaining_calls}/{max_llm_calls} remaining", ] + if remaining_iterations <= max(1, max_iterations * 0.2): + warnings.append(f"iterations ({remaining_iterations} left)") + if remaining_calls <= max(1, max_llm_calls * 0.2): + warnings.append(f"LLM calls ({remaining_calls} left)") + start_time = _execution_state.get("start_time") if max_time is not None and start_time is not None: elapsed = _time.monotonic() - start_time remaining_time = max(0.0, max_time - elapsed) parts.append(f"Time: {remaining_time:.1f}s/{max_time:.1f}s remaining ({elapsed:.1f}s elapsed)") + if remaining_time <= max_time * 0.2: + warnings.append(f"time ({remaining_time:.0f}s left)") elif start_time is not None: elapsed = _time.monotonic() - start_time parts.append(f"Time: no limit ({elapsed:.1f}s elapsed)") @@ -434,10 +456,15 @@ def budget() -> str: if max_cost is not None: remaining_cost = max(0.0, max_cost - cost_spent) parts.append(f"Cost: ${remaining_cost:.4f}/${max_cost:.4f} remaining (${cost_spent:.4f} spent, {tokens_used:,} tokens)") + if remaining_cost <= max_cost * 0.2: + warnings.append(f"cost (${remaining_cost:.4f} left)") elif cost_spent > 0: parts.append(f"Cost: no limit (${cost_spent:.4f} spent, {tokens_used:,} tokens)") - return " | ".join(parts) + result = " | ".join(parts) + if warnings: + result = f"⚠ LOW: {', '.join(warnings)}. Wrap up soon! | " + result + return result # Expose cost tracker to forward() for budget enforcement _execution_state["_get_cost_and_tokens"] = _get_cost_and_tokens diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 296b965bb7..4f990d2fe4 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1683,5 +1683,95 @@ def test_max_cost_zero_triggers_immediate_fallback(self): assert result.answer == "cost_fallback" + def test_byok_cost_upstream_inference_cost(self): + """Test cost tracking includes usage.cost_details.upstream_inference_cost for BYOK.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + # Start with empty history — entries added after tool creation + mock_lm.history = [] + + rlm = RLM("query -> answer", max_cost=1.0, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + + # Simulate BYOK responses arriving after tool creation + mock_lm.history.extend([ + { + "cost": 0, + "usage": { + "total_tokens": 500, + "is_byok": True, + "cost_details": {"upstream_inference_cost": 0.0025}, + }, + }, + { + "cost": None, + "usage": { + "total_tokens": 300, + "cost_details": {"upstream_inference_cost": 0.0015}, + }, + }, + ]) + + result = tools["budget"]() + # Should pick up the upstream_inference_cost values: 0.0025 + 0.0015 + assert "$0.0040" in result + assert "800 tokens" in result # 500 + 300 + + def test_cost_sums_provider_and_upstream(self): + """Test cost tracking sums both provider cost and upstream inference cost.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [] + + rlm = RLM("query -> answer", max_cost=1.0, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + + # Non-BYOK: both provider cost and upstream cost are nonzero + mock_lm.history.append({ + "cost": 0.01, + "usage": { + "total_tokens": 1000, + "cost_details": {"upstream_inference_cost": 0.005}, + }, + }) + + result = tools["budget"]() + # Sum of 0.01 + 0.005 = 0.015 + assert "$0.0150" in result + + def test_budget_warning_low_iterations(self): + """Test that budget() shows warning when iterations are low.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=50) + # iteration=4 means 0 remaining (5 - 4 - 1) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 4} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" in result + assert "iterations" in result + + def test_budget_warning_low_time(self): + """Test that budget() shows warning when time is running low.""" + import time + rlm = RLM("query -> answer", max_iterations=20, max_llm_calls=50, max_time=10.0) + # Start time 9 seconds ago — only 1s remaining (10%) + execution_state = {"start_time": time.monotonic() - 9.0, "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" in result + assert "time" in result + + def test_budget_no_warning_when_plenty_remaining(self): + """Test that budget() has no warning when resources are plentiful.""" + rlm = RLM("query -> answer", max_iterations=20, max_llm_calls=50, max_time=60.0) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" not in result + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 82c22cbd2d585f0a0b58767c29114430b0f178a6 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 14:56:25 -0500 Subject: [PATCH 08/15] fix: lint cleanup (ruff) + GEPA bootstrap_trace resilience for RLM failures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename _SubmitCalled → _SubmitCalledError (N818) - Add noqa for SUBMIT() function name (N802, intentional API) - Use X | Y union syntax in isinstance calls (UP038) - Remove unused imports and variables (F401, F841) - Sort imports (I001) - Add zip strict=False (B905) - Catch non-parse exceptions in bootstrap_trace so GEPA can reflect on partial traces from RLM timeout/crash/cost overrun --- dspy/primitives/local_interpreter.py | 10 +++++----- dspy/primitives/python_interpreter.py | 10 +++++----- dspy/teleprompt/bootstrap_trace.py | 7 +++++++ tests/predict/test_rlm.py | 2 -- tests/primitives/test_local_interpreter.py | 1 - 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dspy/primitives/local_interpreter.py b/dspy/primitives/local_interpreter.py index 0a703b3ca5..9ab282b59d 100644 --- a/dspy/primitives/local_interpreter.py +++ b/dspy/primitives/local_interpreter.py @@ -26,7 +26,7 @@ from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput -class _SubmitCalled(Exception): +class _SubmitCalledError(Exception): """Internal signal raised when SUBMIT() is called in user code.""" def __init__(self, output: Any): self.output = output @@ -117,7 +117,7 @@ def execute( output_fields = self.output_fields or [] field_names = [f["name"] for f in output_fields] - def SUBMIT(*args, **kwargs): + def SUBMIT(*args, **kwargs): # noqa: N802 if not args and not kwargs: raise ValueError("SUBMIT requires at least one argument") if args and kwargs: @@ -131,12 +131,12 @@ def SUBMIT(*args, **kwargs): f"SUBMIT() takes {len(field_names)} positional argument(s) " f"({expected}), but {len(args)} were given" ) - output = dict(zip(field_names, args)) + output = dict(zip(field_names, args, strict=False)) elif len(args) == 1: output = {"output": args[0]} else: output = {"output": args} - raise _SubmitCalled(output) + raise _SubmitCalledError(output) self._namespace["SUBMIT"] = SUBMIT @@ -147,7 +147,7 @@ def SUBMIT(*args, **kwargs): try: exec(code, self._namespace) - except _SubmitCalled as e: + except _SubmitCalledError as e: return FinalOutput(e.output) except SyntaxError: raise diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index df61180b3e..dc1d919d44 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -343,7 +343,7 @@ def _handle_tool_call(self, request: dict) -> None: if tool_name not in self.tools: raise CodeInterpreterError(f"Unknown tool: {tool_name}") result = self.tools[tool_name](*args, **kwargs) - is_json = isinstance(result, (list, dict)) + is_json = isinstance(result, list | dict) response = _jsonrpc_result( {"value": json.dumps(result) if is_json else str(result or ""), "type": "json" if is_json else "string"}, request_id @@ -411,13 +411,13 @@ def _health_check(self) -> None: def _to_json_compatible(self, value: Any) -> Any: """Recursively convert Python values to JSON-compatible types.""" - if value is None or isinstance(value, (str, int, float, bool)): + if value is None or isinstance(value, str | int | float | bool): return value elif _is_media_type(value): return _media_descriptor(value) elif isinstance(value, dict): return {k: self._to_json_compatible(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): + elif isinstance(value, list | tuple): return [self._to_json_compatible(v) for v in value] elif isinstance(value, set): try: @@ -465,9 +465,9 @@ def _serialize_value(self, value: Any) -> str: elif isinstance(value, bool): # Must check bool before int since bool is a subclass of int return "True" if value else "False" - elif isinstance(value, (int, float)): + elif isinstance(value, int | float): return str(value) - elif isinstance(value, (list, tuple)): + elif isinstance(value, list | tuple): # Tuples become lists for JSON compatibility items = ", ".join(self._serialize_value(item) for item in value) return f"[{items}]" diff --git a/dspy/teleprompt/bootstrap_trace.py b/dspy/teleprompt/bootstrap_trace.py index 2f0cc60cd1..80c1c8d660 100644 --- a/dspy/teleprompt/bootstrap_trace.py +++ b/dspy/teleprompt/bootstrap_trace.py @@ -107,6 +107,13 @@ def patched_forward(program_to_use: Module, **kwargs): ) return failed_pred, trace + except Exception as e: + # Catch non-parse failures (e.g. RLM timeout, interpreter crash, + # cost overrun). Preserve whatever partial trace was captured so + # GEPA can still reflect on the calls that happened before failure. + trace = dspy.settings.trace.copy() + failed_pred = FailedPrediction(completion_text=str(e)) + return failed_pred, trace program.forward = MethodType(patched_forward, program) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 4f990d2fe4..d9205602ec 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1521,7 +1521,6 @@ def test_budget_reflects_iteration(self): def test_budget_reflects_llm_calls(self): """Test that budget() shows correct remaining LLM calls after usage.""" - rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} from unittest.mock import MagicMock @@ -1572,7 +1571,6 @@ def test_budget_in_action_instructions(self): def test_max_time_triggers_extract_fallback(self): """Test that exceeding max_time triggers extract fallback (not exception).""" - import time mock = MockInterpreter(responses=[ "exploring...", diff --git a/tests/primitives/test_local_interpreter.py b/tests/primitives/test_local_interpreter.py index 664d60a7a5..7f5c864ca5 100644 --- a/tests/primitives/test_local_interpreter.py +++ b/tests/primitives/test_local_interpreter.py @@ -5,7 +5,6 @@ from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput from dspy.primitives.local_interpreter import LocalInterpreter - # ============================================================================= # Basic Execution # ============================================================================= From 0c32c9408e6578ea2c0e0f5c1dd7301cf2dc59ee Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 15:04:21 -0500 Subject: [PATCH 09/15] test: add integration + edge case tests for all new RLM features - RLM + LocalInterpreter integration (7 tests): forward, state persistence, tool access, stdlib imports, error recovery, max_time, aforward - Media content construction (4 tests): audio/image content parts sent to LM, multiple media objects, model routing with media - max_cost mid-run fallback (1 test): cost exceeded during iteration triggers extract fallback, not crash - Async budget/time/cost (3 tests): aforward respects max_time, max_cost, budget() tool works in async path - bootstrap_trace resilience (3 tests): RuntimeError captured as FailedPrediction, partial trace preserved on cost overrun, KeyboardInterrupt not swallowed - LocalInterpreter output_fields setter (2 tests): post-init configuration, default single-output wrapping --- tests/predict/test_rlm.py | 579 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 579 insertions(+) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index d9205602ec..e8cf0605ec 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1771,5 +1771,584 @@ def test_budget_no_warning_when_plenty_remaining(self): assert "LOW" not in result +# ============================================================================ +# Integration Tests: RLM + LocalInterpreter +# ============================================================================ + + +class TestRLMWithLocalInterpreter: + """Integration tests proving RLM and LocalInterpreter work together end-to-end.""" + + def test_basic_forward(self): + """Test RLM forward() with LocalInterpreter produces a result.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("hello")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "hello" + + def test_state_persists_across_iterations(self): + """Test that variables set in one iteration survive to the next.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Set a variable", "code": "x = 42\nprint(x)"}, + {"reasoning": "Use it", "code": "SUBMIT(str(x * 2))"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "84" + + def test_tools_accessible_in_code(self): + """Test that llm_query and budget tools are callable from LocalInterpreter code.""" + from unittest.mock import MagicMock + + from dspy.primitives.local_interpreter import LocalInterpreter + + mock_lm = MagicMock(return_value=["mocked response"]) + mock_lm.history = [] + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, sub_lm=mock_lm, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Use tools", "code": 'b = budget()\nresult = llm_query("hi")\nprint(result)'}, + {"reasoning": "Submit", "code": "SUBMIT(result)"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "mocked response" + + def test_stdlib_imports_work(self): + """Test that LocalInterpreter allows stdlib imports inside RLM.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Use stdlib", "code": 'import json\nSUBMIT(json.dumps({"a": 1}))'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == '{"a": 1}' + + def test_error_recovery(self): + """Test that a runtime error in one iteration doesn't kill the session.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "This will error", "code": "1 / 0"}, + {"reasoning": "Recover", "code": 'SUBMIT("recovered")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "recovered" + + def test_max_time_with_local_interpreter(self): + """Test that max_time budget enforcement works with LocalInterpreter.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=10, max_time=0.0, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "timeout_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "timeout_fallback" + + @pytest.mark.asyncio + async def test_aforward_with_local_interpreter(self): + """Test RLM aforward() works with LocalInterpreter.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("async_hello")'}, + ]) + + result = await rlm.aforward(query="test") + assert result.answer == "async_hello" + + +# ============================================================================ +# Tests: llm_query_with_media content construction +# ============================================================================ + + +class TestMediaContentConstruction: + """Test that llm_query_with_media builds correct multimodal content for the LM.""" + + def test_media_content_parts_sent_to_lm(self): + """Test that llm_query_with_media sends multimodal content parts to the LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_lm = MagicMock(return_value=["transcription result"]) + mock_lm.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + media_registry={"my_audio": audio}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("describe this audio", "my_audio") + assert result == "transcription result" + + # Verify the LM was called with messages containing multimodal content + mock_lm.assert_called_once() + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + assert messages is not None + assert len(messages) == 1 + assert messages[0]["role"] == "user" + content = messages[0]["content"] + assert isinstance(content, list) + # First part should be text + assert content[0]["type"] == "text" + assert content[0]["text"] == "describe this audio" + # Remaining parts should be from audio.format() + assert len(content) > 1 + + def test_image_content_parts_sent_to_lm(self): + """Test that llm_query_with_media sends image content parts to the LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.image import Image + + mock_lm = MagicMock(return_value=["a cat sitting on a mat"]) + mock_lm.history = [] + + image = Image(url="https://example.com/cat.jpg") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + media_registry={"my_image": image}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("what is in this image?", "my_image") + assert result == "a cat sitting on a mat" + + # Verify multimodal message structure + mock_lm.assert_called_once() + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + content = messages[0]["content"] + assert content[0]["type"] == "text" + assert len(content) > 1 + + def test_multiple_media_objects_in_one_call(self): + """Test llm_query_with_media with multiple media variables.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + from dspy.adapters.types.image import Image + + mock_lm = MagicMock(return_value=["combined analysis"]) + mock_lm.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + image = Image(url="https://example.com/photo.jpg") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + media_registry={"audio_in": audio, "image_in": image}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("analyze both", "audio_in", "image_in") + assert result == "combined analysis" + + # Verify both media objects were included + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + content = messages[0]["content"] + # text + audio parts + image parts + assert len(content) >= 3 + + def test_media_with_model_routing(self): + """Test llm_query_with_media routes to named model when specified.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_default = MagicMock(return_value=["default"]) + mock_default.history = [] + mock_pro = MagicMock(return_value=["pro result"]) + mock_pro.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, + sub_lm=mock_default, sub_lms={"pro": mock_pro}) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + media_registry={"audio": audio}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("transcribe", "audio", model="pro") + assert result == "pro result" + mock_pro.assert_called_once() + mock_default.assert_not_called() + + +# ============================================================================ +# Tests: max_cost mid-run fallback +# ============================================================================ + + +class TestMaxCostMidRunFallback: + """Test that max_cost triggers extract fallback during a multi-iteration run.""" + + def test_cost_exceeded_mid_run_triggers_fallback(self): + """Test that exceeding max_cost mid-run triggers extract fallback, not crash.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [] + + mock = MockInterpreter(responses=[ + "first iteration output", + "second iteration output", # cost exceeded before this + ]) + rlm = RLM("query -> answer", max_iterations=10, max_cost=0.05, + sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + {"reasoning": "More", "code": "print('more')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "cost_fallback"}, + ]) + + # Simulate cost appearing after first iteration + # The forward loop checks cost at the start of each iteration + # After iteration 0, we inject cost into LM history + original_execute = mock.execute + + def execute_with_cost_injection(code, variables=None): + result = original_execute(code, variables) + # After first execute, inject cost exceeding budget + if mock.call_count == 1: + mock_lm.history.append({ + "cost": 0.10, # exceeds max_cost=0.05 + "usage": {"total_tokens": 5000}, + }) + return result + + mock.execute = execute_with_cost_injection + + result = rlm.forward(query="test") + assert result.answer == "cost_fallback" + assert result.final_reasoning == "Extract forced final output" + + +# ============================================================================ +# Tests: Async budget/time/cost +# ============================================================================ + + +class TestAsyncBudgetTimeCost: + """Test that budget, max_time, and max_cost work correctly in aforward().""" + + @pytest.mark.asyncio + async def test_aforward_max_time_triggers_fallback(self): + """Test that aforward() respects max_time and triggers extract fallback.""" + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_time=0.0, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ], async_mode=True) + rlm.extract = make_mock_predictor([ + {"answer": "async_timeout_fallback"}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "async_timeout_fallback" + + @pytest.mark.asyncio + async def test_aforward_max_cost_triggers_fallback(self): + """Test that aforward() respects max_cost and triggers extract fallback.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [{"cost": 1.0, "usage": {"total_tokens": 50000}}] + + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_cost=0.0, + sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ], async_mode=True) + rlm.extract = make_mock_predictor([ + {"answer": "async_cost_fallback"}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "async_cost_fallback" + + @pytest.mark.asyncio + async def test_aforward_budget_tool_works(self): + """Test that budget() tool is accessible in aforward() via MockInterpreter.""" + budget_reports = [] + + class BudgetCapturingInterpreter(MockInterpreter): + def __init__(self): + super().__init__(responses=["output", FinalOutput({"answer": "done"})]) + + def execute(self, code, variables=None): + if "budget" in self.tools: + budget_reports.append(self.tools["budget"]()) + return super().execute(code, variables) + + mock_interp = BudgetCapturingInterpreter() + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=30, + max_time=60.0, interpreter=mock_interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Step 1", "code": "print('a')"}, + {"reasoning": "Done", "code": 'SUBMIT("done")'}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "done" + assert len(budget_reports) >= 1 + assert "Iterations:" in budget_reports[0] + assert "Time:" in budget_reports[0] + + +# ============================================================================ +# Tests: bootstrap_trace resilience for non-parse exceptions +# ============================================================================ + + +class TestBootstrapTraceResilience: + """Test that bootstrap_trace_data handles non-parse exceptions gracefully.""" + + def test_runtime_error_captured_as_failed_prediction(self): + """Test that a RuntimeError from forward() is captured, not propagated.""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import FailedPrediction, bootstrap_trace_data + + class CrashingModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + raise RuntimeError("RLM timeout exceeded") + + program = CrashingModule() + dataset = [ + dspy.Example(query="test1").with_inputs("query"), + dspy.Example(query="test2").with_inputs("query"), + ] + + def metric(example, prediction, trace=None): + if isinstance(prediction, FailedPrediction): + return 0.0 + return 1.0 + + # Mock the Evaluate class to directly call the program + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + self.failure_score = kwargs.get("failure_score", 0) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + score = metric(example, prediction) if metric else None + results.append((example, prediction, score)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + results = bootstrap_trace_data( + program=program, + dataset=dataset, + metric=metric, + raise_on_error=False, + ) + + assert len(results) == 2 + for result in results: + pred = result["prediction"] + assert isinstance(pred, FailedPrediction) + assert "RLM timeout exceeded" in pred.completion_text + # Trace should be preserved (even if empty, it shouldn't be None) + assert isinstance(result["trace"], list) + + def test_cost_overrun_captured_as_failed_prediction(self): + """Test that a cost overrun exception is captured with partial trace.""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import FailedPrediction, bootstrap_trace_data + + class CostOverrunModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + # Simulate partial work before cost overrun + # Add something to the trace before crashing + dspy.settings.trace.append( + (self.predictor, kwargs, dspy.Prediction(answer="partial")) + ) + raise RuntimeError("Cost budget exceeded ($0.55 > $0.50)") + + program = CostOverrunModule() + dataset = [dspy.Example(query="test").with_inputs("query")] + + def metric(example, prediction, trace=None): + return 0.0 + + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + score = metric(example, prediction) if metric else None + results.append((example, prediction, score)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + results = bootstrap_trace_data( + program=program, + dataset=dataset, + metric=metric, + raise_on_error=False, + ) + + assert len(results) == 1 + pred = results[0]["prediction"] + assert isinstance(pred, FailedPrediction) + assert "Cost budget exceeded" in pred.completion_text + # The partial trace from before the crash should be preserved + trace = results[0]["trace"] + assert isinstance(trace, list) + assert len(trace) == 1 # the one entry we appended before crashing + + def test_keyboard_interrupt_not_swallowed(self): + """Test that KeyboardInterrupt is NOT caught (it's BaseException, not Exception).""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import bootstrap_trace_data + + class InterruptingModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + raise KeyboardInterrupt() + + program = InterruptingModule() + dataset = [dspy.Example(query="test").with_inputs("query")] + + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + results.append((example, prediction, None)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + with pytest.raises(KeyboardInterrupt): + bootstrap_trace_data( + program=program, + dataset=dataset, + raise_on_error=False, + ) + + +# ============================================================================ +# Tests: LocalInterpreter output_fields via setter +# ============================================================================ + + +class TestLocalInterpreterOutputFieldsSetter: + """Test LocalInterpreter output_fields configuration paths.""" + + def test_output_fields_set_after_init(self): + """Test that output_fields can be set after construction and SUBMIT uses them.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + interp.output_fields = [{"name": "answer"}, {"name": "confidence"}] + interp.start() + + result = interp.execute('SUBMIT("hello", "high")') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "hello", "confidence": "high"} + + def test_output_fields_none_defaults_to_single_output(self): + """Test that without output_fields, single-arg SUBMIT wraps in 'output' key.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + interp.start() + + result = interp.execute('SUBMIT("hello")') + assert isinstance(result, FinalOutput) + assert result.output == {"output": "hello"} + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 9b5cb29596ed6af570904da381f10816932fce38 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 16:26:46 -0500 Subject: [PATCH 10/15] feat(rlm): add depth>1 recursive subcalls via LocalInterpreter When max_depth > 1, llm_query() spawns a child RLM with its own LocalInterpreter REPL instead of making a plain LM completion call. Each child manages its own iteration loop, tools, and budget. Key design (mirrors vanilla RLM PR #84 pattern): - RLM._subcall() method spawns child RLM(signature='prompt -> response') - Child gets fresh LocalInterpreter (isolated namespace, stdout capture) - Budget propagation: remaining time/cost passed to child - model= param selects child's sub_lm via resolve_lm closure - User tools inherited by children - Interpreter cleanup in finally block - At leaf depth (depth >= max_depth - 1), falls back to plain LM call - llm_query_batched runs children sequentially when recursive Params added to RLM.__init__: - depth: int = 0 (current recursion depth, 0-indexed) - max_depth: int = 1 (max recursion depth, default=no recursion) Tests: 24 new tests adapted from vanilla RLM test_subcall.py covering parameter propagation, budget enforcement, interpreter isolation, and end-to-end execution with DummyLM. --- dspy/predict/rlm.py | 108 +++++++++ tests/predict/test_rlm.py | 485 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 593 insertions(+) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 821f53577f..9a53147bc7 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -129,6 +129,8 @@ def __init__( sub_lm: dspy.LM | None = None, sub_lms: dict[str, dspy.LM] | None = None, interpreter: CodeInterpreter | None = None, + depth: int = 0, + max_depth: int = 1, ): """ Args: @@ -151,6 +153,10 @@ def __init__( a specific model via llm_query(prompt, model="name"). When model is None, falls back to sub_lm, then dspy.settings.lm. interpreter: CodeInterpreter implementation to use. Defaults to PythonInterpreter. + depth: Current recursion depth (0-indexed). Used internally when spawning child RLMs. + max_depth: Maximum recursion depth. When depth < max_depth - 1, llm_query spawns + a child RLM with its own REPL (LocalInterpreter). At leaf depth, falls + back to plain LM completion. Default 1 means no recursion (current behavior). """ super().__init__() self.signature = ensure_signature(signature) @@ -163,6 +169,8 @@ def __init__( self.sub_lm = sub_lm self.sub_lms = sub_lms or {} self._interpreter = interpreter + self.depth = depth + self.max_depth = max_depth self._user_tools = self._normalize_tools(tools) self._validate_tools(self._user_tools) @@ -318,9 +326,19 @@ def _query_lm_multimodal(prompt: str, media_objects: list, model: str | None = N return item return str(response) + # Determine if llm_query should spawn recursive child RLMs + _use_recursive = self.depth < self.max_depth - 1 + + def _do_subcall(prompt: str, model: str | None = None) -> str: + """Route subcall through the instance method, passing execution state.""" + return self._subcall(prompt, model=model, execution_state=_execution_state, resolve_lm=_resolve_lm) + def llm_query(prompt: str, model: str | None = None) -> str: """Query the LLM with a prompt string. + At depth < max_depth - 1, spawns a child RLM with its own REPL. + At leaf depth, makes a plain LM completion call. + Args: prompt: The text prompt for the LLM. model: Optional model name from sub_lms to use. Defaults to the default sub_lm. @@ -328,11 +346,16 @@ def llm_query(prompt: str, model: str | None = None) -> str: if not prompt: raise ValueError("prompt cannot be empty") _check_and_increment(1) + if _use_recursive: + return _do_subcall(prompt, model=model) return _query_lm(prompt, model=model) def llm_query_batched(prompts: list[str], model: str | None = None) -> list[str]: """Query the LLM with multiple prompts concurrently. + At depth < max_depth - 1, each prompt spawns a child RLM (sequential). + At leaf depth, prompts are sent as plain LM calls (parallel). + Args: prompts: List of prompts to send to the LLM. model: Optional model name from sub_lms to use. Defaults to the default sub_lm. @@ -341,6 +364,16 @@ def llm_query_batched(prompts: list[str], model: str | None = None) -> list[str] return [] _check_and_increment(len(prompts)) + if _use_recursive: + # Sequential: each child spawns its own LocalInterpreter + results = [] + for p in prompts: + try: + results.append(_do_subcall(p, model=model)) + except Exception as e: + results.append(f"[ERROR] {e}") + return results + results: dict[int, str] = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_idx = {executor.submit(_query_lm, p, model): i for i, p in enumerate(prompts)} @@ -654,6 +687,81 @@ def _inject_execution_context(self, interpreter: CodeInterpreter, execution_tool if hasattr(interpreter, "_tools_registered"): interpreter._tools_registered = False + def _subcall( + self, + prompt: str, + model: str | None = None, + execution_state: dict[str, Any] | None = None, + resolve_lm: Callable | None = None, + ) -> str: + """Spawn a child RLM with its own LocalInterpreter REPL. + + Called by llm_query/llm_query_batched when depth < max_depth - 1. + The child gets a fresh REPL and can write code, call llm_query (which + recurses further or falls back to plain LM at leaf depth), and SUBMIT + a response. Mirrors the vanilla RLM's _subcall pattern. + + Args: + prompt: The prompt to pass as the child's input. + model: Optional model name. Selects child's sub_lm via resolve_lm. + execution_state: Parent's mutable execution state (start_time, cost tracker). + resolve_lm: Closure from _make_llm_tools that resolves model name to LM instance. + + Returns: + The child's response string, or an error string on failure. + """ + from dspy.primitives.local_interpreter import LocalInterpreter + + _execution_state = execution_state or {} + + # Calculate remaining time budget for child + remaining_time = None + if self.max_time is not None: + start = _execution_state.get("start_time") + if start is not None: + elapsed = _time.monotonic() - start + remaining_time = max(0.0, self.max_time - elapsed) + if remaining_time <= 0: + return "Error: Time budget exhausted" + + # Calculate remaining cost budget for child + remaining_cost = None + if self.max_cost is not None: + get_cost = _execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + remaining_cost = max(0.0, self.max_cost - cost_spent) + if remaining_cost <= 0: + return "Error: Cost budget exhausted" + + # Resolve child's sub_lm: model param selects which LM the child uses + child_sub_lm = resolve_lm(model) if (model and resolve_lm) else self.sub_lm + + interpreter = LocalInterpreter() + child = RLM( + signature="prompt -> response", + max_iterations=self.max_iterations, + max_llm_calls=self.max_llm_calls, + max_output_chars=self.max_output_chars, + max_time=remaining_time, + max_cost=remaining_cost, + verbose=self.verbose, + tools=list(self._user_tools.values()) if self._user_tools else None, + sub_lm=child_sub_lm, + sub_lms=self.sub_lms, + interpreter=interpreter, + depth=self.depth + 1, + max_depth=self.max_depth, + ) + + try: + result = child(prompt=prompt) + return result.response + except Exception as e: + return f"Error: Child RLM failed - {e}" + finally: + interpreter.shutdown() + @contextmanager def _interpreter_context(self, execution_tools: dict[str, Callable]) -> Iterator[CodeInterpreter]: """Yield interpreter, creating PythonInterpreter if none provided at init.""" diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index e8cf0605ec..eb480bc921 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -2350,5 +2350,490 @@ def test_output_fields_none_defaults_to_single_output(self): assert result.output == {"output": "hello"} +# ============================================================================ +# Depth > 1 Tests: Recursive RLM with LocalInterpreter +# ============================================================================ + +import time as _time + + +class TestSubcallInit: + """Tests for depth/max_depth initialization and routing flag.""" + + def test_depth_max_depth_defaults(self): + """Default depth=0, max_depth=1 means no recursion.""" + rlm = RLM("query -> answer", max_iterations=3) + assert rlm.depth == 0 + assert rlm.max_depth == 1 + + def test_depth_max_depth_stored(self): + """Custom depth and max_depth are stored.""" + rlm = RLM("query -> answer", max_iterations=3, depth=1, max_depth=3) + assert rlm.depth == 1 + assert rlm.max_depth == 3 + + def test_max_depth_1_uses_plain_lm(self): + """With max_depth=1 (default), llm_query does a plain LM call.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["plain response"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=1, sub_lm=mock_lm) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt") + assert result == "plain response" + mock_lm.assert_called_once() + + def test_max_depth_2_at_depth_0_is_recursive(self): + """With max_depth=2, depth=0: _subcall is called (not plain LM).""" + from unittest.mock import MagicMock, patch + + mock_lm = MagicMock(return_value=["should not be called"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, sub_lm=mock_lm) + + # Patch _subcall to capture that it's called instead of plain LM + with patch.object(rlm, "_subcall", return_value="subcall result") as mock_subcall: + tools = rlm._make_llm_tools() + result = tools["llm_query"]("test prompt") + + assert result == "subcall result" + mock_subcall.assert_called_once() + mock_lm.assert_not_called() + + def test_max_depth_2_at_depth_1_is_leaf(self): + """With max_depth=2, depth=1: plain LM call (leaf level).""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["leaf response"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, depth=1, sub_lm=mock_lm) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt") + assert result == "leaf response" + mock_lm.assert_called_once() + + +class TestSubcallTimeBudgetPropagation: + """Tests for max_time propagation to child RLM, adapted from vanilla RLM test_subcall.py.""" + + def test_child_receives_remaining_timeout(self): + """When parent has max_time=60 and 10s elapsed, child should get ~50s.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + # Only capture child inits (signature="prompt -> response") + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=60.0) + execution_state = {"start_time": _time.monotonic() - 10.0, "iteration": 0} + + with patch.object(RLM, "__init__", capturing_init): + # Child will fail (no LM configured) but we capture the kwargs before that + rlm._subcall("test", execution_state=execution_state) + + assert "max_time" in captured_child_kwargs + remaining = captured_child_kwargs["max_time"] + assert 45.0 < remaining < 55.0, f"Expected ~50s remaining, got {remaining}" + + def test_child_receives_none_timeout_when_parent_has_none(self): + """When parent has no max_time, child should also have None.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=None) + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + + assert captured_child_kwargs.get("max_time") is None + + def test_subcall_returns_error_when_time_exhausted(self): + """When time budget is already exhausted, _subcall returns error string.""" + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=10.0) + execution_state = {"start_time": _time.monotonic() - 15.0, "iteration": 0} + + result = rlm._subcall("test", execution_state=execution_state) + assert "Time budget exhausted" in result + + +class TestSubcallCostBudgetPropagation: + """Tests for max_cost propagation to child RLM.""" + + def test_child_receives_remaining_cost(self): + """When parent has max_cost=1.0 and $0.30 spent, child should get ~$0.70.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_cost=1.0) + execution_state = { + "start_time": _time.monotonic(), + "iteration": 0, + "_get_cost_and_tokens": lambda: (0.30, 5000), + } + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test", execution_state=execution_state) + + assert "max_cost" in captured_child_kwargs + remaining = captured_child_kwargs["max_cost"] + assert 0.69 < remaining < 0.71, f"Expected ~$0.70, got {remaining}" + + def test_subcall_returns_error_when_cost_exhausted(self): + """When cost budget is already exhausted, _subcall returns error string.""" + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_cost=0.01) + execution_state = { + "start_time": _time.monotonic(), + "iteration": 0, + "_get_cost_and_tokens": lambda: (0.02, 1000), + } + + result = rlm._subcall("test", execution_state=execution_state) + assert "Cost budget exhausted" in result + + +class TestSubcallModelOverride: + """Tests for model= parameter override in _subcall.""" + + def test_model_override_sets_child_sub_lm(self): + """When model='flash', child's sub_lm should be the flash LM instance.""" + from unittest.mock import MagicMock, patch + + mock_flash = MagicMock(return_value=["flash"]) + mock_pro = MagicMock(return_value=["pro"]) + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, + sub_lms={"flash": mock_flash, "pro": mock_pro}) + + def resolve_lm(model=None): + if model == "flash": + return mock_flash + if model == "pro": + return mock_pro + return None + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test", model="flash", resolve_lm=resolve_lm) + + assert captured_child_kwargs.get("sub_lm") is mock_flash + + def test_no_model_override_uses_parent_sub_lm(self): + """Without model override, child inherits parent's sub_lm.""" + from unittest.mock import MagicMock, patch + + mock_parent_sub = MagicMock(return_value=["parent sub"]) + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, sub_lm=mock_parent_sub) + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + + assert captured_child_kwargs.get("sub_lm") is mock_parent_sub + + +class TestSubcallParameterPropagation: + """Tests for combined parameter propagation to child RLM.""" + + def test_child_inherits_max_iterations(self): + """Child should receive parent's max_iterations.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=15, max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_iterations") == 15 + + def test_child_inherits_max_llm_calls(self): + """Child should receive parent's max_llm_calls.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_llm_calls=25, max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_llm_calls") == 25 + + def test_child_inherits_user_tools(self): + """Child should receive parent's user-provided tools.""" + from unittest.mock import patch + + def my_tool(x: str) -> str: + return x + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2, tools=[my_tool]) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("tools") is not None + assert len(captured["tools"]) == 1 + + def test_child_depth_incremented(self): + """Child should have depth = parent.depth + 1.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=3, depth=0) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("depth") == 1 + + def test_child_inherits_max_depth(self): + """Child should receive same max_depth as parent.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=3) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_depth") == 3 + + def test_child_inherits_sub_lms(self): + """Child should receive parent's sub_lms dict.""" + from unittest.mock import MagicMock, patch + + mock_flash = MagicMock() + mock_pro = MagicMock() + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2, sub_lms={"flash": mock_flash, "pro": mock_pro}) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("sub_lms") == {"flash": mock_flash, "pro": mock_pro} + + +class TestSubcallInterpreterIsolation: + """Tests that child RLM gets an isolated LocalInterpreter.""" + + def test_child_gets_local_interpreter(self): + """Child should receive a LocalInterpreter instance.""" + from unittest.mock import patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert isinstance(captured.get("interpreter"), LocalInterpreter) + + def test_interpreter_shutdown_on_success(self): + """LocalInterpreter.shutdown() is called after successful child completion.""" + from unittest.mock import MagicMock, patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + shutdown_called = [] + _original_shutdown = LocalInterpreter.shutdown + + def tracking_shutdown(self_inner): + shutdown_called.append(True) + _original_shutdown(self_inner) + + with dummy_lm_context([ + {"reasoning": "Done", "code": 'SUBMIT(response="ok")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): + rlm._subcall("test") + + assert len(shutdown_called) >= 1 + + def test_interpreter_shutdown_on_error(self): + """LocalInterpreter.shutdown() is called even when child fails.""" + from unittest.mock import MagicMock, patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + shutdown_called = [] + _original_shutdown = LocalInterpreter.shutdown + + def tracking_shutdown(self_inner): + shutdown_called.append(True) + _original_shutdown(self_inner) + + with dummy_lm_context([ + {"reasoning": "Bad", "code": 'raise Exception("boom")'}, + {"response": "fallback"}, # extract fallback + ]): + rlm = RLM("query -> answer", max_iterations=1, max_depth=2) + with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): + rlm._subcall("test") + + assert len(shutdown_called) >= 1 + + +class TestSubcallE2E: + """End-to-end tests for depth>1 using DummyLM + LocalInterpreter.""" + + def test_depth_2_child_submits_response(self): + """Parent calls llm_query, child runs in LocalInterpreter and SUBMITs.""" + with dummy_lm_context([ + {"reasoning": "Answer directly", "code": 'SUBMIT(response="42")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("What is 6*7?") + assert result == "42" + + def test_depth_2_child_uses_prompt_variable(self): + """Child RLM receives the prompt as a variable it can use in code.""" + with dummy_lm_context([ + {"reasoning": "Use the prompt", "code": 'SUBMIT(response=f"Got: {prompt}")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("hello world") + assert result == "Got: hello world" + + def test_depth_2_child_inherits_tools(self): + """Child can call parent's user-provided tools.""" + call_log = [] + + def my_tool(x: str) -> str: + call_log.append(x) + return f"tool({x})" + + with dummy_lm_context([ + {"reasoning": "Use tool", "code": 'val = my_tool(x="hi")\nSUBMIT(response=val)'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, tools=[my_tool]) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("use tool") + assert result == "tool(hi)" + assert call_log == ["hi"] + + def test_depth_2_child_multi_iteration(self): + """Child RLM can take multiple iterations before SUBMITting.""" + with dummy_lm_context([ + {"reasoning": "Explore first", "code": 'x = len(prompt)\nprint(f"length={x}")'}, + {"reasoning": "Now submit", "code": 'SUBMIT(response=str(x))'}, + ]): + rlm = RLM("query -> answer", max_iterations=5, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("hello") + assert result == "5" + + def test_depth_2_child_error_returns_string(self): + """Child failure returns error string, doesn't crash parent.""" + with dummy_lm_context([ + {"reasoning": "Crash", "code": 'raise RuntimeError("boom")'}, + {"response": "recovered"}, # extract fallback + ]): + rlm = RLM("query -> answer", max_iterations=1, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("test") + # Should get a string back (either "recovered" from extract or error message) + assert isinstance(result, str) + + def test_depth_2_batched_sequential(self): + """llm_query_batched with max_depth=2 runs children sequentially.""" + with dummy_lm_context([ + # Child 1 + {"reasoning": "First", "code": 'SUBMIT(response="a1")'}, + # Child 2 + {"reasoning": "Second", "code": 'SUBMIT(response="a2")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_llm_calls=20) + tools = rlm._make_llm_tools() + results = tools["llm_query_batched"](["q1", "q2"]) + assert results == ["a1", "a2"] + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 5b3b5a9beac75fbaf2fb45ba3612f7d057e2d79c Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 18:19:42 -0500 Subject: [PATCH 11/15] test(rlm): add PythonInterpreter (Deno) parent + depth=2 integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Verifies that when the parent RLM uses PythonInterpreter (Deno sandbox), llm_query with max_depth=2 correctly spawns a child RLM with its own LocalInterpreter. The tool callback crosses the Deno→host JSON-RPC boundary, _subcall runs on the host, and the child operates in an isolated LocalInterpreter namespace. --- tests/predict/test_rlm.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index eb480bc921..c3e4868aab 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -2834,6 +2834,26 @@ def test_depth_2_batched_sequential(self): results = tools["llm_query_batched"](["q1", "q2"]) assert results == ["a1", "a2"] + @pytest.mark.deno + def test_depth_2_python_interpreter_parent(self): + """Parent uses PythonInterpreter (Deno), child uses LocalInterpreter. + + llm_query is a host-side tool callback (JSON-RPC from Deno), so _subcall + runs on the host and creates the child with its own LocalInterpreter. + """ + with dummy_lm_context([ + # Parent iter 1: call llm_query (triggers child RLM) + {"reasoning": "Delegate", "code": 'result = llm_query("compute 2+2")\nprint(result)'}, + # Child iter 1 (runs in LocalInterpreter): SUBMIT response + {"reasoning": "Easy", "code": 'SUBMIT(response="4")'}, + # Parent iter 2: SUBMIT the child's result + {"reasoning": "Done", "code": 'SUBMIT(result)'}, + ]): + # No interpreter= means default PythonInterpreter (Deno sandbox) + rlm = RLM("query -> answer", max_iterations=5, max_depth=2) + result = rlm(query="What is 2+2?") + assert result.answer == "4" + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 6bb9e218d70f0ab9640e814378d780ebac7504e9 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 20:07:40 -0500 Subject: [PATCH 12/15] fix(rlm): add depth/max_depth validation + aforward depth>1 test - Reject max_depth < 1 and depth < 0 with ValueError at init - Add async aforward() test for depth=2 (async parent, sync _subcall) - All 185 tests pass with --deno flag --- dspy/predict/rlm.py | 4 ++++ tests/predict/test_rlm.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 9a53147bc7..e1412b6d20 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -169,6 +169,10 @@ def __init__( self.sub_lm = sub_lm self.sub_lms = sub_lms or {} self._interpreter = interpreter + if max_depth < 1: + raise ValueError(f"max_depth must be >= 1, got {max_depth}") + if depth < 0: + raise ValueError(f"depth must be >= 0, got {depth}") self.depth = depth self.max_depth = max_depth self._user_tools = self._normalize_tools(tools) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index c3e4868aab..2bd7501b2a 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -2412,6 +2412,16 @@ def test_max_depth_2_at_depth_1_is_leaf(self): assert result == "leaf response" mock_lm.assert_called_once() + def test_max_depth_0_raises(self): + """max_depth < 1 should raise ValueError.""" + with pytest.raises(ValueError, match="max_depth must be >= 1"): + RLM("query -> answer", max_depth=0) + + def test_negative_depth_raises(self): + """Negative depth should raise ValueError.""" + with pytest.raises(ValueError, match="depth must be >= 0"): + RLM("query -> answer", depth=-1) + class TestSubcallTimeBudgetPropagation: """Tests for max_time propagation to child RLM, adapted from vanilla RLM test_subcall.py.""" @@ -2834,6 +2844,23 @@ def test_depth_2_batched_sequential(self): results = tools["llm_query_batched"](["q1", "q2"]) assert results == ["a1", "a2"] + @pytest.mark.asyncio + async def test_depth_2_aforward(self): + """aforward() works with max_depth=2 (async parent, sync _subcall).""" + with dummy_lm_context([ + # Parent iter 1: call llm_query + {"reasoning": "Ask", "code": 'result = llm_query("test")\nprint(result)'}, + # Child iter 1: SUBMIT + {"reasoning": "Done", "code": 'SUBMIT(response="async_ok")'}, + # Parent iter 2: SUBMIT + {"reasoning": "Got it", "code": 'SUBMIT(result)'}, + ]): + from dspy.primitives.local_interpreter import LocalInterpreter + rlm = RLM("query -> answer", max_iterations=5, max_depth=2, + interpreter=LocalInterpreter()) + result = await rlm.aforward(query="async test") + assert result.answer == "async_ok" + @pytest.mark.deno def test_depth_2_python_interpreter_parent(self): """Parent uses PythonInterpreter (Deno), child uses LocalInterpreter. From 63ce422bead5aaa4e044244c58e0a2cb27cb81c0 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 20:11:04 -0500 Subject: [PATCH 13/15] fix(rlm): match parent interpreter type in child RLM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Child now inherits the parent's interpreter type: - LocalInterpreter parent → LocalInterpreter child - Default (PythonInterpreter/Deno) parent → PythonInterpreter child - Custom interpreter → LocalInterpreter fallback (can't clone) Previously child always got LocalInterpreter, which broke sandboxing when parent used PythonInterpreter (Deno). Tests updated: interpreter isolation tests specify parent type, new test verifies PythonInterpreter child when parent uses default. --- dspy/predict/rlm.py | 14 +++++++++++++- tests/predict/test_rlm.py | 36 +++++++++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index e1412b6d20..c8eb9ba28f 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -741,7 +741,19 @@ def _subcall( # Resolve child's sub_lm: model param selects which LM the child uses child_sub_lm = resolve_lm(model) if (model and resolve_lm) else self.sub_lm - interpreter = LocalInterpreter() + # Match parent's interpreter type: if parent uses LocalInterpreter (or a + # custom interpreter), child gets a fresh LocalInterpreter. If parent uses + # the default PythonInterpreter (Deno sandbox), child gets its own + # PythonInterpreter so sandboxing is preserved. + if isinstance(self._interpreter, LocalInterpreter): + interpreter = LocalInterpreter() + elif self._interpreter is None: + # Parent uses default PythonInterpreter — child gets one too + interpreter = PythonInterpreter() + else: + # Custom interpreter — can't clone, fall back to LocalInterpreter + interpreter = LocalInterpreter() + child = RLM( signature="prompt -> response", max_iterations=self.max_iterations, diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 2bd7501b2a..4ab49394e4 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -2702,8 +2702,8 @@ def capturing_init(self_inner, *args, **kwargs): class TestSubcallInterpreterIsolation: """Tests that child RLM gets an isolated LocalInterpreter.""" - def test_child_gets_local_interpreter(self): - """Child should receive a LocalInterpreter instance.""" + def test_child_gets_local_interpreter_when_parent_uses_local(self): + """When parent uses LocalInterpreter, child gets LocalInterpreter.""" from unittest.mock import patch from dspy.primitives.local_interpreter import LocalInterpreter @@ -2717,13 +2717,32 @@ def capturing_init(self_inner, *args, **kwargs): captured.update(kwargs) _original_init(self_inner, *args, **kwargs) - rlm = RLM("query -> answer", max_depth=2) + rlm = RLM("query -> answer", max_depth=2, interpreter=LocalInterpreter()) with patch.object(RLM, "__init__", capturing_init): rlm._subcall("test") assert isinstance(captured.get("interpreter"), LocalInterpreter) + def test_child_gets_python_interpreter_when_parent_uses_default(self): + """When parent uses default (PythonInterpreter), child matches.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + # No interpreter= means parent uses default PythonInterpreter + rlm = RLM("query -> answer", max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert isinstance(captured.get("interpreter"), PythonInterpreter) + def test_interpreter_shutdown_on_success(self): - """LocalInterpreter.shutdown() is called after successful child completion.""" + """Child interpreter shutdown() is called after successful completion.""" from unittest.mock import MagicMock, patch from dspy.primitives.local_interpreter import LocalInterpreter @@ -2738,14 +2757,16 @@ def tracking_shutdown(self_inner): with dummy_lm_context([ {"reasoning": "Done", "code": 'SUBMIT(response="ok")'}, ]): - rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + # Parent uses LocalInterpreter so child matches + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, + interpreter=LocalInterpreter()) with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): rlm._subcall("test") assert len(shutdown_called) >= 1 def test_interpreter_shutdown_on_error(self): - """LocalInterpreter.shutdown() is called even when child fails.""" + """Child interpreter shutdown() is called even when child fails.""" from unittest.mock import MagicMock, patch from dspy.primitives.local_interpreter import LocalInterpreter @@ -2761,7 +2782,8 @@ def tracking_shutdown(self_inner): {"reasoning": "Bad", "code": 'raise Exception("boom")'}, {"response": "fallback"}, # extract fallback ]): - rlm = RLM("query -> answer", max_iterations=1, max_depth=2) + rlm = RLM("query -> answer", max_iterations=1, max_depth=2, + interpreter=LocalInterpreter()) with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): rlm._subcall("test") From d6730eaba6969a8397a2d01051cc405b403b1646 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Tue, 10 Feb 2026 21:17:34 -0500 Subject: [PATCH 14/15] refactor(rlm): adopt generic dspy.Type sandbox protocol for media support Replace hardcoded Audio/Image media registry with the generic to_sandbox() protocol from PR #9283 (kmad's DataFrame type approach). Changes: - Add rlm_preview(), to_sandbox(), sandbox_setup(), sandbox_assignment() to Audio and Image types - Replace _is_media/_media_descriptor with generic _has_rlm_support() in python_interpreter.py - Replace _detect_media_fields/_build_media_registry with generic _detect_multimodal_fields/_build_multimodal_registry in rlm.py - Add _wrap_rlm_inputs() for auto-wrapping raw values into dspy.Type - Add _inject_pending_vars() for unified sandbox variable injection - llm_query_with_media is now always available as a tool - _build_variables uses rlm_preview() for better LLM context - Update all tests to use new API names (multimodal_registry, etc.) - Add tests for custom types implementing the protocol Any dspy.Type with to_sandbox() + format() is now automatically: 1. Detected as multimodal input 2. Injected into sandbox via to_sandbox() protocol 3. Available for llm_query_with_media() 4. Previewed via rlm_preview() in variable context 204 tests pass (0 failures, 38 deno-skips). --- dspy/adapters/types/audio.py | 21 +++ dspy/adapters/types/image.py | 25 +++ dspy/predict/rlm.py | 170 +++++++++++------- dspy/primitives/python_interpreter.py | 103 +++++------ tests/predict/test_rlm.py | 136 ++++++++------- tests/primitives/test_media_serialization.py | 172 ++++++++++++++----- 6 files changed, 406 insertions(+), 221 deletions(-) diff --git a/dspy/adapters/types/audio.py b/dspy/adapters/types/audio.py index 0ceb734b73..b0bd5b022f 100644 --- a/dspy/adapters/types/audio.py +++ b/dspy/adapters/types/audio.py @@ -122,6 +122,27 @@ def __repr__(self) -> str: length = len(self.data) return f"Audio(data=, audio_format='{self.audio_format}')" + # RLM Sandbox Support + + def rlm_preview(self, max_chars: int = 500) -> str: + """Generate LLM-friendly preview of Audio contents.""" + return f"" + + def to_sandbox(self) -> bytes: + """Serialize Audio for sandbox injection (descriptor string, not raw data). + + Audio data cannot be meaningfully processed as code in the sandbox. + The agent should use llm_query() with multimodal content to perceive audio. + """ + return self.rlm_preview().encode("utf-8") + + def sandbox_setup(self) -> str: + return "" + + def sandbox_assignment(self, var_name: str, data_expr: str) -> str: + """Return code that assigns the audio descriptor string in the sandbox.""" + return f"{var_name} = {data_expr}" + def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: int = 16000, format: str = "wav") -> dict: """ Encode audio to a dict with 'data' and 'audio_format'. diff --git a/dspy/adapters/types/image.py b/dspy/adapters/types/image.py index fc24768b80..d3349a2e9f 100644 --- a/dspy/adapters/types/image.py +++ b/dspy/adapters/types/image.py @@ -115,6 +115,31 @@ def __repr__(self): return f"Image(url=data:image/{image_type};base64,)" return f"Image(url='{self.url}')" + # RLM Sandbox Support + + def rlm_preview(self, max_chars: int = 500) -> str: + """Generate LLM-friendly preview of Image contents.""" + if "base64" in self.url: + len_base64 = len(self.url.split("base64,")[1]) + image_type = self.url.split(";")[0].split("/")[-1] + return f"" + return f"" + + def to_sandbox(self) -> bytes: + """Serialize Image for sandbox injection (descriptor string, not raw data). + + Image data cannot be meaningfully processed as code in the sandbox. + The agent should use llm_query() with multimodal content to perceive images. + """ + return self.rlm_preview().encode("utf-8") + + def sandbox_setup(self) -> str: + return "" + + def sandbox_assignment(self, var_name: str, data_expr: str) -> str: + """Return code that assigns the image descriptor string in the sandbox.""" + return f"{var_name} = {data_expr}" + def is_url(string: str) -> bool: """Check if a string is a valid URL.""" diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index c8eb9ba28f..bcac91a8d2 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -21,8 +21,6 @@ import pydantic import dspy -from dspy.adapters.types.audio import Audio -from dspy.adapters.types.image import Image from dspy.adapters.types.tool import Tool from dspy.adapters.utils import parse_value, translate_field_type from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreter, CodeInterpreterError, FinalOutput @@ -33,19 +31,15 @@ from dspy.signatures.signature import ensure_signature from dspy.utils.annotation import experimental -# Types considered "media" — their data can't be serialized into the sandbox -# but can be forwarded to sub-LLM calls via llm_query_with_media(). -_MEDIA_TYPES = (Audio, Image) +def _has_rlm_support(value): + """Check if a value is a dspy.Type with RLM sandbox support (to_sandbox protocol).""" + return hasattr(value, "to_sandbox") and callable(getattr(value, "to_sandbox", None)) -def _is_media(value): - """Check if a value is a media type (Audio or Image).""" - return isinstance(value, _MEDIA_TYPES) - -def _format_media_for_lm(value): - """Convert a media object to LM message content parts.""" - return value.format() +def _is_multimodal_type(value): + """Check if a value is a multimodal dspy.Type (Audio or Image) that can be sent to an LLM.""" + return hasattr(value, "format") and callable(getattr(value, "format", None)) and _has_rlm_support(value) if TYPE_CHECKING: @@ -63,7 +57,8 @@ def _format_media_for_lm(value): Available: - Variables: {inputs} (your input data) - `llm_query(prompt, model=None)` - query a sub-LLM (~500K char capacity) for semantic analysis{model_docs} -- `llm_query_batched(prompts, model=None)` - query multiple prompts concurrently (much faster for multiple queries){media_tools} +- `llm_query_batched(prompts, model=None)` - query multiple prompts concurrently (much faster for multiple queries) +- `llm_query_with_media(prompt, *media_var_names, model=None)` - query sub-LLM with media (audio/image) attached{media_docs} - `print()` - ALWAYS print to see results - `SUBMIT({final_output_names})` - submit final output when done - `budget()` - check remaining iterations, LLM calls, and time @@ -244,14 +239,15 @@ def _format_tool_docs(self, tools: dict[str, Tool]) -> str: def _make_llm_tools( self, - media_registry: dict[str, Any] | None = None, + multimodal_registry: dict[str, Any] | None = None, max_workers: int = 8, execution_state: dict[str, Any] | None = None, ) -> dict[str, Callable]: """Create llm_query, llm_query_batched, llm_query_with_media, and budget tools with a fresh call counter. Args: - media_registry: Dict mapping variable names to media objects (Audio/Image). + multimodal_registry: Dict mapping variable names to multimodal objects (any dspy.Type + with format() and to_sandbox() methods, e.g. Audio, Image). Used by llm_query_with_media to attach media to sub-LLM calls. max_workers: Max concurrent workers for batched queries. execution_state: Mutable dict tracking iteration/time state. Keys: @@ -262,7 +258,7 @@ def _make_llm_tools( _execution_state = execution_state or {} default_lm = self.sub_lm named_lms = self.sub_lms - _media_registry = media_registry or {} + _multimodal_registry = multimodal_registry or {} # Snapshot LM history lengths for cost tracking. # We'll sum cost from entries added after these offsets. @@ -312,14 +308,14 @@ def _query_lm(prompt: str, model: str | None = None) -> str: return item return str(response) - def _query_lm_multimodal(prompt: str, media_objects: list, model: str | None = None) -> str: - """Query the LLM with a prompt string and media content parts.""" + def _query_lm_multimodal(prompt: str, multimodal_objects: list, model: str | None = None) -> str: + """Query the LLM with a prompt string and multimodal content parts.""" target_lm = _resolve_lm(model) # Build multimodal content: text prompt + media content parts content_parts = [{"type": "text", "text": prompt}] - for media_obj in media_objects: - content_parts.extend(_format_media_for_lm(media_obj)) + for obj in multimodal_objects: + content_parts.extend(obj.format()) messages = [{"role": "user", "content": content_parts}] response = target_lm(messages=messages) @@ -390,12 +386,12 @@ def llm_query_batched(prompts: list[str], model: str | None = None) -> list[str] return [results[i] for i in range(len(prompts))] def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = None) -> str: - """Query the LLM with a prompt and media variables (audio/image). + """Query the LLM with a prompt and multimodal variables (audio/image). Args: prompt: The text prompt for the LLM. - *media_var_names: Names of media variables to include (e.g., 'audio_input', 'my_image'). - These must be names of Audio or Image input variables. + *media_var_names: Names of multimodal variables to include (e.g., 'audio_input', 'my_image'). + These must be names of input variables that have multimodal content (Audio, Image, etc.). model: Optional model name from sub_lms to use. Defaults to the default sub_lm. Returns: @@ -406,21 +402,21 @@ def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = if not media_var_names: raise ValueError( "At least one media variable name is required. " - f"Available media variables: {list(_media_registry.keys())}" + f"Available media variables: {list(_multimodal_registry.keys())}" ) - # Resolve media objects from the registry - media_objects = [] + # Resolve multimodal objects from the registry + multimodal_objects = [] for var_name in media_var_names: - if var_name not in _media_registry: - available = list(_media_registry.keys()) + if var_name not in _multimodal_registry: + available = list(_multimodal_registry.keys()) raise ValueError( f"Media variable '{var_name}' not found. Available media variables: {available}" ) - media_objects.append(_media_registry[var_name]) + multimodal_objects.append(_multimodal_registry[var_name]) _check_and_increment(1) - return _query_lm_multimodal(prompt, media_objects, model=model) + return _query_lm_multimodal(prompt, multimodal_objects, model=model) max_iterations = self.max_iterations max_llm_calls = self.max_llm_calls @@ -506,9 +502,12 @@ def budget() -> str: # Expose cost tracker to forward() for budget enforcement _execution_state["_get_cost_and_tokens"] = _get_cost_and_tokens - tools = {"llm_query": llm_query, "llm_query_batched": llm_query_batched, "budget": budget} - if _media_registry: - tools["llm_query_with_media"] = llm_query_with_media + tools = { + "llm_query": llm_query, + "llm_query_batched": llm_query_batched, + "llm_query_with_media": llm_query_with_media, + "budget": budget, + } return tools @property @@ -520,20 +519,22 @@ def tools(self) -> dict[str, Tool]: # Signature Building # ========================================================================= - def _detect_media_fields(self) -> dict[str, str]: - """Detect input fields that are Audio or Image types. + def _detect_multimodal_fields(self) -> dict[str, str]: + """Detect input fields that are multimodal dspy.Type subclasses (Audio, Image, etc.). + + Uses the generic to_sandbox() protocol — any dspy.Type with to_sandbox() and format() + is considered multimodal media that can be sent to an LLM via llm_query_with_media(). Returns: - Dict mapping field name to media type name (e.g., {'audio_input': 'Audio', 'photo': 'Image'}). + Dict mapping field name to type name (e.g., {'audio_input': 'Audio', 'photo': 'Image'}). """ - media_fields = {} + multimodal_fields = {} for name, field in self.signature.input_fields.items(): annotation = getattr(field, "annotation", None) - if annotation is Audio: - media_fields[name] = "Audio" - elif annotation is Image: - media_fields[name] = "Image" - return media_fields + if annotation is not None and isinstance(annotation, type) and issubclass(annotation, dspy.Type): + if hasattr(annotation, "to_sandbox") and hasattr(annotation, "format"): + multimodal_fields[name] = annotation.__name__ + return multimodal_fields def _build_signatures(self) -> tuple[Signature, Signature]: """Build the action and extract signatures from templates.""" @@ -553,21 +554,20 @@ def _build_signatures(self) -> tuple[Signature, Signature]: # Format tool documentation for user-provided tools tool_docs = self._format_tool_docs(self._user_tools) - # Detect media fields and build media-specific instructions - media_fields = self._detect_media_fields() - if media_fields: - media_var_list = ", ".join(f"'{name}'" for name in media_fields) - media_tools_str = ( - f"\n- `llm_query_with_media(prompt, *media_var_names)` - query sub-LLM with media (audio/image) attached. " - f"Media variables: {media_var_list}. The sub-LLM can see/hear the media content." + # Detect multimodal fields and build media-specific instructions + multimodal_fields = self._detect_multimodal_fields() + if multimodal_fields: + media_var_list = ", ".join(f"'{name}'" for name in multimodal_fields) + media_docs_str = ( + f"\n Media variables: {media_var_list}. The sub-LLM can see/hear the media content." ) media_guidelines_str = ( - f"\n FOR MEDIA INPUTS (Audio/Image): Variables like {media_var_list} are media objects. " + f"\n FOR MEDIA INPUTS: Variables like {media_var_list} are media objects. " f"In the sandbox they appear as descriptor strings — you CANNOT decode or process them as raw data. " - f"Use `llm_query_with_media(prompt, {next(iter(media_fields.keys()))!r})` to send media to a sub-LLM that can perceive it." + f"Use `llm_query_with_media(prompt, {next(iter(multimodal_fields.keys()))!r})` to send media to a sub-LLM that can perceive it." ) else: - media_tools_str = "" + media_docs_str = "" media_guidelines_str = "" # Document available model names if sub_lms is configured @@ -581,7 +581,7 @@ def _build_signatures(self) -> tuple[Signature, Signature]: dspy.Signature({}, task_instructions + ACTION_INSTRUCTIONS_TEMPLATE.format( inputs=inputs_str, final_output_names=final_output_names, output_fields=output_fields, max_llm_calls=self.max_llm_calls, - media_tools=media_tools_str, media_guidelines=media_guidelines_str, + media_docs=media_docs_str, media_guidelines=media_guidelines_str, model_docs=model_docs_str, ) + tool_docs) .append("variables_info", dspy.InputField(desc="Metadata about the variables available in the REPL"), type_=str) @@ -628,12 +628,49 @@ def _get_output_fields_info(self) -> list[dict]: fields.append(field_info) return fields + def _wrap_rlm_inputs(self, input_args: dict[str, Any]) -> dict[str, Any]: + """Auto-wrap raw values into their annotated dspy.Type when the type has RLM support. + + For example, if a field is annotated as dspy.DataFrame and the user passes a raw + pandas DataFrame, wrap it into dspy.DataFrame so the interpreter can use to_sandbox(). + """ + wrapped = {} + for name, value in input_args.items(): + field = self.signature.input_fields.get(name) + if field is None: + wrapped[name] = value + continue + + annotation = getattr(field, "annotation", None) + # Check if the annotation is a dspy.Type subclass with RLM support + if ( + annotation is not None + and isinstance(annotation, type) + and issubclass(annotation, dspy.Type) + and hasattr(annotation, "to_sandbox") + and not isinstance(value, annotation) + ): + try: + wrapped[name] = annotation(value) + except (TypeError, ValueError): + wrapped[name] = value + else: + wrapped[name] = value + return wrapped + def _build_variables(self, **input_args: Any) -> list[REPLVariable]: """Build REPLVariable list from input arguments with field metadata.""" variables = [] for name, value in input_args.items(): field_info = self.signature.input_fields.get(name) - variables.append(REPLVariable.from_value(name, value, field_info=field_info)) + if hasattr(value, "rlm_preview") and callable(getattr(value, "rlm_preview", None)): + # Use rlm_preview() for types with RLM support (gives better LLM context) + preview = value.rlm_preview() + var = REPLVariable.from_value(name, value, field_info=field_info) + var = var.model_copy(update={"preview": preview, "total_length": len(preview)}) + else: + var = REPLVariable.from_value(name, value, field_info=field_info) + variables.append(var) return variables def _format_output(self, output: str) -> str: @@ -651,26 +688,29 @@ def _validate_inputs(self, input_args: dict[str, Any]) -> None: # CodeInterpreter Lifecycle # ========================================================================= - def _build_media_registry(self, input_args: dict[str, Any]) -> dict[str, Any]: - """Extract media objects (Audio/Image) from inputs into a registry. + def _build_multimodal_registry(self, input_args: dict[str, Any]) -> dict[str, Any]: + """Extract multimodal objects from inputs into a registry. + + Any dspy.Type with both format() (for LLM content parts) and to_sandbox() (for + sandbox injection) is considered multimodal and eligible for llm_query_with_media(). Returns: - Dict mapping variable names to their media objects. + Dict mapping variable names to their multimodal objects. """ registry = {} for name, value in input_args.items(): - if _is_media(value): + if _is_multimodal_type(value): registry[name] = value return registry def _prepare_execution_tools( self, - media_registry: dict[str, Any] | None = None, + multimodal_registry: dict[str, Any] | None = None, execution_state: dict[str, Any] | None = None, ) -> dict[str, Callable]: """Create fresh LLM tools and merge with user-provided tools.""" execution_tools = self._make_llm_tools( - media_registry=media_registry, + multimodal_registry=multimodal_registry, execution_state=execution_state, ) # Extract underlying functions from Tool objects for the interpreter @@ -957,15 +997,16 @@ def forward(self, **input_args) -> Prediction: RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) + input_args = self._wrap_rlm_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - media_registry = self._build_media_registry(input_args) + multimodal_registry = self._build_multimodal_registry(input_args) # Mutable execution state — shared with budget() tool via closure execution_state = {"start_time": _time.monotonic(), "iteration": 0} execution_tools = self._prepare_execution_tools( - media_registry=media_registry, + multimodal_registry=multimodal_registry, execution_state=execution_state, ) variables = self._build_variables(**input_args) @@ -1073,15 +1114,16 @@ async def aforward(self, **input_args) -> Prediction: RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) + input_args = self._wrap_rlm_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - media_registry = self._build_media_registry(input_args) + multimodal_registry = self._build_multimodal_registry(input_args) # Mutable execution state — shared with budget() tool via closure execution_state = {"start_time": _time.monotonic(), "iteration": 0} execution_tools = self._prepare_execution_tools( - media_registry=media_registry, + multimodal_registry=multimodal_registry, execution_state=execution_state, ) variables = self._build_variables(**input_args) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index dc1d919d44..cd39f7929c 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -20,39 +20,9 @@ from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreterError, FinalOutput -# Lazy import helpers for multimodal types to avoid circular imports -def _is_media_type(value: Any) -> bool: - """Check if value is a DSPy Audio or Image type.""" - try: - from dspy.adapters.types.audio import Audio - if isinstance(value, Audio): - return True - except ImportError: - pass - try: - from dspy.adapters.types.image import Image - if isinstance(value, Image): - return True - except ImportError: - pass - return False - - -def _media_descriptor(value: Any) -> str: - """Return a human-readable descriptor string for a media object.""" - try: - from dspy.adapters.types.audio import Audio - if isinstance(value, Audio): - return f"" - except ImportError: - pass - try: - from dspy.adapters.types.image import Image - if isinstance(value, Image): - return repr(value) - except ImportError: - pass - return repr(value) +def _has_rlm_support(value: Any) -> bool: + """Check if a value is a dspy.Type with RLM sandbox support (to_sandbox protocol).""" + return hasattr(value, "to_sandbox") and callable(getattr(value, "to_sandbox", None)) __all__ = ["PythonInterpreter", "FinalOutput", "CodeInterpreterError"] @@ -413,8 +383,8 @@ def _to_json_compatible(self, value: Any) -> Any: """Recursively convert Python values to JSON-compatible types.""" if value is None or isinstance(value, str | int | float | bool): return value - elif _is_media_type(value): - return _media_descriptor(value) + elif _has_rlm_support(value): + return value.rlm_preview() if hasattr(value, "rlm_preview") else repr(value) elif isinstance(value, dict): return {k: self._to_json_compatible(v) for k, v in value.items()} elif isinstance(value, list | tuple): @@ -428,14 +398,35 @@ def _to_json_compatible(self, value: Any) -> Any: raise CodeInterpreterError(f"Unsupported value type: {type(value).__name__}") def _inject_variables(self, code: str, variables: dict[str, Any]) -> str: - """Insert Python assignments for each variable at the top of the code.""" + """Insert Python assignments for each variable at the top of the code. + + Supports dspy.Type instances with RLM sandbox support via to_sandbox(), + with fallback to JSON serialization for other types. + """ for key in variables: if not key.isidentifier() or keyword.iskeyword(key) or key == "json": raise CodeInterpreterError(f"Invalid variable name: '{key}'") + # Variables with custom RLM serialization (via to_sandbox interface) + rlm_type_vars: dict[str, bytes] = {} # name -> payload large_vars = {} small_assignments = [] + setup_imports = set() + rlm_assignments = [] + for k, v in variables.items(): + if _has_rlm_support(v): + payload = v.to_sandbox() + rlm_type_vars[k] = payload + data_expr = f"open('/tmp/dspy_vars/{k}.json').read()" + rlm_assignments.append(v.sandbox_assignment(k, data_expr)) + if hasattr(v, "sandbox_setup"): + setup = v.sandbox_setup() + if setup: + setup_imports.add(setup) + continue + + # Standard serialization for other types serialized = self._serialize_value(v) if len(serialized) > LARGE_VAR_THRESHOLD: large_vars[k] = json.dumps(self._to_json_compatible(v)) @@ -443,12 +434,20 @@ def _inject_variables(self, code: str, variables: dict[str, Any]) -> str: small_assignments.append(f"{k} = {serialized}") self._pending_large_vars = large_vars + self._pending_rlm_type_vars = rlm_type_vars + # Build imports + imports = list(setup_imports) if large_vars: - large_assignments = [f"{k} = json.loads(open('/tmp/dspy_vars/{k}.json').read())" for k in large_vars] - assignments = ["import json"] + small_assignments + large_assignments - else: - assignments = small_assignments + imports.append("import json") + + # Build assignments for large JSON vars + large_assignments = [ + f"{k} = json.loads(open('/tmp/dspy_vars/{k}.json').read())" + for k in large_vars + ] + + assignments = imports + small_assignments + large_assignments + rlm_assignments return "\n".join(assignments) + "\n" + code if assignments else code @@ -485,10 +484,11 @@ def _serialize_value(self, value: Any) -> str: sorted_items = list(value) items = ", ".join(self._serialize_value(item) for item in sorted_items) return f"[{items}]" - elif _is_media_type(value): - # Media types (Audio, Image) are represented as descriptor strings in the sandbox. - # The actual media data is accessed via llm_query_with_media() in the RLM context. - return repr(_media_descriptor(value)) + elif _has_rlm_support(value): + # Types with RLM sandbox support are represented as preview strings in the sandbox. + # The actual data is accessed via the type's sandbox protocol methods. + preview = value.rlm_preview() if hasattr(value, "rlm_preview") else repr(value) + return repr(preview) else: raise CodeInterpreterError(f"Unsupported value type: {type(value).__name__}") @@ -496,6 +496,15 @@ def _inject_large_var(self, name: str, value: str) -> None: """Inject a large variable via the virtual filesystem.""" self._send_request("inject_var", {"name": name, "value": value}, f"injecting variable '{name}'") + def _inject_pending_vars(self) -> None: + """Inject all pending large and RLM type variables into the sandbox.""" + for name, value in self._pending_large_vars.items(): + self._inject_large_var(name, value) + for name, payload in getattr(self, "_pending_rlm_type_vars", {}).items(): + if payload is not None: + value = payload.decode("utf-8") if isinstance(payload, bytes) else str(payload) + self._inject_large_var(name, value) + def execute( self, code: str, @@ -508,8 +517,7 @@ def execute( self._mount_files() self._register_tools() - for name, value in self._pending_large_vars.items(): - self._inject_large_var(name, value) + self._inject_pending_vars() # Send the code as JSON-RPC request self._request_id += 1 @@ -525,8 +533,7 @@ def execute( self._ensure_deno_process() self._mount_files() self._register_tools() - for name, value in self._pending_large_vars.items(): - self._inject_large_var(name, value) + self._inject_pending_vars() self.deno_process.stdin.write(input_data + "\n") self.deno_process.stdin.flush() diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index 4ab49394e4..45503d5bbf 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1115,11 +1115,11 @@ def test_with_llm_query(self): # ============================================================================ -class TestMediaDetection: - """Unit tests for media field detection and registry building.""" +class TestMultimodalDetection: + """Unit tests for multimodal field detection and registry building (types protocol).""" def test_detect_audio_field(self): - """Test _detect_media_fields finds Audio-typed inputs.""" + """Test _detect_multimodal_fields finds Audio-typed inputs.""" import dspy from dspy.adapters.types.audio import Audio @@ -1130,13 +1130,13 @@ class TranscribeSig(dspy.Signature): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM(TranscribeSig, max_iterations=3) - media_fields = rlm._detect_media_fields() + fields = rlm._detect_multimodal_fields() - assert "audio_input" in media_fields - assert media_fields["audio_input"] == "Audio" + assert "audio_input" in fields + assert fields["audio_input"] == "Audio" def test_detect_image_field(self): - """Test _detect_media_fields finds Image-typed inputs.""" + """Test _detect_multimodal_fields finds Image-typed inputs.""" import dspy from dspy.adapters.types.image import Image @@ -1147,13 +1147,13 @@ class DescribeSig(dspy.Signature): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('a cat')"}]): rlm = RLM(DescribeSig, max_iterations=3) - media_fields = rlm._detect_media_fields() + fields = rlm._detect_multimodal_fields() - assert "photo" in media_fields - assert media_fields["photo"] == "Image" + assert "photo" in fields + assert fields["photo"] == "Image" - def test_detect_mixed_media_fields(self): - """Test _detect_media_fields finds both Audio and Image in same signature.""" + def test_detect_mixed_multimodal_fields(self): + """Test _detect_multimodal_fields finds both Audio and Image in same signature.""" import dspy from dspy.adapters.types.audio import Audio from dspy.adapters.types.image import Image @@ -1166,23 +1166,23 @@ class MultimodalSig(dspy.Signature): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('done')"}]): rlm = RLM(MultimodalSig, max_iterations=3) - media_fields = rlm._detect_media_fields() + fields = rlm._detect_multimodal_fields() - assert len(media_fields) == 2 - assert media_fields["audio_clip"] == "Audio" - assert media_fields["photo"] == "Image" + assert len(fields) == 2 + assert fields["audio_clip"] == "Audio" + assert fields["photo"] == "Image" - def test_no_media_fields(self): - """Test _detect_media_fields returns empty dict for text-only signatures.""" + def test_no_multimodal_fields(self): + """Test _detect_multimodal_fields returns empty dict for text-only signatures.""" with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - media_fields = rlm._detect_media_fields() + fields = rlm._detect_multimodal_fields() - assert media_fields == {} + assert fields == {} - def test_build_media_registry_with_audio(self): - """Test _build_media_registry extracts Audio objects from inputs.""" + def test_build_multimodal_registry_with_audio(self): + """Test _build_multimodal_registry extracts Audio objects from inputs.""" import dspy from dspy.adapters.types.audio import Audio @@ -1195,62 +1195,58 @@ class TranscribeSig(dspy.Signature): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM(TranscribeSig, max_iterations=3) - registry = rlm._build_media_registry({"audio_input": audio}) + registry = rlm._build_multimodal_registry({"audio_input": audio}) assert "audio_input" in registry assert registry["audio_input"] is audio - def test_build_media_registry_ignores_text(self): - """Test _build_media_registry skips non-media values.""" + def test_build_multimodal_registry_ignores_text(self): + """Test _build_multimodal_registry skips non-multimodal values.""" with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - registry = rlm._build_media_registry({"query": "hello world"}) + registry = rlm._build_multimodal_registry({"query": "hello world"}) assert registry == {} - -class TestLLMQueryWithMedia: - """Unit tests for llm_query_with_media tool creation and validation.""" - - def test_media_tool_available_when_registry_populated(self): - """Test llm_query_with_media is created when media registry is non-empty.""" + def test_wrap_rlm_inputs_passthrough_already_wrapped(self): + """Test _wrap_rlm_inputs passes through already-wrapped dspy types.""" import dspy from dspy.adapters.types.audio import Audio - audio = Audio(data="dGVzdA==", audio_format="wav") - class TranscribeSig(dspy.Signature): - """Transcribe audio.""" audio_input: Audio = dspy.InputField() transcription: str = dspy.OutputField() + audio = Audio(data="dGVzdA==", audio_format="wav") with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM(TranscribeSig, max_iterations=3) - tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) - - assert "llm_query_with_media" in tools - assert "llm_query" in tools - assert "llm_query_batched" in tools + wrapped = rlm._wrap_rlm_inputs({"audio_input": audio}) - def test_media_tool_absent_when_no_media(self): - """Test llm_query_with_media is NOT created when media registry is empty.""" + assert wrapped["audio_input"] is audio + def test_wrap_rlm_inputs_passthrough_plain_types(self): + """Test _wrap_rlm_inputs passes through plain types like str.""" with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - tools = rlm._make_llm_tools(media_registry={}) + wrapped = rlm._wrap_rlm_inputs({"query": "hello"}) - assert "llm_query_with_media" not in tools - assert "llm_query" in tools + assert wrapped["query"] == "hello" - def test_media_tool_absent_when_registry_none(self): - """Test llm_query_with_media is NOT created when media registry is None.""" +class TestLLMQueryWithMedia: + """Unit tests for llm_query_with_media tool creation and validation.""" + + def test_media_tool_always_available(self): + """Test llm_query_with_media is always created as a tool.""" with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - tools = rlm._make_llm_tools(media_registry=None) + tools = rlm._make_llm_tools() - assert "llm_query_with_media" not in tools + assert "llm_query_with_media" in tools + assert "llm_query" in tools + assert "llm_query_batched" in tools + assert "budget" in tools def test_media_tool_rejects_empty_prompt(self): """Test llm_query_with_media raises on empty prompt.""" @@ -1260,7 +1256,7 @@ def test_media_tool_rejects_empty_prompt(self): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) with pytest.raises(ValueError, match="prompt cannot be empty"): tools["llm_query_with_media"]("") @@ -1273,7 +1269,7 @@ def test_media_tool_rejects_no_media_vars(self): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) with pytest.raises(ValueError, match="At least one media variable"): tools["llm_query_with_media"]("transcribe this") @@ -1286,7 +1282,7 @@ def test_media_tool_rejects_unknown_var(self): with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) - tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) with pytest.raises(ValueError, match="not found"): tools["llm_query_with_media"]("transcribe", "nonexistent_var") @@ -1297,10 +1293,10 @@ def test_reserved_tool_names_includes_media(self): class TestMediaInstructions: - """Unit tests for media-specific instruction injection in signatures.""" + """Unit tests for multimodal-specific instruction injection in signatures.""" - def test_media_tools_in_action_instructions(self): - """Test that media fields cause llm_query_with_media docs to appear.""" + def test_media_docs_in_action_instructions(self): + """Test that multimodal fields cause media docs to appear in instructions.""" import dspy from dspy.adapters.types.audio import Audio @@ -1314,18 +1310,20 @@ class TranscribeSig(dspy.Signature): action_sig, _extract_sig = rlm._build_signatures() instructions = action_sig.instructions - assert "llm_query_with_media" in instructions assert "audio_input" in instructions + assert "media" in instructions.lower() - def test_no_media_instructions_for_text_only(self): - """Test that text-only signatures do NOT include media instructions.""" + def test_no_media_guidelines_for_text_only(self): + """Test that text-only signatures do NOT include media-specific guidelines.""" with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): rlm = RLM("query -> answer", max_iterations=3) action_sig, _extract_sig = rlm._build_signatures() instructions = action_sig.instructions - assert "llm_query_with_media" not in instructions + # The generic llm_query_with_media tool is always documented, + # but media-specific guidelines should NOT appear + assert "FOR MEDIA INPUTS" not in instructions class TestMultiModelSubCalls: @@ -1428,7 +1426,7 @@ def test_llm_query_with_media_routes_to_named_model(self): rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_default, sub_lms={"pro": mock_pro}) - tools = rlm._make_llm_tools(media_registry={"audio_input": audio}) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) result = tools["llm_query_with_media"]("transcribe", "audio_input", model="pro") assert result == "pro media response" @@ -1905,7 +1903,7 @@ def test_media_content_parts_sent_to_lm(self): rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} tools = rlm._make_llm_tools( - media_registry={"my_audio": audio}, + multimodal_registry={"my_audio": audio}, execution_state=execution_state, ) @@ -1941,7 +1939,7 @@ def test_image_content_parts_sent_to_lm(self): rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} tools = rlm._make_llm_tools( - media_registry={"my_image": image}, + multimodal_registry={"my_image": image}, execution_state=execution_state, ) @@ -1972,7 +1970,7 @@ def test_multiple_media_objects_in_one_call(self): rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} tools = rlm._make_llm_tools( - media_registry={"audio_in": audio, "image_in": image}, + multimodal_registry={"audio_in": audio, "image_in": image}, execution_state=execution_state, ) @@ -2003,7 +2001,7 @@ def test_media_with_model_routing(self): sub_lm=mock_default, sub_lms={"pro": mock_pro}) execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} tools = rlm._make_llm_tools( - media_registry={"audio": audio}, + multimodal_registry={"audio": audio}, execution_state=execution_state, ) @@ -2743,7 +2741,7 @@ def capturing_init(self_inner, *args, **kwargs): def test_interpreter_shutdown_on_success(self): """Child interpreter shutdown() is called after successful completion.""" - from unittest.mock import MagicMock, patch + from unittest.mock import patch from dspy.primitives.local_interpreter import LocalInterpreter @@ -2767,7 +2765,7 @@ def tracking_shutdown(self_inner): def test_interpreter_shutdown_on_error(self): """Child interpreter shutdown() is called even when child fails.""" - from unittest.mock import MagicMock, patch + from unittest.mock import patch from dspy.primitives.local_interpreter import LocalInterpreter @@ -2834,7 +2832,7 @@ def test_depth_2_child_multi_iteration(self): """Child RLM can take multiple iterations before SUBMITting.""" with dummy_lm_context([ {"reasoning": "Explore first", "code": 'x = len(prompt)\nprint(f"length={x}")'}, - {"reasoning": "Now submit", "code": 'SUBMIT(response=str(x))'}, + {"reasoning": "Now submit", "code": "SUBMIT(response=str(x))"}, ]): rlm = RLM("query -> answer", max_iterations=5, max_depth=2) tools = rlm._make_llm_tools() @@ -2875,7 +2873,7 @@ async def test_depth_2_aforward(self): # Child iter 1: SUBMIT {"reasoning": "Done", "code": 'SUBMIT(response="async_ok")'}, # Parent iter 2: SUBMIT - {"reasoning": "Got it", "code": 'SUBMIT(result)'}, + {"reasoning": "Got it", "code": "SUBMIT(result)"}, ]): from dspy.primitives.local_interpreter import LocalInterpreter rlm = RLM("query -> answer", max_iterations=5, max_depth=2, @@ -2896,7 +2894,7 @@ def test_depth_2_python_interpreter_parent(self): # Child iter 1 (runs in LocalInterpreter): SUBMIT response {"reasoning": "Easy", "code": 'SUBMIT(response="4")'}, # Parent iter 2: SUBMIT the child's result - {"reasoning": "Done", "code": 'SUBMIT(result)'}, + {"reasoning": "Done", "code": "SUBMIT(result)"}, ]): # No interpreter= means default PythonInterpreter (Deno sandbox) rlm = RLM("query -> answer", max_iterations=5, max_depth=2) diff --git a/tests/primitives/test_media_serialization.py b/tests/primitives/test_media_serialization.py index 99246128da..a419819bdb 100644 --- a/tests/primitives/test_media_serialization.py +++ b/tests/primitives/test_media_serialization.py @@ -1,78 +1,170 @@ """ -Tests for media type helpers in python_interpreter.py. +Tests for RLM sandbox type protocol in python_interpreter.py and type classes. -These tests do NOT require Deno — they test pure Python serialization logic -for Audio and Image objects in the sandboxed interpreter. +Tests the generic _has_rlm_support() helper and the to_sandbox/rlm_preview/sandbox_setup/ +sandbox_assignment protocol on Audio and Image types. + +These tests do NOT require Deno — they test pure Python serialization logic. """ -from dspy.primitives.python_interpreter import _is_media_type, _media_descriptor +from dspy.primitives.python_interpreter import _has_rlm_support + +# ============================================================================ +# Tests: _has_rlm_support helper +# ============================================================================ -class TestIsMediaType: - """Tests for _is_media_type helper.""" +class TestHasRlmSupport: + """Tests for the generic _has_rlm_support() protocol check.""" def test_audio(self): from dspy.adapters.types.audio import Audio - audio = Audio(data="dGVzdA==", audio_format="wav") - assert _is_media_type(audio) is True + assert _has_rlm_support(audio) is True def test_image(self): from dspy.adapters.types.image import Image - img = Image(url="") - assert _is_media_type(img) is True + assert _has_rlm_support(img) is True def test_string(self): - assert _is_media_type("hello") is False + assert _has_rlm_support("hello") is False def test_int(self): - assert _is_media_type(42) is False + assert _has_rlm_support(42) is False def test_none(self): - assert _is_media_type(None) is False + assert _has_rlm_support(None) is False def test_dict(self): - assert _is_media_type({"data": "abc"}) is False + assert _has_rlm_support({"data": "abc"}) is False def test_list(self): - assert _is_media_type([1, 2, 3]) is False + assert _has_rlm_support([1, 2, 3]) is False -class TestMediaDescriptor: - """Tests for _media_descriptor helper.""" +# ============================================================================ +# Tests: Audio RLM sandbox protocol +# ============================================================================ - def test_audio_descriptor(self): - from dspy.adapters.types.audio import Audio - audio = Audio(data="dGVzdA==", audio_format="wav") - desc = _media_descriptor(audio) - assert "Audio" in desc - assert "wav" in desc - assert "8" in desc # len("dGVzdA==") == 8 +class TestAudioRlmProtocol: + """Tests for Audio.rlm_preview, to_sandbox, sandbox_setup, sandbox_assignment.""" - def test_image_descriptor(self): - from dspy.adapters.types.image import Image + def test_rlm_preview_basic(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + preview = audio.rlm_preview() + assert "Audio" in preview + assert "wav" in preview + assert "8" in preview # len("dGVzdA==") == 8 - img = Image(url="") - desc = _media_descriptor(img) - assert "Image" in desc + def test_rlm_preview_includes_format(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="AAAA", audio_format="mpeg") + preview = audio.rlm_preview() + assert "mpeg" in preview - def test_non_media_falls_back_to_repr(self): - desc = _media_descriptor("just a string") - assert desc == repr("just a string") + def test_rlm_preview_includes_data_length(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="A" * 100, audio_format="wav") + preview = audio.rlm_preview() + assert "100" in preview - def test_audio_descriptor_includes_format(self): + def test_to_sandbox_returns_bytes(self): from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + payload = audio.to_sandbox() + assert isinstance(payload, bytes) + assert b"Audio" in payload + assert b"wav" in payload - audio = Audio(data="AAAA", audio_format="mpeg") - desc = _media_descriptor(audio) - assert "mpeg" in desc + def test_sandbox_setup_empty(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + assert audio.sandbox_setup() == "" - def test_audio_descriptor_includes_data_length(self): + def test_sandbox_assignment(self): from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + code = audio.sandbox_assignment("my_audio", "open('/tmp/data.json').read()") + assert "my_audio" in code + assert "/tmp/data.json" in code - audio = Audio(data="A" * 100, audio_format="wav") - desc = _media_descriptor(audio) - assert "100" in desc + +# ============================================================================ +# Tests: Image RLM sandbox protocol +# ============================================================================ + + +class TestImageRlmProtocol: + """Tests for Image.rlm_preview, to_sandbox, sandbox_setup, sandbox_assignment.""" + + def test_rlm_preview_base64(self): + from dspy.adapters.types.image import Image + img = Image(url="") + preview = img.rlm_preview() + assert "Image" in preview + assert "png" in preview + + def test_rlm_preview_url(self): + from dspy.adapters.types.image import Image + img = Image(url="https://example.com/photo.jpg", download=False) + preview = img.rlm_preview() + assert "Image" in preview + assert "example.com" in preview + + def test_to_sandbox_returns_bytes(self): + from dspy.adapters.types.image import Image + img = Image(url="") + payload = img.to_sandbox() + assert isinstance(payload, bytes) + assert b"Image" in payload + + def test_sandbox_setup_empty(self): + from dspy.adapters.types.image import Image + img = Image(url="") + assert img.sandbox_setup() == "" + + def test_sandbox_assignment(self): + from dspy.adapters.types.image import Image + img = Image(url="") + code = img.sandbox_assignment("my_img", "open('/tmp/img.json').read()") + assert "my_img" in code + assert "/tmp/img.json" in code + + +# ============================================================================ +# Tests: Custom type with RLM protocol +# ============================================================================ + + +class TestCustomTypeRlmProtocol: + """Tests that any object implementing the protocol is recognized.""" + + def test_custom_type_with_protocol(self): + class MyType: + def to_sandbox(self): + return b"custom data" + def rlm_preview(self): + return "" + def sandbox_setup(self): + return "" + def sandbox_assignment(self, var_name, data_expr): + return f"{var_name} = {data_expr}" + + obj = MyType() + assert _has_rlm_support(obj) is True + + def test_custom_type_without_protocol(self): + class PlainType: + pass + assert _has_rlm_support(PlainType()) is False + + def test_partial_protocol_not_enough(self): + """Object with rlm_preview but no to_sandbox should NOT have RLM support.""" + class HalfType: + def rlm_preview(self): + return "preview" + assert _has_rlm_support(HalfType()) is False From c7e248a5f3e4c5400b8d8ecfdb6c671f2f837f6d Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp Date: Wed, 11 Feb 2026 11:22:18 -0500 Subject: [PATCH 15/15] =?UTF-8?q?refactor:=20sync=20with=20PR=20#9283=20la?= =?UTF-8?q?test=20=E2=80=94=20head+tail=20truncation,=20protocol=20docs,?= =?UTF-8?q?=20max=5Foutput=5Fchars?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Incorporate changes from 07e3f39 and 61e81ac on PR #9283: - Restore head+tail truncation in REPLVariable and REPLEntry (was simplified to head-only, now shows first half + '...' + last half) - Add REPLEntry.format_output() static method for verbose logging - Put max_output_chars back on REPLHistory (threaded through append()) - Revert max_output_chars default to 10_000 (was 100_000) - Simplify _format_output() to passthrough (truncation in REPLHistory) - Add RLM sandbox protocol documentation to Type base class - Update tests: head+tail assertions, REPLHistory threading, new tests 206 tests pass (0 failures, 38 deno-skips). --- dspy/adapters/types/base_type.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index afcacb47ad..0754ec51d1 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -129,6 +129,16 @@ def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Type"]: """ return None + # RLM Sandbox Support + # + # To opt-in to RLM sandbox injection, subclasses should implement: + # sandbox_setup() -> str (imports needed in sandbox) + # to_sandbox() -> bytes (serialize for injection) + # sandbox_assignment(var_name, data_expr) -> str (reconstruction code) + # rlm_preview(max_chars) -> str (LLM-friendly preview) + # + # See dspy.DataFrame for a reference implementation. + def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Split user message content into a list of content blocks.