diff --git a/dspy/adapters/types/audio.py b/dspy/adapters/types/audio.py index 0ceb734b73..b0bd5b022f 100644 --- a/dspy/adapters/types/audio.py +++ b/dspy/adapters/types/audio.py @@ -122,6 +122,27 @@ def __repr__(self) -> str: length = len(self.data) return f"Audio(data=, audio_format='{self.audio_format}')" + # RLM Sandbox Support + + def rlm_preview(self, max_chars: int = 500) -> str: + """Generate LLM-friendly preview of Audio contents.""" + return f"" + + def to_sandbox(self) -> bytes: + """Serialize Audio for sandbox injection (descriptor string, not raw data). + + Audio data cannot be meaningfully processed as code in the sandbox. + The agent should use llm_query() with multimodal content to perceive audio. + """ + return self.rlm_preview().encode("utf-8") + + def sandbox_setup(self) -> str: + return "" + + def sandbox_assignment(self, var_name: str, data_expr: str) -> str: + """Return code that assigns the audio descriptor string in the sandbox.""" + return f"{var_name} = {data_expr}" + def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: int = 16000, format: str = "wav") -> dict: """ Encode audio to a dict with 'data' and 'audio_format'. diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index afcacb47ad..0754ec51d1 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -129,6 +129,16 @@ def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Type"]: """ return None + # RLM Sandbox Support + # + # To opt-in to RLM sandbox injection, subclasses should implement: + # sandbox_setup() -> str (imports needed in sandbox) + # to_sandbox() -> bytes (serialize for injection) + # sandbox_assignment(var_name, data_expr) -> str (reconstruction code) + # rlm_preview(max_chars) -> str (LLM-friendly preview) + # + # See dspy.DataFrame for a reference implementation. + def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Split user message content into a list of content blocks. diff --git a/dspy/adapters/types/image.py b/dspy/adapters/types/image.py index fc24768b80..d3349a2e9f 100644 --- a/dspy/adapters/types/image.py +++ b/dspy/adapters/types/image.py @@ -115,6 +115,31 @@ def __repr__(self): return f"Image(url=data:image/{image_type};base64,)" return f"Image(url='{self.url}')" + # RLM Sandbox Support + + def rlm_preview(self, max_chars: int = 500) -> str: + """Generate LLM-friendly preview of Image contents.""" + if "base64" in self.url: + len_base64 = len(self.url.split("base64,")[1]) + image_type = self.url.split(";")[0].split("/")[-1] + return f"" + return f"" + + def to_sandbox(self) -> bytes: + """Serialize Image for sandbox injection (descriptor string, not raw data). + + Image data cannot be meaningfully processed as code in the sandbox. + The agent should use llm_query() with multimodal content to perceive images. + """ + return self.rlm_preview().encode("utf-8") + + def sandbox_setup(self) -> str: + return "" + + def sandbox_assignment(self, var_name: str, data_expr: str) -> str: + """Return code that assigns the image descriptor string in the sandbox.""" + return f"{var_name} = {data_expr}" + def is_url(string: str) -> bool: """Check if a string is a valid URL.""" diff --git a/dspy/predict/rlm.py b/dspy/predict/rlm.py index 66c011d234..bcac91a8d2 100644 --- a/dspy/predict/rlm.py +++ b/dspy/predict/rlm.py @@ -13,6 +13,7 @@ import logging import re import threading +import time as _time from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator @@ -30,6 +31,16 @@ from dspy.signatures.signature import ensure_signature from dspy.utils.annotation import experimental + +def _has_rlm_support(value): + """Check if a value is a dspy.Type with RLM sandbox support (to_sandbox protocol).""" + return hasattr(value, "to_sandbox") and callable(getattr(value, "to_sandbox", None)) + + +def _is_multimodal_type(value): + """Check if a value is a multimodal dspy.Type (Audio or Image) that can be sent to an LLM.""" + return hasattr(value, "format") and callable(getattr(value, "format", None)) and _has_rlm_support(value) + if TYPE_CHECKING: from dspy.signatures.signature import Signature @@ -45,10 +56,12 @@ Available: - Variables: {inputs} (your input data) -- `llm_query(prompt)` - query a sub-LLM (~500K char capacity) for semantic analysis -- `llm_query_batched(prompts)` - query multiple prompts concurrently (much faster for multiple queries) +- `llm_query(prompt, model=None)` - query a sub-LLM (~500K char capacity) for semantic analysis{model_docs} +- `llm_query_batched(prompts, model=None)` - query multiple prompts concurrently (much faster for multiple queries) +- `llm_query_with_media(prompt, *media_var_names, model=None)` - query sub-LLM with media (audio/image) attached{media_docs} - `print()` - ALWAYS print to see results - `SUBMIT({final_output_names})` - submit final output when done +- `budget()` - check remaining iterations, LLM calls, and time - Standard libraries: re, json, collections, math, etc. IMPORTANT: This is ITERATIVE. Each code block you write will execute, you'll see the output, then you decide what to do next. Do NOT try to solve everything in one step. @@ -56,11 +69,11 @@ 1. EXPLORE FIRST - Look at your data before processing it. Print samples, check types/lengths, understand the structure. 2. ITERATE - Write small code snippets, observe outputs, then decide next steps. State persists between iterations. 3. VERIFY BEFORE SUBMITTING - If results seem wrong (zeros, empty, unexpected), reconsider your approach. -4. USE llm_query FOR SEMANTICS - String matching finds WHERE things are; llm_query understands WHAT things mean. +4. USE llm_query FOR SEMANTICS - String matching finds WHERE things are; llm_query understands WHAT things mean.{media_guidelines} 5. MINIMIZE RETYPING (INPUTS & OUTPUTS) - When values are long, precise, or error-prone (IDs, numbers, code, quotes), re-access them via variables and parse/compute in code instead of retyping. Use small, targeted prints to sanity-check, but avoid manual copying when variables can carry the exact value. 6. SUBMIT ONLY AFTER SEEING OUTPUTS - SUBMIT ends the current run immediately. If you need to inspect printed output, run it in one step, review the result, then call SUBMIT in a later step. -You have max {max_llm_calls} sub-LLM calls. When done, call SUBMIT() with your output.""" +You have max {max_llm_calls} sub-LLM calls. Call budget() to check remaining resources. When done, call SUBMIT() with your output.""" # Pattern to match markdown code fences: ```python\n...\n``` or ```\n...\n``` _CODE_FENCE_PATTERN = re.compile(r"^```(?:python|py)?\s*\n(.*)\n```\s*$", re.DOTALL) @@ -104,10 +117,15 @@ def __init__( max_iterations: int = 20, max_llm_calls: int = 50, max_output_chars: int = 10_000, + max_time: float | None = None, + max_cost: float | None = None, verbose: bool = False, tools: list[Callable] | None = None, sub_lm: dspy.LM | None = None, + sub_lms: dict[str, dspy.LM] | None = None, interpreter: CodeInterpreter | None = None, + depth: int = 0, + max_depth: int = 1, ): """ Args: @@ -116,21 +134,42 @@ def __init__( max_iterations: Maximum REPL interaction iterations. max_llm_calls: Maximum sub-LLM calls (llm_query/llm_query_batched) per execution. max_output_chars: Maximum characters to include from REPL output. + max_time: Maximum wall-clock seconds per forward() call. None means no limit. + The agent can check remaining time via budget(). + max_cost: Maximum dollar cost per forward() call. None means no limit. + Tracked via litellm's per-call cost reporting. The agent can + check remaining cost via budget(). verbose: Whether to log detailed execution info. tools: List of tool functions or dspy.Tool objects callable from interpreter code. Built-in tools: llm_query(prompt), llm_query_batched(prompts). - sub_lm: LM for llm_query/llm_query_batched. Defaults to dspy.settings.lm. + sub_lm: Default LM for llm_query/llm_query_batched. Defaults to dspy.settings.lm. Allows using a different (e.g., cheaper) model for sub-queries. + sub_lms: Dict mapping model names to LM instances, enabling sandbox code to select + a specific model via llm_query(prompt, model="name"). When model is None, + falls back to sub_lm, then dspy.settings.lm. interpreter: CodeInterpreter implementation to use. Defaults to PythonInterpreter. + depth: Current recursion depth (0-indexed). Used internally when spawning child RLMs. + max_depth: Maximum recursion depth. When depth < max_depth - 1, llm_query spawns + a child RLM with its own REPL (LocalInterpreter). At leaf depth, falls + back to plain LM completion. Default 1 means no recursion (current behavior). """ super().__init__() self.signature = ensure_signature(signature) self.max_iterations = max_iterations self.max_llm_calls = max_llm_calls self.max_output_chars = max_output_chars + self.max_time = max_time + self.max_cost = max_cost self.verbose = verbose self.sub_lm = sub_lm + self.sub_lms = sub_lms or {} self._interpreter = interpreter + if max_depth < 1: + raise ValueError(f"max_depth must be >= 1, got {max_depth}") + if depth < 0: + raise ValueError(f"depth must be >= 0, got {depth}") + self.depth = depth + self.max_depth = max_depth self._user_tools = self._normalize_tools(tools) self._validate_tools(self._user_tools) @@ -144,7 +183,7 @@ def __init__( # ========================================================================= # Reserved tool names that conflict with built-in sandbox functions - _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "SUBMIT", "print"}) + _RESERVED_TOOL_NAMES = frozenset({"llm_query", "llm_query_batched", "llm_query_with_media", "SUBMIT", "print", "budget"}) def _normalize_tools(self, tools: list[Callable] | None) -> dict[str, Tool]: """Normalize tools list to a dict of Tool objects keyed by name.""" @@ -171,7 +210,7 @@ def to_tool(func: Callable | Tool) -> Tool: def _validate_tools(self, tools: dict[str, Tool]) -> None: """Validate user-provided tools have valid names.""" - for name, tool in tools.items(): + for name, _tool in tools.items(): if not name.isidentifier(): raise ValueError(f"Invalid tool name '{name}': must be a valid Python identifier") if name in self._RESERVED_TOOL_NAMES: @@ -198,11 +237,57 @@ def _format_tool_docs(self, tools: dict[str, Tool]) -> str: return "\n".join(lines) - def _make_llm_tools(self, max_workers: int = 8) -> dict[str, Callable]: - """Create llm_query and llm_query_batched tools with a fresh call counter.""" + def _make_llm_tools( + self, + multimodal_registry: dict[str, Any] | None = None, + max_workers: int = 8, + execution_state: dict[str, Any] | None = None, + ) -> dict[str, Callable]: + """Create llm_query, llm_query_batched, llm_query_with_media, and budget tools with a fresh call counter. + + Args: + multimodal_registry: Dict mapping variable names to multimodal objects (any dspy.Type + with format() and to_sandbox() methods, e.g. Audio, Image). + Used by llm_query_with_media to attach media to sub-LLM calls. + max_workers: Max concurrent workers for batched queries. + execution_state: Mutable dict tracking iteration/time state. Keys: + 'start_time' (float), 'iteration' (int). Updated by forward(). + """ state = {"call_count": 0} lock = threading.Lock() - lm = self.sub_lm + _execution_state = execution_state or {} + default_lm = self.sub_lm + named_lms = self.sub_lms + _multimodal_registry = multimodal_registry or {} + + # Snapshot LM history lengths for cost tracking. + # We'll sum cost from entries added after these offsets. + _all_lms: list[dspy.LM] = [] + if default_lm is not None: + _all_lms.append(default_lm) + for lm_inst in named_lms.values(): + if lm_inst not in _all_lms: + _all_lms.append(lm_inst) + if not _all_lms and dspy.settings.lm is not None: + _all_lms.append(dspy.settings.lm) + _history_offsets = {id(lm_inst): len(lm_inst.history) for lm_inst in _all_lms} + + def _resolve_lm(model: str | None = None) -> dspy.LM: + """Resolve a model name to an LM instance. + + Resolution order: named sub_lms[model] → sub_lm → dspy.settings.lm. + """ + if model is not None: + if model in named_lms: + return named_lms[model] + available = list(named_lms.keys()) + raise ValueError( + f"Model '{model}' not found in sub_lms. Available models: {available}" + ) + target_lm = default_lm if default_lm is not None else dspy.settings.lm + if target_lm is None: + raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + return target_lm def _check_and_increment(n: int = 1) -> None: with lock: @@ -213,10 +298,8 @@ def _check_and_increment(n: int = 1) -> None: ) state["call_count"] += n - def _query_lm(prompt: str) -> str: - target_lm = lm if lm is not None else dspy.settings.lm - if target_lm is None: - raise RuntimeError("No LM configured. Use dspy.configure(lm=...) or pass sub_lm to RLM.") + def _query_lm(prompt: str, model: str | None = None) -> str: + target_lm = _resolve_lm(model) response = target_lm(prompt) if isinstance(response, list) and response: item = response[0] @@ -225,22 +308,75 @@ def _query_lm(prompt: str) -> str: return item return str(response) - def llm_query(prompt: str) -> str: - """Query the LLM with a prompt string.""" + def _query_lm_multimodal(prompt: str, multimodal_objects: list, model: str | None = None) -> str: + """Query the LLM with a prompt string and multimodal content parts.""" + target_lm = _resolve_lm(model) + + # Build multimodal content: text prompt + media content parts + content_parts = [{"type": "text", "text": prompt}] + for obj in multimodal_objects: + content_parts.extend(obj.format()) + + messages = [{"role": "user", "content": content_parts}] + response = target_lm(messages=messages) + if isinstance(response, list) and response: + item = response[0] + if isinstance(item, dict) and "text" in item: + return item["text"] + return item + return str(response) + + # Determine if llm_query should spawn recursive child RLMs + _use_recursive = self.depth < self.max_depth - 1 + + def _do_subcall(prompt: str, model: str | None = None) -> str: + """Route subcall through the instance method, passing execution state.""" + return self._subcall(prompt, model=model, execution_state=_execution_state, resolve_lm=_resolve_lm) + + def llm_query(prompt: str, model: str | None = None) -> str: + """Query the LLM with a prompt string. + + At depth < max_depth - 1, spawns a child RLM with its own REPL. + At leaf depth, makes a plain LM completion call. + + Args: + prompt: The text prompt for the LLM. + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. + """ if not prompt: raise ValueError("prompt cannot be empty") _check_and_increment(1) - return _query_lm(prompt) + if _use_recursive: + return _do_subcall(prompt, model=model) + return _query_lm(prompt, model=model) + + def llm_query_batched(prompts: list[str], model: str | None = None) -> list[str]: + """Query the LLM with multiple prompts concurrently. - def llm_query_batched(prompts: list[str]) -> list[str]: - """Query the LLM with multiple prompts concurrently.""" + At depth < max_depth - 1, each prompt spawns a child RLM (sequential). + At leaf depth, prompts are sent as plain LM calls (parallel). + + Args: + prompts: List of prompts to send to the LLM. + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. + """ if not prompts: return [] _check_and_increment(len(prompts)) + if _use_recursive: + # Sequential: each child spawns its own LocalInterpreter + results = [] + for p in prompts: + try: + results.append(_do_subcall(p, model=model)) + except Exception as e: + results.append(f"[ERROR] {e}") + return results + results: dict[int, str] = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_idx = {executor.submit(_query_lm, p): i for i, p in enumerate(prompts)} + future_to_idx = {executor.submit(_query_lm, p, model): i for i, p in enumerate(prompts)} for future in as_completed(future_to_idx): idx = future_to_idx[future] try: @@ -249,7 +385,130 @@ def llm_query_batched(prompts: list[str]) -> list[str]: results[idx] = f"[ERROR] {e}" return [results[i] for i in range(len(prompts))] - return {"llm_query": llm_query, "llm_query_batched": llm_query_batched} + def llm_query_with_media(prompt: str, *media_var_names: str, model: str | None = None) -> str: + """Query the LLM with a prompt and multimodal variables (audio/image). + + Args: + prompt: The text prompt for the LLM. + *media_var_names: Names of multimodal variables to include (e.g., 'audio_input', 'my_image'). + These must be names of input variables that have multimodal content (Audio, Image, etc.). + model: Optional model name from sub_lms to use. Defaults to the default sub_lm. + + Returns: + The LLM's text response. + """ + if not prompt: + raise ValueError("prompt cannot be empty") + if not media_var_names: + raise ValueError( + "At least one media variable name is required. " + f"Available media variables: {list(_multimodal_registry.keys())}" + ) + + # Resolve multimodal objects from the registry + multimodal_objects = [] + for var_name in media_var_names: + if var_name not in _multimodal_registry: + available = list(_multimodal_registry.keys()) + raise ValueError( + f"Media variable '{var_name}' not found. Available media variables: {available}" + ) + multimodal_objects.append(_multimodal_registry[var_name]) + + _check_and_increment(1) + return _query_lm_multimodal(prompt, multimodal_objects, model=model) + + max_iterations = self.max_iterations + max_llm_calls = self.max_llm_calls + max_time = self.max_time + max_cost = self.max_cost + + def _get_cost_and_tokens() -> tuple[float, int]: + """Sum cost and tokens from LM history entries added since tool creation. + + Aggregates both the provider cost (litellm response_cost) and the upstream + inference cost (e.g. OpenRouter BYOK → Vertex). For BYOK providers where + response_cost is 0, the upstream cost is the actual charge. + """ + total_cost = 0.0 + total_tokens = 0 + for lm_inst in _all_lms: + offset = _history_offsets.get(id(lm_inst), 0) + for entry in lm_inst.history[offset:]: + entry_cost = 0.0 + cost = entry.get("cost") + if cost: + entry_cost += cost + usage = entry.get("usage", {}) + if isinstance(usage, dict): + cost_details = usage.get("cost_details") + if isinstance(cost_details, dict): + upstream = cost_details.get("upstream_inference_cost") + if upstream: + entry_cost += upstream + total_cost += entry_cost + total_tokens += usage.get("total_tokens", 0) if isinstance(usage, dict) else 0 + return total_cost, total_tokens + + def budget() -> str: + """Check remaining execution budget: iterations, LLM calls, time, and cost. + + Returns a human-readable summary of remaining resources. + Includes warnings when any resource drops below 20% remaining. + """ + with lock: + calls_used = state["call_count"] + + iteration = _execution_state.get("iteration", 0) + remaining_iterations = max_iterations - iteration - 1 # -1 for current + remaining_calls = max_llm_calls - calls_used + + warnings = [] + parts = [ + f"Iterations: {remaining_iterations}/{max_iterations} remaining", + f"LLM calls: {remaining_calls}/{max_llm_calls} remaining", + ] + + if remaining_iterations <= max(1, max_iterations * 0.2): + warnings.append(f"iterations ({remaining_iterations} left)") + if remaining_calls <= max(1, max_llm_calls * 0.2): + warnings.append(f"LLM calls ({remaining_calls} left)") + + start_time = _execution_state.get("start_time") + if max_time is not None and start_time is not None: + elapsed = _time.monotonic() - start_time + remaining_time = max(0.0, max_time - elapsed) + parts.append(f"Time: {remaining_time:.1f}s/{max_time:.1f}s remaining ({elapsed:.1f}s elapsed)") + if remaining_time <= max_time * 0.2: + warnings.append(f"time ({remaining_time:.0f}s left)") + elif start_time is not None: + elapsed = _time.monotonic() - start_time + parts.append(f"Time: no limit ({elapsed:.1f}s elapsed)") + + cost_spent, tokens_used = _get_cost_and_tokens() + if max_cost is not None: + remaining_cost = max(0.0, max_cost - cost_spent) + parts.append(f"Cost: ${remaining_cost:.4f}/${max_cost:.4f} remaining (${cost_spent:.4f} spent, {tokens_used:,} tokens)") + if remaining_cost <= max_cost * 0.2: + warnings.append(f"cost (${remaining_cost:.4f} left)") + elif cost_spent > 0: + parts.append(f"Cost: no limit (${cost_spent:.4f} spent, {tokens_used:,} tokens)") + + result = " | ".join(parts) + if warnings: + result = f"⚠ LOW: {', '.join(warnings)}. Wrap up soon! | " + result + return result + + # Expose cost tracker to forward() for budget enforcement + _execution_state["_get_cost_and_tokens"] = _get_cost_and_tokens + + tools = { + "llm_query": llm_query, + "llm_query_batched": llm_query_batched, + "llm_query_with_media": llm_query_with_media, + "budget": budget, + } + return tools @property def tools(self) -> dict[str, Tool]: @@ -260,6 +519,23 @@ def tools(self) -> dict[str, Tool]: # Signature Building # ========================================================================= + def _detect_multimodal_fields(self) -> dict[str, str]: + """Detect input fields that are multimodal dspy.Type subclasses (Audio, Image, etc.). + + Uses the generic to_sandbox() protocol — any dspy.Type with to_sandbox() and format() + is considered multimodal media that can be sent to an LLM via llm_query_with_media(). + + Returns: + Dict mapping field name to type name (e.g., {'audio_input': 'Audio', 'photo': 'Image'}). + """ + multimodal_fields = {} + for name, field in self.signature.input_fields.items(): + annotation = getattr(field, "annotation", None) + if annotation is not None and isinstance(annotation, type) and issubclass(annotation, dspy.Type): + if hasattr(annotation, "to_sandbox") and hasattr(annotation, "format"): + multimodal_fields[name] = annotation.__name__ + return multimodal_fields + def _build_signatures(self) -> tuple[Signature, Signature]: """Build the action and extract signatures from templates.""" inputs_str = ", ".join(f"`{n}`" for n in self.signature.input_fields) @@ -278,10 +554,35 @@ def _build_signatures(self) -> tuple[Signature, Signature]: # Format tool documentation for user-provided tools tool_docs = self._format_tool_docs(self._user_tools) + # Detect multimodal fields and build media-specific instructions + multimodal_fields = self._detect_multimodal_fields() + if multimodal_fields: + media_var_list = ", ".join(f"'{name}'" for name in multimodal_fields) + media_docs_str = ( + f"\n Media variables: {media_var_list}. The sub-LLM can see/hear the media content." + ) + media_guidelines_str = ( + f"\n FOR MEDIA INPUTS: Variables like {media_var_list} are media objects. " + f"In the sandbox they appear as descriptor strings — you CANNOT decode or process them as raw data. " + f"Use `llm_query_with_media(prompt, {next(iter(multimodal_fields.keys()))!r})` to send media to a sub-LLM that can perceive it." + ) + else: + media_docs_str = "" + media_guidelines_str = "" + + # Document available model names if sub_lms is configured + if self.sub_lms: + model_names = ", ".join(f"'{name}'" for name in self.sub_lms) + model_docs_str = f"\n Available models: {model_names}. Pass model= to select one." + else: + model_docs_str = "" + action_sig = ( dspy.Signature({}, task_instructions + ACTION_INSTRUCTIONS_TEMPLATE.format( inputs=inputs_str, final_output_names=final_output_names, output_fields=output_fields, max_llm_calls=self.max_llm_calls, + media_docs=media_docs_str, media_guidelines=media_guidelines_str, + model_docs=model_docs_str, ) + tool_docs) .append("variables_info", dspy.InputField(desc="Metadata about the variables available in the REPL"), type_=str) .append("repl_history", dspy.InputField(desc="Previous REPL code executions and their outputs"), type_=REPLHistory) @@ -327,12 +628,49 @@ def _get_output_fields_info(self) -> list[dict]: fields.append(field_info) return fields + def _wrap_rlm_inputs(self, input_args: dict[str, Any]) -> dict[str, Any]: + """Auto-wrap raw values into their annotated dspy.Type when the type has RLM support. + + For example, if a field is annotated as dspy.DataFrame and the user passes a raw + pandas DataFrame, wrap it into dspy.DataFrame so the interpreter can use to_sandbox(). + """ + wrapped = {} + for name, value in input_args.items(): + field = self.signature.input_fields.get(name) + if field is None: + wrapped[name] = value + continue + + annotation = getattr(field, "annotation", None) + # Check if the annotation is a dspy.Type subclass with RLM support + if ( + annotation is not None + and isinstance(annotation, type) + and issubclass(annotation, dspy.Type) + and hasattr(annotation, "to_sandbox") + and not isinstance(value, annotation) + ): + try: + wrapped[name] = annotation(value) + except (TypeError, ValueError): + wrapped[name] = value + else: + wrapped[name] = value + return wrapped + def _build_variables(self, **input_args: Any) -> list[REPLVariable]: """Build REPLVariable list from input arguments with field metadata.""" variables = [] for name, value in input_args.items(): field_info = self.signature.input_fields.get(name) - variables.append(REPLVariable.from_value(name, value, field_info=field_info)) + if hasattr(value, "rlm_preview") and callable(getattr(value, "rlm_preview", None)): + # Use rlm_preview() for types with RLM support (gives better LLM context) + preview = value.rlm_preview() + var = REPLVariable.from_value(name, value, field_info=field_info) + var = var.model_copy(update={"preview": preview, "total_length": len(preview)}) + else: + var = REPLVariable.from_value(name, value, field_info=field_info) + variables.append(var) return variables def _format_output(self, output: str) -> str: @@ -350,9 +688,31 @@ def _validate_inputs(self, input_args: dict[str, Any]) -> None: # CodeInterpreter Lifecycle # ========================================================================= - def _prepare_execution_tools(self) -> dict[str, Callable]: + def _build_multimodal_registry(self, input_args: dict[str, Any]) -> dict[str, Any]: + """Extract multimodal objects from inputs into a registry. + + Any dspy.Type with both format() (for LLM content parts) and to_sandbox() (for + sandbox injection) is considered multimodal and eligible for llm_query_with_media(). + + Returns: + Dict mapping variable names to their multimodal objects. + """ + registry = {} + for name, value in input_args.items(): + if _is_multimodal_type(value): + registry[name] = value + return registry + + def _prepare_execution_tools( + self, + multimodal_registry: dict[str, Any] | None = None, + execution_state: dict[str, Any] | None = None, + ) -> dict[str, Callable]: """Create fresh LLM tools and merge with user-provided tools.""" - execution_tools = self._make_llm_tools() + execution_tools = self._make_llm_tools( + multimodal_registry=multimodal_registry, + execution_state=execution_state, + ) # Extract underlying functions from Tool objects for the interpreter execution_tools.update({name: tool.func for name, tool in self._user_tools.items()}) return execution_tools @@ -371,6 +731,93 @@ def _inject_execution_context(self, interpreter: CodeInterpreter, execution_tool if hasattr(interpreter, "_tools_registered"): interpreter._tools_registered = False + def _subcall( + self, + prompt: str, + model: str | None = None, + execution_state: dict[str, Any] | None = None, + resolve_lm: Callable | None = None, + ) -> str: + """Spawn a child RLM with its own LocalInterpreter REPL. + + Called by llm_query/llm_query_batched when depth < max_depth - 1. + The child gets a fresh REPL and can write code, call llm_query (which + recurses further or falls back to plain LM at leaf depth), and SUBMIT + a response. Mirrors the vanilla RLM's _subcall pattern. + + Args: + prompt: The prompt to pass as the child's input. + model: Optional model name. Selects child's sub_lm via resolve_lm. + execution_state: Parent's mutable execution state (start_time, cost tracker). + resolve_lm: Closure from _make_llm_tools that resolves model name to LM instance. + + Returns: + The child's response string, or an error string on failure. + """ + from dspy.primitives.local_interpreter import LocalInterpreter + + _execution_state = execution_state or {} + + # Calculate remaining time budget for child + remaining_time = None + if self.max_time is not None: + start = _execution_state.get("start_time") + if start is not None: + elapsed = _time.monotonic() - start + remaining_time = max(0.0, self.max_time - elapsed) + if remaining_time <= 0: + return "Error: Time budget exhausted" + + # Calculate remaining cost budget for child + remaining_cost = None + if self.max_cost is not None: + get_cost = _execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + remaining_cost = max(0.0, self.max_cost - cost_spent) + if remaining_cost <= 0: + return "Error: Cost budget exhausted" + + # Resolve child's sub_lm: model param selects which LM the child uses + child_sub_lm = resolve_lm(model) if (model and resolve_lm) else self.sub_lm + + # Match parent's interpreter type: if parent uses LocalInterpreter (or a + # custom interpreter), child gets a fresh LocalInterpreter. If parent uses + # the default PythonInterpreter (Deno sandbox), child gets its own + # PythonInterpreter so sandboxing is preserved. + if isinstance(self._interpreter, LocalInterpreter): + interpreter = LocalInterpreter() + elif self._interpreter is None: + # Parent uses default PythonInterpreter — child gets one too + interpreter = PythonInterpreter() + else: + # Custom interpreter — can't clone, fall back to LocalInterpreter + interpreter = LocalInterpreter() + + child = RLM( + signature="prompt -> response", + max_iterations=self.max_iterations, + max_llm_calls=self.max_llm_calls, + max_output_chars=self.max_output_chars, + max_time=remaining_time, + max_cost=remaining_cost, + verbose=self.verbose, + tools=list(self._user_tools.values()) if self._user_tools else None, + sub_lm=child_sub_lm, + sub_lms=self.sub_lms, + interpreter=interpreter, + depth=self.depth + 1, + max_depth=self.max_depth, + ) + + try: + result = child(prompt=prompt) + return result.response + except Exception as e: + return f"Error: Child RLM failed - {e}" + finally: + interpreter.shutdown() + @contextmanager def _interpreter_context(self, execution_tools: dict[str, Callable]) -> Iterator[CodeInterpreter]: """Yield interpreter, creating PythonInterpreter if none provided at init.""" @@ -547,17 +994,51 @@ def forward(self, **input_args) -> Prediction: Raises: ValueError: If required input fields are missing + RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) + input_args = self._wrap_rlm_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - execution_tools = self._prepare_execution_tools() + multimodal_registry = self._build_multimodal_registry(input_args) + + # Mutable execution state — shared with budget() tool via closure + execution_state = {"start_time": _time.monotonic(), "iteration": 0} + + execution_tools = self._prepare_execution_tools( + multimodal_registry=multimodal_registry, + execution_state=execution_state, + ) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: history: REPLHistory = REPLHistory(max_output_chars=self.max_output_chars) for iteration in range(self.max_iterations): + execution_state["iteration"] = iteration + + # Check time budget before starting iteration + if self.max_time is not None: + elapsed = _time.monotonic() - execution_state["start_time"] + if elapsed > self.max_time: + logger.warning( + f"RLM time budget exceeded ({elapsed:.1f}s > {self.max_time:.1f}s) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return self._extract_fallback(variables, history, output_field_names) + + # Check cost budget before starting iteration + if self.max_cost is not None: + get_cost = execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + if cost_spent > self.max_cost: + logger.warning( + f"RLM cost budget exceeded (${cost_spent:.4f} > ${self.max_cost:.4f}) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return self._extract_fallback(variables, history, output_field_names) + result: Prediction | REPLHistory = self._execute_iteration( repl, variables, history, iteration, input_args, output_field_names ) @@ -630,17 +1111,51 @@ async def aforward(self, **input_args) -> Prediction: Raises: ValueError: If required input fields are missing + RuntimeError: If max_time budget is exceeded """ self._validate_inputs(input_args) + input_args = self._wrap_rlm_inputs(input_args) output_field_names = list(self.signature.output_fields.keys()) - execution_tools = self._prepare_execution_tools() + multimodal_registry = self._build_multimodal_registry(input_args) + + # Mutable execution state — shared with budget() tool via closure + execution_state = {"start_time": _time.monotonic(), "iteration": 0} + + execution_tools = self._prepare_execution_tools( + multimodal_registry=multimodal_registry, + execution_state=execution_state, + ) variables = self._build_variables(**input_args) with self._interpreter_context(execution_tools) as repl: history = REPLHistory(max_output_chars=self.max_output_chars) for iteration in range(self.max_iterations): + execution_state["iteration"] = iteration + + # Check time budget before starting iteration + if self.max_time is not None: + elapsed = _time.monotonic() - execution_state["start_time"] + if elapsed > self.max_time: + logger.warning( + f"RLM time budget exceeded ({elapsed:.1f}s > {self.max_time:.1f}s) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return await self._aextract_fallback(variables, history, output_field_names) + + # Check cost budget before starting iteration + if self.max_cost is not None: + get_cost = execution_state.get("_get_cost_and_tokens") + if get_cost is not None: + cost_spent, _ = get_cost() + if cost_spent > self.max_cost: + logger.warning( + f"RLM cost budget exceeded (${cost_spent:.4f} > ${self.max_cost:.4f}) " + f"at iteration {iteration + 1}, using extract fallback" + ) + return await self._aextract_fallback(variables, history, output_field_names) + result = await self._aexecute_iteration( repl, variables, history, iteration, input_args, output_field_names ) diff --git a/dspy/primitives/local_interpreter.py b/dspy/primitives/local_interpreter.py new file mode 100644 index 0000000000..9ab282b59d --- /dev/null +++ b/dspy/primitives/local_interpreter.py @@ -0,0 +1,175 @@ +""" +Unsandboxed local Python interpreter for RLM. + +Implements the CodeInterpreter protocol but executes code directly in the host +Python process via exec(). This gives the RLM agent full access to any installed +Python package (PIL, pydub, numpy, scipy, etc.). + +Use this when the sandboxed PythonInterpreter (Deno/Pyodide) is too restrictive — +e.g., when the RLM agent needs to manipulate images with PIL or process audio +with pydub directly in its generated code. + +Security: This is intentionally UNSANDBOXED. The LLM-generated code runs with +full host process privileges. Only use for local experiments or trusted workloads. + +Usage: + from dspy.primitives.local_interpreter import LocalInterpreter + + rlm = dspy.RLM("context -> answer", interpreter=LocalInterpreter()) +""" + +import io +import sys +import traceback +from typing import Any, Callable + +from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput + + +class _SubmitCalledError(Exception): + """Internal signal raised when SUBMIT() is called in user code.""" + def __init__(self, output: Any): + self.output = output + + +class LocalInterpreter: + """Unsandboxed Python interpreter implementing the CodeInterpreter protocol. + + Executes code directly in the host process via exec(). State persists + across execute() calls within a session. Tools are injected as callable + functions in the execution namespace. + + This gives the RLM agent full access to the host Python environment: + - PIL/Pillow for image manipulation + - pydub/ffmpeg for audio manipulation + - numpy, scipy, scikit-image, etc. + - Any installed Python package + + Note: Not thread-safe. Create separate instances for concurrent use. + """ + + def __init__( + self, + tools: dict[str, Callable[..., str]] | None = None, + output_fields: list[dict] | None = None, + ): + """ + Args: + tools: Dictionary mapping tool names to callable functions. + Tools are available as top-level functions in the namespace. + output_fields: Output field definitions for typed SUBMIT signature. + """ + self._tools: dict[str, Callable[..., str]] = dict(tools) if tools else {} + self.output_fields = output_fields + self._namespace: dict[str, Any] = {} + self._started = False + + @property + def tools(self) -> dict[str, Callable[..., str]]: + """Tools available for interpreter code to call.""" + return self._tools + + @tools.setter + def tools(self, value: dict[str, Callable[..., str]]) -> None: + self._tools = value + + def start(self) -> None: + """Initialize the interpreter namespace.""" + if self._started: + return + self._namespace = {"__builtins__": __builtins__} + self._started = True + + def execute( + self, + code: str, + variables: dict[str, Any] | None = None, + ) -> Any: + """Execute Python code in the host process. + + Args: + code: Python code to execute. + variables: Variables to inject into the namespace before execution. + Media objects (Audio, Image) are injected AS-IS, giving + code direct access to their data for manipulation. + + Returns: + - FinalOutput: If SUBMIT() was called + - str: Captured stdout (from print() calls) + - None: If no output was produced + + Raises: + CodeInterpreterError: On runtime errors + SyntaxError: On invalid Python syntax + """ + if not self._started: + self.start() + + # Inject variables directly into namespace (no serialization — objects stay live) + if variables: + self._namespace.update(variables) + + # Inject tools as callable functions + for name, func in self._tools.items(): + self._namespace[name] = func + + # Inject SUBMIT function — maps args to output field names (matching PythonInterpreter) + output_fields = self.output_fields or [] + field_names = [f["name"] for f in output_fields] + + def SUBMIT(*args, **kwargs): # noqa: N802 + if not args and not kwargs: + raise ValueError("SUBMIT requires at least one argument") + if args and kwargs: + raise ValueError("SUBMIT accepts either positional args or keyword args, not both") + if kwargs: + output = kwargs + elif field_names: + if len(args) != len(field_names): + expected = ", ".join(field_names) + raise TypeError( + f"SUBMIT() takes {len(field_names)} positional argument(s) " + f"({expected}), but {len(args)} were given" + ) + output = dict(zip(field_names, args, strict=False)) + elif len(args) == 1: + output = {"output": args[0]} + else: + output = {"output": args} + raise _SubmitCalledError(output) + + self._namespace["SUBMIT"] = SUBMIT + + # Capture stdout + old_stdout = sys.stdout + captured = io.StringIO() + sys.stdout = captured + + try: + exec(code, self._namespace) + except _SubmitCalledError as e: + return FinalOutput(e.output) + except SyntaxError: + raise + except Exception as e: + tb = traceback.format_exc() + raise CodeInterpreterError(f"{type(e).__name__}: {e}\n{tb}") from e + finally: + sys.stdout = old_stdout + + output = captured.getvalue() + if output: + return output.rstrip("\n") + return None + + def shutdown(self) -> None: + """Release resources and clear the namespace.""" + self._namespace.clear() + self._started = False + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.shutdown() diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 76b685ad58..cd39f7929c 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -19,6 +19,11 @@ from dspy.primitives.code_interpreter import SIMPLE_TYPES, CodeInterpreterError, FinalOutput + +def _has_rlm_support(value: Any) -> bool: + """Check if a value is a dspy.Type with RLM sandbox support (to_sandbox protocol).""" + return hasattr(value, "to_sandbox") and callable(getattr(value, "to_sandbox", None)) + __all__ = ["PythonInterpreter", "FinalOutput", "CodeInterpreterError"] logger = logging.getLogger(__name__) @@ -308,7 +313,7 @@ def _handle_tool_call(self, request: dict) -> None: if tool_name not in self.tools: raise CodeInterpreterError(f"Unknown tool: {tool_name}") result = self.tools[tool_name](*args, **kwargs) - is_json = isinstance(result, (list, dict)) + is_json = isinstance(result, list | dict) response = _jsonrpc_result( {"value": json.dumps(result) if is_json else str(result or ""), "type": "json" if is_json else "string"}, request_id @@ -376,11 +381,13 @@ def _health_check(self) -> None: def _to_json_compatible(self, value: Any) -> Any: """Recursively convert Python values to JSON-compatible types.""" - if value is None or isinstance(value, (str, int, float, bool)): + if value is None or isinstance(value, str | int | float | bool): return value + elif _has_rlm_support(value): + return value.rlm_preview() if hasattr(value, "rlm_preview") else repr(value) elif isinstance(value, dict): return {k: self._to_json_compatible(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): + elif isinstance(value, list | tuple): return [self._to_json_compatible(v) for v in value] elif isinstance(value, set): try: @@ -391,14 +398,35 @@ def _to_json_compatible(self, value: Any) -> Any: raise CodeInterpreterError(f"Unsupported value type: {type(value).__name__}") def _inject_variables(self, code: str, variables: dict[str, Any]) -> str: - """Insert Python assignments for each variable at the top of the code.""" + """Insert Python assignments for each variable at the top of the code. + + Supports dspy.Type instances with RLM sandbox support via to_sandbox(), + with fallback to JSON serialization for other types. + """ for key in variables: if not key.isidentifier() or keyword.iskeyword(key) or key == "json": raise CodeInterpreterError(f"Invalid variable name: '{key}'") + # Variables with custom RLM serialization (via to_sandbox interface) + rlm_type_vars: dict[str, bytes] = {} # name -> payload large_vars = {} small_assignments = [] + setup_imports = set() + rlm_assignments = [] + for k, v in variables.items(): + if _has_rlm_support(v): + payload = v.to_sandbox() + rlm_type_vars[k] = payload + data_expr = f"open('/tmp/dspy_vars/{k}.json').read()" + rlm_assignments.append(v.sandbox_assignment(k, data_expr)) + if hasattr(v, "sandbox_setup"): + setup = v.sandbox_setup() + if setup: + setup_imports.add(setup) + continue + + # Standard serialization for other types serialized = self._serialize_value(v) if len(serialized) > LARGE_VAR_THRESHOLD: large_vars[k] = json.dumps(self._to_json_compatible(v)) @@ -406,12 +434,20 @@ def _inject_variables(self, code: str, variables: dict[str, Any]) -> str: small_assignments.append(f"{k} = {serialized}") self._pending_large_vars = large_vars + self._pending_rlm_type_vars = rlm_type_vars + # Build imports + imports = list(setup_imports) if large_vars: - large_assignments = [f"{k} = json.loads(open('/tmp/dspy_vars/{k}.json').read())" for k in large_vars] - assignments = ["import json"] + small_assignments + large_assignments - else: - assignments = small_assignments + imports.append("import json") + + # Build assignments for large JSON vars + large_assignments = [ + f"{k} = json.loads(open('/tmp/dspy_vars/{k}.json').read())" + for k in large_vars + ] + + assignments = imports + small_assignments + large_assignments + rlm_assignments return "\n".join(assignments) + "\n" + code if assignments else code @@ -428,9 +464,9 @@ def _serialize_value(self, value: Any) -> str: elif isinstance(value, bool): # Must check bool before int since bool is a subclass of int return "True" if value else "False" - elif isinstance(value, (int, float)): + elif isinstance(value, int | float): return str(value) - elif isinstance(value, (list, tuple)): + elif isinstance(value, list | tuple): # Tuples become lists for JSON compatibility items = ", ".join(self._serialize_value(item) for item in value) return f"[{items}]" @@ -448,6 +484,11 @@ def _serialize_value(self, value: Any) -> str: sorted_items = list(value) items = ", ".join(self._serialize_value(item) for item in sorted_items) return f"[{items}]" + elif _has_rlm_support(value): + # Types with RLM sandbox support are represented as preview strings in the sandbox. + # The actual data is accessed via the type's sandbox protocol methods. + preview = value.rlm_preview() if hasattr(value, "rlm_preview") else repr(value) + return repr(preview) else: raise CodeInterpreterError(f"Unsupported value type: {type(value).__name__}") @@ -455,6 +496,15 @@ def _inject_large_var(self, name: str, value: str) -> None: """Inject a large variable via the virtual filesystem.""" self._send_request("inject_var", {"name": name, "value": value}, f"injecting variable '{name}'") + def _inject_pending_vars(self) -> None: + """Inject all pending large and RLM type variables into the sandbox.""" + for name, value in self._pending_large_vars.items(): + self._inject_large_var(name, value) + for name, payload in getattr(self, "_pending_rlm_type_vars", {}).items(): + if payload is not None: + value = payload.decode("utf-8") if isinstance(payload, bytes) else str(payload) + self._inject_large_var(name, value) + def execute( self, code: str, @@ -467,8 +517,7 @@ def execute( self._mount_files() self._register_tools() - for name, value in self._pending_large_vars.items(): - self._inject_large_var(name, value) + self._inject_pending_vars() # Send the code as JSON-RPC request self._request_id += 1 @@ -484,8 +533,7 @@ def execute( self._ensure_deno_process() self._mount_files() self._register_tools() - for name, value in self._pending_large_vars.items(): - self._inject_large_var(name, value) + self._inject_pending_vars() self.deno_process.stdin.write(input_data + "\n") self.deno_process.stdin.flush() diff --git a/dspy/teleprompt/bootstrap_trace.py b/dspy/teleprompt/bootstrap_trace.py index 2f0cc60cd1..80c1c8d660 100644 --- a/dspy/teleprompt/bootstrap_trace.py +++ b/dspy/teleprompt/bootstrap_trace.py @@ -107,6 +107,13 @@ def patched_forward(program_to_use: Module, **kwargs): ) return failed_pred, trace + except Exception as e: + # Catch non-parse failures (e.g. RLM timeout, interpreter crash, + # cost overrun). Preserve whatever partial trace was captured so + # GEPA can still reflect on the calls that happened before failure. + trace = dspy.settings.trace.copy() + failed_pred = FailedPrediction(completion_text=str(e)) + return failed_pred, trace program.forward = MethodType(patched_forward, program) diff --git a/tests/predict/test_rlm.py b/tests/predict/test_rlm.py index dab71e3416..45503d5bbf 100644 --- a/tests/predict/test_rlm.py +++ b/tests/predict/test_rlm.py @@ -1110,5 +1110,1797 @@ def test_with_llm_query(self): assert "dog" in result.answer.lower() +# ============================================================================ +# Unit Tests: Multimodal Media Support (Audio/Image) +# ============================================================================ + + +class TestMultimodalDetection: + """Unit tests for multimodal field detection and registry building (types protocol).""" + + def test_detect_audio_field(self): + """Test _detect_multimodal_fields finds Audio-typed inputs.""" + import dspy + from dspy.adapters.types.audio import Audio + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + fields = rlm._detect_multimodal_fields() + + assert "audio_input" in fields + assert fields["audio_input"] == "Audio" + + def test_detect_image_field(self): + """Test _detect_multimodal_fields finds Image-typed inputs.""" + import dspy + from dspy.adapters.types.image import Image + + class DescribeSig(dspy.Signature): + """Describe an image.""" + photo: Image = dspy.InputField() + description: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('a cat')"}]): + rlm = RLM(DescribeSig, max_iterations=3) + fields = rlm._detect_multimodal_fields() + + assert "photo" in fields + assert fields["photo"] == "Image" + + def test_detect_mixed_multimodal_fields(self): + """Test _detect_multimodal_fields finds both Audio and Image in same signature.""" + import dspy + from dspy.adapters.types.audio import Audio + from dspy.adapters.types.image import Image + + class MultimodalSig(dspy.Signature): + """Process audio and image together.""" + audio_clip: Audio = dspy.InputField() + photo: Image = dspy.InputField() + analysis: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('done')"}]): + rlm = RLM(MultimodalSig, max_iterations=3) + fields = rlm._detect_multimodal_fields() + + assert len(fields) == 2 + assert fields["audio_clip"] == "Audio" + assert fields["photo"] == "Image" + + def test_no_multimodal_fields(self): + """Test _detect_multimodal_fields returns empty dict for text-only signatures.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + fields = rlm._detect_multimodal_fields() + + assert fields == {} + + def test_build_multimodal_registry_with_audio(self): + """Test _build_multimodal_registry extracts Audio objects from inputs.""" + import dspy + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + registry = rlm._build_multimodal_registry({"audio_input": audio}) + + assert "audio_input" in registry + assert registry["audio_input"] is audio + + def test_build_multimodal_registry_ignores_text(self): + """Test _build_multimodal_registry skips non-multimodal values.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + registry = rlm._build_multimodal_registry({"query": "hello world"}) + + assert registry == {} + + def test_wrap_rlm_inputs_passthrough_already_wrapped(self): + """Test _wrap_rlm_inputs passes through already-wrapped dspy types.""" + import dspy + from dspy.adapters.types.audio import Audio + + class TranscribeSig(dspy.Signature): + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + audio = Audio(data="dGVzdA==", audio_format="wav") + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + wrapped = rlm._wrap_rlm_inputs({"audio_input": audio}) + + assert wrapped["audio_input"] is audio + + def test_wrap_rlm_inputs_passthrough_plain_types(self): + """Test _wrap_rlm_inputs passes through plain types like str.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + wrapped = rlm._wrap_rlm_inputs({"query": "hello"}) + + assert wrapped["query"] == "hello" + + +class TestLLMQueryWithMedia: + """Unit tests for llm_query_with_media tool creation and validation.""" + + def test_media_tool_always_available(self): + """Test llm_query_with_media is always created as a tool.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools() + + assert "llm_query_with_media" in tools + assert "llm_query" in tools + assert "llm_query_batched" in tools + assert "budget" in tools + + def test_media_tool_rejects_empty_prompt(self): + """Test llm_query_with_media raises on empty prompt.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="prompt cannot be empty"): + tools["llm_query_with_media"]("") + + def test_media_tool_rejects_no_media_vars(self): + """Test llm_query_with_media raises when no media var names given.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="At least one media variable"): + tools["llm_query_with_media"]("transcribe this") + + def test_media_tool_rejects_unknown_var(self): + """Test llm_query_with_media raises on nonexistent media variable.""" + from dspy.adapters.types.audio import Audio + + audio = Audio(data="dGVzdA==", audio_format="wav") + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) + + with pytest.raises(ValueError, match="not found"): + tools["llm_query_with_media"]("transcribe", "nonexistent_var") + + def test_reserved_tool_names_includes_media(self): + """Test llm_query_with_media is in the reserved tool names set.""" + assert "llm_query_with_media" in RLM._RESERVED_TOOL_NAMES + + +class TestMediaInstructions: + """Unit tests for multimodal-specific instruction injection in signatures.""" + + def test_media_docs_in_action_instructions(self): + """Test that multimodal fields cause media docs to appear in instructions.""" + import dspy + from dspy.adapters.types.audio import Audio + + class TranscribeSig(dspy.Signature): + """Transcribe audio.""" + audio_input: Audio = dspy.InputField() + transcription: str = dspy.OutputField() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM(TranscribeSig, max_iterations=3) + action_sig, _extract_sig = rlm._build_signatures() + + instructions = action_sig.instructions + assert "audio_input" in instructions + assert "media" in instructions.lower() + + def test_no_media_guidelines_for_text_only(self): + """Test that text-only signatures do NOT include media-specific guidelines.""" + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + action_sig, _extract_sig = rlm._build_signatures() + + instructions = action_sig.instructions + # The generic llm_query_with_media tool is always documented, + # but media-specific guidelines should NOT appear + assert "FOR MEDIA INPUTS" not in instructions + + +class TestMultiModelSubCalls: + """Unit tests for multi-model sub-call routing via sub_lms parameter.""" + + def test_sub_lms_stored_on_init(self): + """Test sub_lms dict is stored on the RLM instance.""" + from unittest.mock import MagicMock + + lm1 = MagicMock() + lm2 = MagicMock() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": lm1, "pro": lm2}) + + assert rlm.sub_lms == {"flash": lm1, "pro": lm2} + + def test_sub_lms_defaults_to_empty_dict(self): + """Test sub_lms defaults to empty dict when not provided.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + assert rlm.sub_lms == {} + + def test_llm_query_routes_to_named_model(self): + """Test llm_query(prompt, model='name') routes to the correct LM.""" + from unittest.mock import MagicMock + + mock_flash = MagicMock(return_value=["flash response"]) + mock_pro = MagicMock(return_value=["pro response"]) + mock_default = MagicMock(return_value=["default response"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_default, + sub_lms={"flash": mock_flash, "pro": mock_pro}) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt", model="flash") + assert result == "flash response" + mock_flash.assert_called_once() + + result = tools["llm_query"]("test prompt", model="pro") + assert result == "pro response" + mock_pro.assert_called_once() + + def test_llm_query_default_model_uses_sub_lm(self): + """Test llm_query without model param falls back to sub_lm.""" + from unittest.mock import MagicMock + + mock_sub = MagicMock(return_value=["sub_lm response"]) + mock_flash = MagicMock(return_value=["flash response"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_sub, sub_lms={"flash": mock_flash}) + tools = rlm._make_llm_tools() + + # model=None should use sub_lm, not flash + result = tools["llm_query"]("test prompt") + assert result == "sub_lm response" + mock_sub.assert_called_once() + mock_flash.assert_not_called() + + def test_llm_query_raises_on_unknown_model(self): + """Test llm_query raises ValueError for unknown model name.""" + from unittest.mock import MagicMock + + mock_flash = MagicMock(return_value=["flash"]) + + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": mock_flash}) + tools = rlm._make_llm_tools() + + with pytest.raises(ValueError, match="Model 'nonexistent' not found"): + tools["llm_query"]("test", model="nonexistent") + + def test_llm_query_batched_routes_to_named_model(self): + """Test llm_query_batched with model param routes all prompts to named LM.""" + from unittest.mock import MagicMock + + mock_default = MagicMock(return_value=["default response"]) + mock_pro = MagicMock(return_value=["pro response"]) + + rlm = RLM("query -> answer", max_iterations=3, max_llm_calls=10, + sub_lm=mock_default, sub_lms={"pro": mock_pro}) + tools = rlm._make_llm_tools() + + results = tools["llm_query_batched"](["prompt1", "prompt2"], model="pro") + assert len(results) == 2 + # Both should come from the pro LM + assert all(r == "pro response" for r in results) + assert mock_pro.call_count == 2 + mock_default.assert_not_called() + + def test_llm_query_with_media_routes_to_named_model(self): + """Test llm_query_with_media with model kwarg routes to named LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_default = MagicMock(return_value=["default response"]) + mock_pro = MagicMock(return_value=["pro media response"]) + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_default, + sub_lms={"pro": mock_pro}) + tools = rlm._make_llm_tools(multimodal_registry={"audio_input": audio}) + + result = tools["llm_query_with_media"]("transcribe", "audio_input", model="pro") + assert result == "pro media response" + mock_pro.assert_called_once() + mock_default.assert_not_called() + + def test_model_docs_in_instructions_when_sub_lms(self): + """Test that model names appear in action instructions when sub_lms is set.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock() + + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3, sub_lms={"flash": mock_lm, "pro": mock_lm}) + action_sig, _ = rlm._build_signatures() + + instructions = action_sig.instructions + assert "flash" in instructions + assert "pro" in instructions + assert "model=" in instructions + + def test_no_model_docs_when_no_sub_lms(self): + """Test that model docs are absent when sub_lms is not set.""" + with dummy_lm_context([{"reasoning": "test", "code": "SUBMIT('hi')"}]): + rlm = RLM("query -> answer", max_iterations=3) + action_sig, _ = rlm._build_signatures() + + instructions = action_sig.instructions + assert "Available models:" not in instructions + + def test_call_count_shared_across_models(self): + """Test that the call counter is shared across all model choices.""" + from unittest.mock import MagicMock + + mock_default = MagicMock(return_value=["default"]) + mock_flash = MagicMock(return_value=["flash"]) + mock_pro = MagicMock(return_value=["pro"]) + + rlm = RLM("query -> answer", max_iterations=3, max_llm_calls=3, + sub_lm=mock_default, sub_lms={"flash": mock_flash, "pro": mock_pro}) + tools = rlm._make_llm_tools() + + tools["llm_query"]("p1", model="flash") + tools["llm_query"]("p2", model="pro") + tools["llm_query"]("p3") # default + + # 4th call should exceed the limit of 3 + with pytest.raises(RuntimeError, match="LLM call limit exceeded"): + tools["llm_query"]("p4", model="flash") + + +class TestBudgetTracking: + """Tests for budget() tool and max_time enforcement.""" + + def test_max_time_initialization(self): + """Test that max_time is stored on the RLM instance.""" + rlm = RLM("query -> answer", max_time=60.0) + assert rlm.max_time == 60.0 + + def test_max_time_default_none(self): + """Test that max_time defaults to None (no limit).""" + rlm = RLM("query -> answer") + assert rlm.max_time is None + + def test_budget_tool_created(self): + """Test that the budget tool is included in execution tools.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + tools = rlm._make_llm_tools() + assert "budget" in tools + assert callable(tools["budget"]) + + def test_budget_returns_string(self): + """Test that budget() returns a human-readable string.""" + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=20) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 3} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert isinstance(result, str) + assert "Iterations:" in result + assert "LLM calls:" in result + + def test_budget_reflects_iteration(self): + """Test that budget() shows correct remaining iterations.""" + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=20) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 7} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + # iteration=7, max=10, remaining = 10 - 7 - 1 = 2 + assert "2/10 remaining" in result + + def test_budget_reflects_llm_calls(self): + """Test that budget() shows correct remaining LLM calls after usage.""" + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + + from unittest.mock import MagicMock + mock_lm = MagicMock(return_value=["response"]) + rlm_with_lm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, sub_lm=mock_lm) + tools = rlm_with_lm._make_llm_tools(execution_state=execution_state) + + # Use 3 LLM calls + tools["llm_query"]("prompt1") + tools["llm_query"]("prompt2") + tools["llm_query"]("prompt3") + + result = tools["budget"]() + assert "7/10 remaining" in result + + def test_budget_shows_time_when_max_time_set(self): + """Test that budget() includes time info when max_time is configured.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, max_time=120.0) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "Time:" in result + assert "/120.0s remaining" in result + + def test_budget_no_time_when_max_time_none(self): + """Test that budget() shows 'no limit' when max_time is None.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "no limit" in result + + def test_budget_reserved_name(self): + """Test that 'budget' is a reserved tool name.""" + def budget() -> str: + return "custom" + + from dspy.adapters.types.tool import Tool + tool = Tool(budget, name="budget") + with pytest.raises(ValueError, match="conflicts with built-in"): + RLM("query -> answer", tools=[tool]) + + def test_budget_in_action_instructions(self): + """Test that the action instructions mention budget().""" + rlm = RLM("query -> answer", max_iterations=5) + action_sig = rlm.generate_action.signature + assert "budget()" in action_sig.instructions + + def test_max_time_triggers_extract_fallback(self): + """Test that exceeding max_time triggers extract fallback (not exception).""" + + mock = MockInterpreter(responses=[ + "exploring...", + "still exploring...", + ]) + # Set max_time to 0 so it's already exceeded on first check + rlm = RLM("query -> answer", max_iterations=5, max_time=0.0, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "timeout_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "timeout_fallback" + assert result.final_reasoning == "Extract forced final output" + + def test_max_time_none_no_timeout(self): + """Test that max_time=None means no time checking.""" + mock = MockInterpreter(responses=[FinalOutput({"answer": "42"})]) + rlm = RLM("query -> answer", max_iterations=5, max_time=None, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("42")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "42" + + def test_budget_iteration_updates_via_tool(self): + """Test that budget() reports decreasing iterations across a forward() run.""" + budget_reports = [] + + class BudgetCapturingInterpreter(MockInterpreter): + def __init__(self): + super().__init__(responses=["output1", "output2", FinalOutput({"answer": "done"})]) + + def execute(self, code, variables=None): + # Call the budget tool if available + if "budget" in self.tools: + budget_reports.append(self.tools["budget"]()) + return super().execute(code, variables) + + mock_interp = BudgetCapturingInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=mock_interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Step 1", "code": "print('a')"}, + {"reasoning": "Step 2", "code": "print('b')"}, + {"reasoning": "Step 3", "code": 'SUBMIT("done")'}, + ]) + + result = rlm(query="test") + assert result.answer == "done" + # We got 3 iterations (0, 1, 2), should have 3 budget reports + assert len(budget_reports) == 3 + # Each report should show decreasing remaining iterations + for report in budget_reports: + assert "Iterations:" in report + assert "LLM calls:" in report + + + def test_max_cost_initialization(self): + """Test that max_cost is stored on the RLM instance.""" + rlm = RLM("query -> answer", max_cost=0.10) + assert rlm.max_cost == 0.10 + + def test_max_cost_default_none(self): + """Test that max_cost defaults to None (no limit).""" + rlm = RLM("query -> answer") + assert rlm.max_cost is None + + def test_budget_shows_cost_when_max_cost_set(self): + """Test that budget() includes cost info when max_cost is configured.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10, max_cost=0.50) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "Cost:" in result + assert "$0.50" in result + + def test_budget_no_cost_when_max_cost_none_and_no_spending(self): + """Test that budget() omits cost when max_cost is None and nothing spent.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=10) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + # No cost tracking when max_cost is None and no LM history entries with cost + assert "Cost:" not in result or "no limit" in result + + def test_max_cost_zero_triggers_immediate_fallback(self): + """Test that max_cost=0 triggers extract fallback immediately.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + # Give the mock LM a history with a cost entry + mock_lm.history = [{"cost": 0.001, "usage": {"total_tokens": 100}}] + + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_cost=0.0, sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "cost_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "cost_fallback" + + + def test_byok_cost_upstream_inference_cost(self): + """Test cost tracking includes usage.cost_details.upstream_inference_cost for BYOK.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + # Start with empty history — entries added after tool creation + mock_lm.history = [] + + rlm = RLM("query -> answer", max_cost=1.0, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + + # Simulate BYOK responses arriving after tool creation + mock_lm.history.extend([ + { + "cost": 0, + "usage": { + "total_tokens": 500, + "is_byok": True, + "cost_details": {"upstream_inference_cost": 0.0025}, + }, + }, + { + "cost": None, + "usage": { + "total_tokens": 300, + "cost_details": {"upstream_inference_cost": 0.0015}, + }, + }, + ]) + + result = tools["budget"]() + # Should pick up the upstream_inference_cost values: 0.0025 + 0.0015 + assert "$0.0040" in result + assert "800 tokens" in result # 500 + 300 + + def test_cost_sums_provider_and_upstream(self): + """Test cost tracking sums both provider cost and upstream inference cost.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [] + + rlm = RLM("query -> answer", max_cost=1.0, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + + # Non-BYOK: both provider cost and upstream cost are nonzero + mock_lm.history.append({ + "cost": 0.01, + "usage": { + "total_tokens": 1000, + "cost_details": {"upstream_inference_cost": 0.005}, + }, + }) + + result = tools["budget"]() + # Sum of 0.01 + 0.005 = 0.015 + assert "$0.0150" in result + + def test_budget_warning_low_iterations(self): + """Test that budget() shows warning when iterations are low.""" + rlm = RLM("query -> answer", max_iterations=5, max_llm_calls=50) + # iteration=4 means 0 remaining (5 - 4 - 1) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 4} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" in result + assert "iterations" in result + + def test_budget_warning_low_time(self): + """Test that budget() shows warning when time is running low.""" + import time + rlm = RLM("query -> answer", max_iterations=20, max_llm_calls=50, max_time=10.0) + # Start time 9 seconds ago — only 1s remaining (10%) + execution_state = {"start_time": time.monotonic() - 9.0, "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" in result + assert "time" in result + + def test_budget_no_warning_when_plenty_remaining(self): + """Test that budget() has no warning when resources are plentiful.""" + rlm = RLM("query -> answer", max_iterations=20, max_llm_calls=50, max_time=60.0) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools(execution_state=execution_state) + result = tools["budget"]() + assert "LOW" not in result + + +# ============================================================================ +# Integration Tests: RLM + LocalInterpreter +# ============================================================================ + + +class TestRLMWithLocalInterpreter: + """Integration tests proving RLM and LocalInterpreter work together end-to-end.""" + + def test_basic_forward(self): + """Test RLM forward() with LocalInterpreter produces a result.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("hello")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "hello" + + def test_state_persists_across_iterations(self): + """Test that variables set in one iteration survive to the next.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Set a variable", "code": "x = 42\nprint(x)"}, + {"reasoning": "Use it", "code": "SUBMIT(str(x * 2))"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "84" + + def test_tools_accessible_in_code(self): + """Test that llm_query and budget tools are callable from LocalInterpreter code.""" + from unittest.mock import MagicMock + + from dspy.primitives.local_interpreter import LocalInterpreter + + mock_lm = MagicMock(return_value=["mocked response"]) + mock_lm.history = [] + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, sub_lm=mock_lm, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Use tools", "code": 'b = budget()\nresult = llm_query("hi")\nprint(result)'}, + {"reasoning": "Submit", "code": "SUBMIT(result)"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "mocked response" + + def test_stdlib_imports_work(self): + """Test that LocalInterpreter allows stdlib imports inside RLM.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Use stdlib", "code": 'import json\nSUBMIT(json.dumps({"a": 1}))'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == '{"a": 1}' + + def test_error_recovery(self): + """Test that a runtime error in one iteration doesn't kill the session.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=5, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "This will error", "code": "1 / 0"}, + {"reasoning": "Recover", "code": 'SUBMIT("recovered")'}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "recovered" + + def test_max_time_with_local_interpreter(self): + """Test that max_time budget enforcement works with LocalInterpreter.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=10, max_time=0.0, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "timeout_fallback"}, + ]) + + result = rlm.forward(query="test") + assert result.answer == "timeout_fallback" + + @pytest.mark.asyncio + async def test_aforward_with_local_interpreter(self): + """Test RLM aforward() works with LocalInterpreter.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + rlm = RLM("query -> answer", max_iterations=3, interpreter=interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Return answer", "code": 'SUBMIT("async_hello")'}, + ]) + + result = await rlm.aforward(query="test") + assert result.answer == "async_hello" + + +# ============================================================================ +# Tests: llm_query_with_media content construction +# ============================================================================ + + +class TestMediaContentConstruction: + """Test that llm_query_with_media builds correct multimodal content for the LM.""" + + def test_media_content_parts_sent_to_lm(self): + """Test that llm_query_with_media sends multimodal content parts to the LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_lm = MagicMock(return_value=["transcription result"]) + mock_lm.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + multimodal_registry={"my_audio": audio}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("describe this audio", "my_audio") + assert result == "transcription result" + + # Verify the LM was called with messages containing multimodal content + mock_lm.assert_called_once() + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + assert messages is not None + assert len(messages) == 1 + assert messages[0]["role"] == "user" + content = messages[0]["content"] + assert isinstance(content, list) + # First part should be text + assert content[0]["type"] == "text" + assert content[0]["text"] == "describe this audio" + # Remaining parts should be from audio.format() + assert len(content) > 1 + + def test_image_content_parts_sent_to_lm(self): + """Test that llm_query_with_media sends image content parts to the LM.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.image import Image + + mock_lm = MagicMock(return_value=["a cat sitting on a mat"]) + mock_lm.history = [] + + image = Image(url="https://example.com/cat.jpg") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + multimodal_registry={"my_image": image}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("what is in this image?", "my_image") + assert result == "a cat sitting on a mat" + + # Verify multimodal message structure + mock_lm.assert_called_once() + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + content = messages[0]["content"] + assert content[0]["type"] == "text" + assert len(content) > 1 + + def test_multiple_media_objects_in_one_call(self): + """Test llm_query_with_media with multiple media variables.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + from dspy.adapters.types.image import Image + + mock_lm = MagicMock(return_value=["combined analysis"]) + mock_lm.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + image = Image(url="https://example.com/photo.jpg") + + rlm = RLM("query -> answer", max_iterations=3, sub_lm=mock_lm) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + multimodal_registry={"audio_in": audio, "image_in": image}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("analyze both", "audio_in", "image_in") + assert result == "combined analysis" + + # Verify both media objects were included + call_kwargs = mock_lm.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + content = messages[0]["content"] + # text + audio parts + image parts + assert len(content) >= 3 + + def test_media_with_model_routing(self): + """Test llm_query_with_media routes to named model when specified.""" + from unittest.mock import MagicMock + + from dspy.adapters.types.audio import Audio + + mock_default = MagicMock(return_value=["default"]) + mock_default.history = [] + mock_pro = MagicMock(return_value=["pro result"]) + mock_pro.history = [] + + audio = Audio(data="dGVzdA==", audio_format="wav") + + rlm = RLM("query -> answer", max_iterations=3, + sub_lm=mock_default, sub_lms={"pro": mock_pro}) + execution_state = {"start_time": __import__("time").monotonic(), "iteration": 0} + tools = rlm._make_llm_tools( + multimodal_registry={"audio": audio}, + execution_state=execution_state, + ) + + result = tools["llm_query_with_media"]("transcribe", "audio", model="pro") + assert result == "pro result" + mock_pro.assert_called_once() + mock_default.assert_not_called() + + +# ============================================================================ +# Tests: max_cost mid-run fallback +# ============================================================================ + + +class TestMaxCostMidRunFallback: + """Test that max_cost triggers extract fallback during a multi-iteration run.""" + + def test_cost_exceeded_mid_run_triggers_fallback(self): + """Test that exceeding max_cost mid-run triggers extract fallback, not crash.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [] + + mock = MockInterpreter(responses=[ + "first iteration output", + "second iteration output", # cost exceeded before this + ]) + rlm = RLM("query -> answer", max_iterations=10, max_cost=0.05, + sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + {"reasoning": "More", "code": "print('more')"}, + ]) + rlm.extract = make_mock_predictor([ + {"answer": "cost_fallback"}, + ]) + + # Simulate cost appearing after first iteration + # The forward loop checks cost at the start of each iteration + # After iteration 0, we inject cost into LM history + original_execute = mock.execute + + def execute_with_cost_injection(code, variables=None): + result = original_execute(code, variables) + # After first execute, inject cost exceeding budget + if mock.call_count == 1: + mock_lm.history.append({ + "cost": 0.10, # exceeds max_cost=0.05 + "usage": {"total_tokens": 5000}, + }) + return result + + mock.execute = execute_with_cost_injection + + result = rlm.forward(query="test") + assert result.answer == "cost_fallback" + assert result.final_reasoning == "Extract forced final output" + + +# ============================================================================ +# Tests: Async budget/time/cost +# ============================================================================ + + +class TestAsyncBudgetTimeCost: + """Test that budget, max_time, and max_cost work correctly in aforward().""" + + @pytest.mark.asyncio + async def test_aforward_max_time_triggers_fallback(self): + """Test that aforward() respects max_time and triggers extract fallback.""" + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_time=0.0, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ], async_mode=True) + rlm.extract = make_mock_predictor([ + {"answer": "async_timeout_fallback"}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "async_timeout_fallback" + + @pytest.mark.asyncio + async def test_aforward_max_cost_triggers_fallback(self): + """Test that aforward() respects max_cost and triggers extract fallback.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["response"]) + mock_lm.history = [{"cost": 1.0, "usage": {"total_tokens": 50000}}] + + mock = MockInterpreter(responses=["exploring..."]) + rlm = RLM("query -> answer", max_iterations=5, max_cost=0.0, + sub_lm=mock_lm, interpreter=mock) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Explore", "code": "print('exploring')"}, + ], async_mode=True) + rlm.extract = make_mock_predictor([ + {"answer": "async_cost_fallback"}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "async_cost_fallback" + + @pytest.mark.asyncio + async def test_aforward_budget_tool_works(self): + """Test that budget() tool is accessible in aforward() via MockInterpreter.""" + budget_reports = [] + + class BudgetCapturingInterpreter(MockInterpreter): + def __init__(self): + super().__init__(responses=["output", FinalOutput({"answer": "done"})]) + + def execute(self, code, variables=None): + if "budget" in self.tools: + budget_reports.append(self.tools["budget"]()) + return super().execute(code, variables) + + mock_interp = BudgetCapturingInterpreter() + rlm = RLM("query -> answer", max_iterations=10, max_llm_calls=30, + max_time=60.0, interpreter=mock_interp) + rlm.generate_action = make_mock_predictor([ + {"reasoning": "Step 1", "code": "print('a')"}, + {"reasoning": "Done", "code": 'SUBMIT("done")'}, + ], async_mode=True) + + result = await rlm.aforward(query="test") + assert result.answer == "done" + assert len(budget_reports) >= 1 + assert "Iterations:" in budget_reports[0] + assert "Time:" in budget_reports[0] + + +# ============================================================================ +# Tests: bootstrap_trace resilience for non-parse exceptions +# ============================================================================ + + +class TestBootstrapTraceResilience: + """Test that bootstrap_trace_data handles non-parse exceptions gracefully.""" + + def test_runtime_error_captured_as_failed_prediction(self): + """Test that a RuntimeError from forward() is captured, not propagated.""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import FailedPrediction, bootstrap_trace_data + + class CrashingModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + raise RuntimeError("RLM timeout exceeded") + + program = CrashingModule() + dataset = [ + dspy.Example(query="test1").with_inputs("query"), + dspy.Example(query="test2").with_inputs("query"), + ] + + def metric(example, prediction, trace=None): + if isinstance(prediction, FailedPrediction): + return 0.0 + return 1.0 + + # Mock the Evaluate class to directly call the program + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + self.failure_score = kwargs.get("failure_score", 0) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + score = metric(example, prediction) if metric else None + results.append((example, prediction, score)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + results = bootstrap_trace_data( + program=program, + dataset=dataset, + metric=metric, + raise_on_error=False, + ) + + assert len(results) == 2 + for result in results: + pred = result["prediction"] + assert isinstance(pred, FailedPrediction) + assert "RLM timeout exceeded" in pred.completion_text + # Trace should be preserved (even if empty, it shouldn't be None) + assert isinstance(result["trace"], list) + + def test_cost_overrun_captured_as_failed_prediction(self): + """Test that a cost overrun exception is captured with partial trace.""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import FailedPrediction, bootstrap_trace_data + + class CostOverrunModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + # Simulate partial work before cost overrun + # Add something to the trace before crashing + dspy.settings.trace.append( + (self.predictor, kwargs, dspy.Prediction(answer="partial")) + ) + raise RuntimeError("Cost budget exceeded ($0.55 > $0.50)") + + program = CostOverrunModule() + dataset = [dspy.Example(query="test").with_inputs("query")] + + def metric(example, prediction, trace=None): + return 0.0 + + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + score = metric(example, prediction) if metric else None + results.append((example, prediction, score)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + results = bootstrap_trace_data( + program=program, + dataset=dataset, + metric=metric, + raise_on_error=False, + ) + + assert len(results) == 1 + pred = results[0]["prediction"] + assert isinstance(pred, FailedPrediction) + assert "Cost budget exceeded" in pred.completion_text + # The partial trace from before the crash should be preserved + trace = results[0]["trace"] + assert isinstance(trace, list) + assert len(trace) == 1 # the one entry we appended before crashing + + def test_keyboard_interrupt_not_swallowed(self): + """Test that KeyboardInterrupt is NOT caught (it's BaseException, not Exception).""" + from unittest.mock import patch + + import dspy + from dspy.teleprompt.bootstrap_trace import bootstrap_trace_data + + class InterruptingModule(dspy.Module): + def __init__(self): + super().__init__() + self.predictor = dspy.Predict("query -> answer") + + def forward(self, **kwargs): + raise KeyboardInterrupt() + + program = InterruptingModule() + dataset = [dspy.Example(query="test").with_inputs("query")] + + class DirectEvaluate: + def __init__(self, **kwargs): + self.devset = kwargs.get("devset", []) + + def __call__(self, program, metric=None, **kwargs): + results = [] + for example in self.devset: + inputs = {k: example[k] for k in example.inputs()} + prediction = program(**inputs) + results.append((example, prediction, None)) + + class Result: + pass + r = Result() + r.results = results + return r + + with patch("dspy.teleprompt.bootstrap_trace.Evaluate", DirectEvaluate): + import dspy as _dspy + with _dspy.context(lm=_dspy.LM(model="openai/gpt-4o-mini"), trace=[]): + with pytest.raises(KeyboardInterrupt): + bootstrap_trace_data( + program=program, + dataset=dataset, + raise_on_error=False, + ) + + +# ============================================================================ +# Tests: LocalInterpreter output_fields via setter +# ============================================================================ + + +class TestLocalInterpreterOutputFieldsSetter: + """Test LocalInterpreter output_fields configuration paths.""" + + def test_output_fields_set_after_init(self): + """Test that output_fields can be set after construction and SUBMIT uses them.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + interp.output_fields = [{"name": "answer"}, {"name": "confidence"}] + interp.start() + + result = interp.execute('SUBMIT("hello", "high")') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "hello", "confidence": "high"} + + def test_output_fields_none_defaults_to_single_output(self): + """Test that without output_fields, single-arg SUBMIT wraps in 'output' key.""" + from dspy.primitives.local_interpreter import LocalInterpreter + + interp = LocalInterpreter() + interp.start() + + result = interp.execute('SUBMIT("hello")') + assert isinstance(result, FinalOutput) + assert result.output == {"output": "hello"} + + +# ============================================================================ +# Depth > 1 Tests: Recursive RLM with LocalInterpreter +# ============================================================================ + +import time as _time + + +class TestSubcallInit: + """Tests for depth/max_depth initialization and routing flag.""" + + def test_depth_max_depth_defaults(self): + """Default depth=0, max_depth=1 means no recursion.""" + rlm = RLM("query -> answer", max_iterations=3) + assert rlm.depth == 0 + assert rlm.max_depth == 1 + + def test_depth_max_depth_stored(self): + """Custom depth and max_depth are stored.""" + rlm = RLM("query -> answer", max_iterations=3, depth=1, max_depth=3) + assert rlm.depth == 1 + assert rlm.max_depth == 3 + + def test_max_depth_1_uses_plain_lm(self): + """With max_depth=1 (default), llm_query does a plain LM call.""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["plain response"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=1, sub_lm=mock_lm) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt") + assert result == "plain response" + mock_lm.assert_called_once() + + def test_max_depth_2_at_depth_0_is_recursive(self): + """With max_depth=2, depth=0: _subcall is called (not plain LM).""" + from unittest.mock import MagicMock, patch + + mock_lm = MagicMock(return_value=["should not be called"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, sub_lm=mock_lm) + + # Patch _subcall to capture that it's called instead of plain LM + with patch.object(rlm, "_subcall", return_value="subcall result") as mock_subcall: + tools = rlm._make_llm_tools() + result = tools["llm_query"]("test prompt") + + assert result == "subcall result" + mock_subcall.assert_called_once() + mock_lm.assert_not_called() + + def test_max_depth_2_at_depth_1_is_leaf(self): + """With max_depth=2, depth=1: plain LM call (leaf level).""" + from unittest.mock import MagicMock + + mock_lm = MagicMock(return_value=["leaf response"]) + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, depth=1, sub_lm=mock_lm) + tools = rlm._make_llm_tools() + + result = tools["llm_query"]("test prompt") + assert result == "leaf response" + mock_lm.assert_called_once() + + def test_max_depth_0_raises(self): + """max_depth < 1 should raise ValueError.""" + with pytest.raises(ValueError, match="max_depth must be >= 1"): + RLM("query -> answer", max_depth=0) + + def test_negative_depth_raises(self): + """Negative depth should raise ValueError.""" + with pytest.raises(ValueError, match="depth must be >= 0"): + RLM("query -> answer", depth=-1) + + +class TestSubcallTimeBudgetPropagation: + """Tests for max_time propagation to child RLM, adapted from vanilla RLM test_subcall.py.""" + + def test_child_receives_remaining_timeout(self): + """When parent has max_time=60 and 10s elapsed, child should get ~50s.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + # Only capture child inits (signature="prompt -> response") + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=60.0) + execution_state = {"start_time": _time.monotonic() - 10.0, "iteration": 0} + + with patch.object(RLM, "__init__", capturing_init): + # Child will fail (no LM configured) but we capture the kwargs before that + rlm._subcall("test", execution_state=execution_state) + + assert "max_time" in captured_child_kwargs + remaining = captured_child_kwargs["max_time"] + assert 45.0 < remaining < 55.0, f"Expected ~50s remaining, got {remaining}" + + def test_child_receives_none_timeout_when_parent_has_none(self): + """When parent has no max_time, child should also have None.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=None) + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + + assert captured_child_kwargs.get("max_time") is None + + def test_subcall_returns_error_when_time_exhausted(self): + """When time budget is already exhausted, _subcall returns error string.""" + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_time=10.0) + execution_state = {"start_time": _time.monotonic() - 15.0, "iteration": 0} + + result = rlm._subcall("test", execution_state=execution_state) + assert "Time budget exhausted" in result + + +class TestSubcallCostBudgetPropagation: + """Tests for max_cost propagation to child RLM.""" + + def test_child_receives_remaining_cost(self): + """When parent has max_cost=1.0 and $0.30 spent, child should get ~$0.70.""" + from unittest.mock import patch + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_cost=1.0) + execution_state = { + "start_time": _time.monotonic(), + "iteration": 0, + "_get_cost_and_tokens": lambda: (0.30, 5000), + } + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test", execution_state=execution_state) + + assert "max_cost" in captured_child_kwargs + remaining = captured_child_kwargs["max_cost"] + assert 0.69 < remaining < 0.71, f"Expected ~$0.70, got {remaining}" + + def test_subcall_returns_error_when_cost_exhausted(self): + """When cost budget is already exhausted, _subcall returns error string.""" + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_cost=0.01) + execution_state = { + "start_time": _time.monotonic(), + "iteration": 0, + "_get_cost_and_tokens": lambda: (0.02, 1000), + } + + result = rlm._subcall("test", execution_state=execution_state) + assert "Cost budget exhausted" in result + + +class TestSubcallModelOverride: + """Tests for model= parameter override in _subcall.""" + + def test_model_override_sets_child_sub_lm(self): + """When model='flash', child's sub_lm should be the flash LM instance.""" + from unittest.mock import MagicMock, patch + + mock_flash = MagicMock(return_value=["flash"]) + mock_pro = MagicMock(return_value=["pro"]) + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, + sub_lms={"flash": mock_flash, "pro": mock_pro}) + + def resolve_lm(model=None): + if model == "flash": + return mock_flash + if model == "pro": + return mock_pro + return None + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test", model="flash", resolve_lm=resolve_lm) + + assert captured_child_kwargs.get("sub_lm") is mock_flash + + def test_no_model_override_uses_parent_sub_lm(self): + """Without model override, child inherits parent's sub_lm.""" + from unittest.mock import MagicMock, patch + + mock_parent_sub = MagicMock(return_value=["parent sub"]) + + captured_child_kwargs = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured_child_kwargs.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, sub_lm=mock_parent_sub) + + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + + assert captured_child_kwargs.get("sub_lm") is mock_parent_sub + + +class TestSubcallParameterPropagation: + """Tests for combined parameter propagation to child RLM.""" + + def test_child_inherits_max_iterations(self): + """Child should receive parent's max_iterations.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_iterations=15, max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_iterations") == 15 + + def test_child_inherits_max_llm_calls(self): + """Child should receive parent's max_llm_calls.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_llm_calls=25, max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_llm_calls") == 25 + + def test_child_inherits_user_tools(self): + """Child should receive parent's user-provided tools.""" + from unittest.mock import patch + + def my_tool(x: str) -> str: + return x + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2, tools=[my_tool]) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("tools") is not None + assert len(captured["tools"]) == 1 + + def test_child_depth_incremented(self): + """Child should have depth = parent.depth + 1.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=3, depth=0) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("depth") == 1 + + def test_child_inherits_max_depth(self): + """Child should receive same max_depth as parent.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=3) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("max_depth") == 3 + + def test_child_inherits_sub_lms(self): + """Child should receive parent's sub_lms dict.""" + from unittest.mock import MagicMock, patch + + mock_flash = MagicMock() + mock_pro = MagicMock() + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2, sub_lms={"flash": mock_flash, "pro": mock_pro}) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert captured.get("sub_lms") == {"flash": mock_flash, "pro": mock_pro} + + +class TestSubcallInterpreterIsolation: + """Tests that child RLM gets an isolated LocalInterpreter.""" + + def test_child_gets_local_interpreter_when_parent_uses_local(self): + """When parent uses LocalInterpreter, child gets LocalInterpreter.""" + from unittest.mock import patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + rlm = RLM("query -> answer", max_depth=2, interpreter=LocalInterpreter()) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert isinstance(captured.get("interpreter"), LocalInterpreter) + + def test_child_gets_python_interpreter_when_parent_uses_default(self): + """When parent uses default (PythonInterpreter), child matches.""" + from unittest.mock import patch + + captured = {} + _original_init = RLM.__init__ + + def capturing_init(self_inner, *args, **kwargs): + sig = args[0] if args else kwargs.get("signature", "") + if sig == "prompt -> response": + captured.update(kwargs) + _original_init(self_inner, *args, **kwargs) + + # No interpreter= means parent uses default PythonInterpreter + rlm = RLM("query -> answer", max_depth=2) + with patch.object(RLM, "__init__", capturing_init): + rlm._subcall("test") + assert isinstance(captured.get("interpreter"), PythonInterpreter) + + def test_interpreter_shutdown_on_success(self): + """Child interpreter shutdown() is called after successful completion.""" + from unittest.mock import patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + shutdown_called = [] + _original_shutdown = LocalInterpreter.shutdown + + def tracking_shutdown(self_inner): + shutdown_called.append(True) + _original_shutdown(self_inner) + + with dummy_lm_context([ + {"reasoning": "Done", "code": 'SUBMIT(response="ok")'}, + ]): + # Parent uses LocalInterpreter so child matches + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, + interpreter=LocalInterpreter()) + with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): + rlm._subcall("test") + + assert len(shutdown_called) >= 1 + + def test_interpreter_shutdown_on_error(self): + """Child interpreter shutdown() is called even when child fails.""" + from unittest.mock import patch + + from dspy.primitives.local_interpreter import LocalInterpreter + + shutdown_called = [] + _original_shutdown = LocalInterpreter.shutdown + + def tracking_shutdown(self_inner): + shutdown_called.append(True) + _original_shutdown(self_inner) + + with dummy_lm_context([ + {"reasoning": "Bad", "code": 'raise Exception("boom")'}, + {"response": "fallback"}, # extract fallback + ]): + rlm = RLM("query -> answer", max_iterations=1, max_depth=2, + interpreter=LocalInterpreter()) + with patch.object(LocalInterpreter, "shutdown", tracking_shutdown): + rlm._subcall("test") + + assert len(shutdown_called) >= 1 + + +class TestSubcallE2E: + """End-to-end tests for depth>1 using DummyLM + LocalInterpreter.""" + + def test_depth_2_child_submits_response(self): + """Parent calls llm_query, child runs in LocalInterpreter and SUBMITs.""" + with dummy_lm_context([ + {"reasoning": "Answer directly", "code": 'SUBMIT(response="42")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("What is 6*7?") + assert result == "42" + + def test_depth_2_child_uses_prompt_variable(self): + """Child RLM receives the prompt as a variable it can use in code.""" + with dummy_lm_context([ + {"reasoning": "Use the prompt", "code": 'SUBMIT(response=f"Got: {prompt}")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("hello world") + assert result == "Got: hello world" + + def test_depth_2_child_inherits_tools(self): + """Child can call parent's user-provided tools.""" + call_log = [] + + def my_tool(x: str) -> str: + call_log.append(x) + return f"tool({x})" + + with dummy_lm_context([ + {"reasoning": "Use tool", "code": 'val = my_tool(x="hi")\nSUBMIT(response=val)'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, tools=[my_tool]) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("use tool") + assert result == "tool(hi)" + assert call_log == ["hi"] + + def test_depth_2_child_multi_iteration(self): + """Child RLM can take multiple iterations before SUBMITting.""" + with dummy_lm_context([ + {"reasoning": "Explore first", "code": 'x = len(prompt)\nprint(f"length={x}")'}, + {"reasoning": "Now submit", "code": "SUBMIT(response=str(x))"}, + ]): + rlm = RLM("query -> answer", max_iterations=5, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("hello") + assert result == "5" + + def test_depth_2_child_error_returns_string(self): + """Child failure returns error string, doesn't crash parent.""" + with dummy_lm_context([ + {"reasoning": "Crash", "code": 'raise RuntimeError("boom")'}, + {"response": "recovered"}, # extract fallback + ]): + rlm = RLM("query -> answer", max_iterations=1, max_depth=2) + tools = rlm._make_llm_tools() + result = tools["llm_query"]("test") + # Should get a string back (either "recovered" from extract or error message) + assert isinstance(result, str) + + def test_depth_2_batched_sequential(self): + """llm_query_batched with max_depth=2 runs children sequentially.""" + with dummy_lm_context([ + # Child 1 + {"reasoning": "First", "code": 'SUBMIT(response="a1")'}, + # Child 2 + {"reasoning": "Second", "code": 'SUBMIT(response="a2")'}, + ]): + rlm = RLM("query -> answer", max_iterations=3, max_depth=2, max_llm_calls=20) + tools = rlm._make_llm_tools() + results = tools["llm_query_batched"](["q1", "q2"]) + assert results == ["a1", "a2"] + + @pytest.mark.asyncio + async def test_depth_2_aforward(self): + """aforward() works with max_depth=2 (async parent, sync _subcall).""" + with dummy_lm_context([ + # Parent iter 1: call llm_query + {"reasoning": "Ask", "code": 'result = llm_query("test")\nprint(result)'}, + # Child iter 1: SUBMIT + {"reasoning": "Done", "code": 'SUBMIT(response="async_ok")'}, + # Parent iter 2: SUBMIT + {"reasoning": "Got it", "code": "SUBMIT(result)"}, + ]): + from dspy.primitives.local_interpreter import LocalInterpreter + rlm = RLM("query -> answer", max_iterations=5, max_depth=2, + interpreter=LocalInterpreter()) + result = await rlm.aforward(query="async test") + assert result.answer == "async_ok" + + @pytest.mark.deno + def test_depth_2_python_interpreter_parent(self): + """Parent uses PythonInterpreter (Deno), child uses LocalInterpreter. + + llm_query is a host-side tool callback (JSON-RPC from Deno), so _subcall + runs on the host and creates the child with its own LocalInterpreter. + """ + with dummy_lm_context([ + # Parent iter 1: call llm_query (triggers child RLM) + {"reasoning": "Delegate", "code": 'result = llm_query("compute 2+2")\nprint(result)'}, + # Child iter 1 (runs in LocalInterpreter): SUBMIT response + {"reasoning": "Easy", "code": 'SUBMIT(response="4")'}, + # Parent iter 2: SUBMIT the child's result + {"reasoning": "Done", "code": "SUBMIT(result)"}, + ]): + # No interpreter= means default PythonInterpreter (Deno sandbox) + rlm = RLM("query -> answer", max_iterations=5, max_depth=2) + result = rlm(query="What is 2+2?") + assert result.answer == "4" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/primitives/test_local_interpreter.py b/tests/primitives/test_local_interpreter.py new file mode 100644 index 0000000000..7f5c864ca5 --- /dev/null +++ b/tests/primitives/test_local_interpreter.py @@ -0,0 +1,312 @@ +"""Tests for LocalInterpreter — unsandboxed host-process Python execution.""" + +import pytest + +from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput +from dspy.primitives.local_interpreter import LocalInterpreter + +# ============================================================================= +# Basic Execution +# ============================================================================= + + +def test_execute_simple_print(): + with LocalInterpreter() as interp: + result = interp.execute("print('Hello, World!')") + assert result == "Hello, World!" + + +def test_execute_no_output(): + with LocalInterpreter() as interp: + result = interp.execute("x = 42") + assert result is None + + +def test_execute_multiline(): + with LocalInterpreter() as interp: + code = "a = 3\nb = 4\nprint(a + b)" + result = interp.execute(code) + assert result == "7" + + +def test_import_stdlib(): + with LocalInterpreter() as interp: + result = interp.execute("import math\nprint(math.sqrt(16))") + assert result == "4.0" + + +def test_import_third_party(): + """LocalInterpreter can import any installed package (unlike Pyodide sandbox).""" + with LocalInterpreter() as interp: + result = interp.execute("import numpy as np\nprint(np.array([1,2,3]).sum())") + assert result == "6" + + +# ============================================================================= +# State Persistence +# ============================================================================= + + +def test_state_persists_across_calls(): + with LocalInterpreter() as interp: + interp.execute("x = 10") + interp.execute("x += 5") + result = interp.execute("print(x)") + assert result == "15" + + +def test_state_cleared_on_shutdown(): + interp = LocalInterpreter() + interp.start() + interp.execute("x = 42") + interp.shutdown() + interp.start() + with pytest.raises(CodeInterpreterError, match="NameError"): + interp.execute("print(x)") + interp.shutdown() + + +def test_auto_start(): + """execute() auto-starts if not already started.""" + interp = LocalInterpreter() + result = interp.execute("print('auto')") + assert result == "auto" + interp.shutdown() + + +# ============================================================================= +# Variable Injection +# ============================================================================= + + +def test_variable_injection(): + with LocalInterpreter() as interp: + result = interp.execute("print(number + 1)", variables={"number": 4}) + assert result == "5" + + +def test_variable_injection_objects(): + """Variables are injected AS-IS, not serialized — live Python objects.""" + with LocalInterpreter() as interp: + data = {"key": [1, 2, 3]} + result = interp.execute("print(type(data).__name__, len(data['key']))", variables={"data": data}) + assert result == "dict 3" + + +def test_variable_injection_persists(): + with LocalInterpreter() as interp: + interp.execute("pass", variables={"x": 100}) + result = interp.execute("print(x)") + assert result == "100" + + +# ============================================================================= +# Error Handling +# ============================================================================= + + +def test_syntax_error(): + with LocalInterpreter() as interp: + with pytest.raises(SyntaxError): + interp.execute("def foo(") + + +def test_runtime_error(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="ZeroDivisionError"): + interp.execute("1 / 0") + + +def test_name_error(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="NameError"): + interp.execute("print(undefined_var)") + + +def test_error_includes_traceback(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="Traceback"): + interp.execute("raise RuntimeError('test error')") + + +def test_stdout_restored_after_error(): + """sys.stdout must be restored even if execution raises.""" + import sys + original_stdout = sys.stdout + interp = LocalInterpreter() + interp.start() + with pytest.raises(CodeInterpreterError): + interp.execute("raise ValueError('boom')") + assert sys.stdout is original_stdout + interp.shutdown() + + +# ============================================================================= +# SUBMIT / FinalOutput +# ============================================================================= + + +def test_submit_single_arg(): + """Single-arg SUBMIT wraps in {"output": value}.""" + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT("the answer")') + assert isinstance(result, FinalOutput) + assert result.output == {"output": "the answer"} + + +def test_submit_kwargs(): + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT(answer="hello", score=42)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "hello", "score": 42} + + +def test_submit_typed_positional(): + """Positional args mapped to output_fields names.""" + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "confidence", "type": "float"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + result = interp.execute('SUBMIT("the answer", 0.95)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "the answer", "confidence": 0.95} + + +def test_submit_typed_kwargs(): + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "confidence", "type": "float"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + result = interp.execute('SUBMIT(answer="the answer", confidence=0.95)') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "the answer", "confidence": 0.95} + + +def test_submit_wrong_arg_count(): + output_fields = [ + {"name": "answer", "type": "str"}, + {"name": "score", "type": "int"}, + ] + with LocalInterpreter(output_fields=output_fields) as interp: + with pytest.raises(CodeInterpreterError, match="takes 2 positional"): + interp.execute("SUBMIT('only one')") + + +def test_submit_no_args(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="SUBMIT requires at least one argument"): + interp.execute("SUBMIT()") + + +def test_submit_mixed_args_and_kwargs(): + with LocalInterpreter() as interp: + with pytest.raises(CodeInterpreterError, match="SUBMIT accepts either positional"): + interp.execute('SUBMIT("pos", key="val")') + + +def test_submit_stops_execution(): + """Code after SUBMIT should not execute.""" + with LocalInterpreter() as interp: + result = interp.execute('SUBMIT(answer="done")\nprint("should not print")') + assert isinstance(result, FinalOutput) + assert result.output == {"answer": "done"} + + +# ============================================================================= +# Tools +# ============================================================================= + + +def test_tool_basic(): + def greet(name: str) -> str: + return f"Hello, {name}!" + + with LocalInterpreter(tools={"greet": greet}) as interp: + result = interp.execute('print(greet("World"))') + assert result == "Hello, World!" + + +def test_tool_default_args(): + def search(query: str, limit: int = 10) -> str: + return f"query={query}, limit={limit}" + + with LocalInterpreter(tools={"search": search}) as interp: + result = interp.execute('print(search("test"))') + assert result == "query=test, limit=10" + result = interp.execute('print(search("test", limit=5))') + assert result == "query=test, limit=5" + + +def test_tools_via_setter(): + """Tools can be added/updated after construction via the tools setter.""" + interp = LocalInterpreter() + + def my_tool() -> str: + return "tool_result" + + interp.tools = {"my_tool": my_tool} + interp.start() + result = interp.execute("print(my_tool())") + assert result == "tool_result" + interp.shutdown() + + +def test_tools_refreshed_each_execute(): + """Tools dict is re-injected on each execute(), so updates are visible.""" + interp = LocalInterpreter() + interp.start() + + interp.tools["counter"] = lambda: "v1" + result = interp.execute("print(counter())") + assert result == "v1" + + interp.tools["counter"] = lambda: "v2" + result = interp.execute("print(counter())") + assert result == "v2" + + interp.shutdown() + + +# ============================================================================= +# Context Manager +# ============================================================================= + + +def test_context_manager(): + with LocalInterpreter() as interp: + assert interp._started is True + result = interp.execute("print('inside')") + assert result == "inside" + assert interp._started is False + + +def test_context_manager_cleanup_on_error(): + with pytest.raises(CodeInterpreterError): + with LocalInterpreter() as interp: + interp.execute("raise RuntimeError('fail')") + assert interp._started is False + + +# ============================================================================= +# Stdout Capture +# ============================================================================= + + +def test_multiple_prints(): + with LocalInterpreter() as interp: + result = interp.execute("print('a')\nprint('b')\nprint('c')") + assert result == "a\nb\nc" + + +def test_trailing_newlines_stripped(): + with LocalInterpreter() as interp: + result = interp.execute("print('hello')") + assert result == "hello" # No trailing newline + + +def test_empty_print(): + with LocalInterpreter() as interp: + result = interp.execute("print()") + assert result == "" # Empty print produces empty string after strip diff --git a/tests/primitives/test_media_serialization.py b/tests/primitives/test_media_serialization.py new file mode 100644 index 0000000000..a419819bdb --- /dev/null +++ b/tests/primitives/test_media_serialization.py @@ -0,0 +1,170 @@ +""" +Tests for RLM sandbox type protocol in python_interpreter.py and type classes. + +Tests the generic _has_rlm_support() helper and the to_sandbox/rlm_preview/sandbox_setup/ +sandbox_assignment protocol on Audio and Image types. + +These tests do NOT require Deno — they test pure Python serialization logic. +""" + + +from dspy.primitives.python_interpreter import _has_rlm_support + +# ============================================================================ +# Tests: _has_rlm_support helper +# ============================================================================ + + +class TestHasRlmSupport: + """Tests for the generic _has_rlm_support() protocol check.""" + + def test_audio(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + assert _has_rlm_support(audio) is True + + def test_image(self): + from dspy.adapters.types.image import Image + img = Image(url="") + assert _has_rlm_support(img) is True + + def test_string(self): + assert _has_rlm_support("hello") is False + + def test_int(self): + assert _has_rlm_support(42) is False + + def test_none(self): + assert _has_rlm_support(None) is False + + def test_dict(self): + assert _has_rlm_support({"data": "abc"}) is False + + def test_list(self): + assert _has_rlm_support([1, 2, 3]) is False + + +# ============================================================================ +# Tests: Audio RLM sandbox protocol +# ============================================================================ + + +class TestAudioRlmProtocol: + """Tests for Audio.rlm_preview, to_sandbox, sandbox_setup, sandbox_assignment.""" + + def test_rlm_preview_basic(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + preview = audio.rlm_preview() + assert "Audio" in preview + assert "wav" in preview + assert "8" in preview # len("dGVzdA==") == 8 + + def test_rlm_preview_includes_format(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="AAAA", audio_format="mpeg") + preview = audio.rlm_preview() + assert "mpeg" in preview + + def test_rlm_preview_includes_data_length(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="A" * 100, audio_format="wav") + preview = audio.rlm_preview() + assert "100" in preview + + def test_to_sandbox_returns_bytes(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + payload = audio.to_sandbox() + assert isinstance(payload, bytes) + assert b"Audio" in payload + assert b"wav" in payload + + def test_sandbox_setup_empty(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + assert audio.sandbox_setup() == "" + + def test_sandbox_assignment(self): + from dspy.adapters.types.audio import Audio + audio = Audio(data="dGVzdA==", audio_format="wav") + code = audio.sandbox_assignment("my_audio", "open('/tmp/data.json').read()") + assert "my_audio" in code + assert "/tmp/data.json" in code + + +# ============================================================================ +# Tests: Image RLM sandbox protocol +# ============================================================================ + + +class TestImageRlmProtocol: + """Tests for Image.rlm_preview, to_sandbox, sandbox_setup, sandbox_assignment.""" + + def test_rlm_preview_base64(self): + from dspy.adapters.types.image import Image + img = Image(url="") + preview = img.rlm_preview() + assert "Image" in preview + assert "png" in preview + + def test_rlm_preview_url(self): + from dspy.adapters.types.image import Image + img = Image(url="https://example.com/photo.jpg", download=False) + preview = img.rlm_preview() + assert "Image" in preview + assert "example.com" in preview + + def test_to_sandbox_returns_bytes(self): + from dspy.adapters.types.image import Image + img = Image(url="") + payload = img.to_sandbox() + assert isinstance(payload, bytes) + assert b"Image" in payload + + def test_sandbox_setup_empty(self): + from dspy.adapters.types.image import Image + img = Image(url="") + assert img.sandbox_setup() == "" + + def test_sandbox_assignment(self): + from dspy.adapters.types.image import Image + img = Image(url="") + code = img.sandbox_assignment("my_img", "open('/tmp/img.json').read()") + assert "my_img" in code + assert "/tmp/img.json" in code + + +# ============================================================================ +# Tests: Custom type with RLM protocol +# ============================================================================ + + +class TestCustomTypeRlmProtocol: + """Tests that any object implementing the protocol is recognized.""" + + def test_custom_type_with_protocol(self): + class MyType: + def to_sandbox(self): + return b"custom data" + def rlm_preview(self): + return "" + def sandbox_setup(self): + return "" + def sandbox_assignment(self, var_name, data_expr): + return f"{var_name} = {data_expr}" + + obj = MyType() + assert _has_rlm_support(obj) is True + + def test_custom_type_without_protocol(self): + class PlainType: + pass + assert _has_rlm_support(PlainType()) is False + + def test_partial_protocol_not_enough(self): + """Object with rlm_preview but no to_sandbox should NOT have RLM support.""" + class HalfType: + def rlm_preview(self): + return "preview" + assert _has_rlm_support(HalfType()) is False