diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index 55a3cf58..1de29776 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -13,7 +13,7 @@ RLMIteration, RLMMetadata, ) -from rlm.environments import BaseEnv, get_environment +from rlm.environments import BaseEnv, SupportsPersistence, get_environment from rlm.logger import RLMLogger, VerbosePrinter from rlm.utils.parsing import ( find_code_blocks, @@ -51,6 +51,7 @@ def __init__( other_backend_kwargs: list[dict[str, Any]] | None = None, logger: RLMLogger | None = None, verbose: bool = False, + persistent: bool = False, ): """ Args: @@ -66,6 +67,7 @@ def __init__( other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends). logger: The logger to use for the RLM. verbose: Whether to print verbose output in rich to console. + persistent: If True, reuse the environment across completion() calls for multi-turn conversations. """ # Store config for spawning per-completion self.backend = backend @@ -84,6 +86,14 @@ def __init__( self.logger = logger self.verbose = VerbosePrinter(enabled=verbose) + # Persistence support + self.persistent = persistent + self._persistent_env: SupportsPersistence | None = None + + # Validate persistence support at initialization + if self.persistent: + self._validate_persistent_environment_support() + # Log metadata if logger is provided if self.logger or verbose: metadata = RLMMetadata( @@ -108,7 +118,9 @@ def __init__( def _spawn_completion_context(self, prompt: str | dict[str, Any]): """ Spawn an LM handler and environment for a single completion call. - Cleans up both when the context exits. + + When persistent=True, the environment is reused across calls. + When persistent=False (default), creates fresh environment each call. """ # Create client and wrap in handler client: BaseLM = get_client(self.backend, self.backend_kwargs) @@ -122,20 +134,32 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]): lm_handler.start() - # Pass handler address to environment so it can make llm_query() calls - env_kwargs = self.environment_kwargs.copy() - env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port) - env_kwargs["context_payload"] = prompt + # Environment: reuse if persistent, otherwise create fresh + if self.persistent and self._persistent_env is not None: + environment = self._persistent_env + # Defensive check: ensure environment supports persistence methods + if not self._env_supports_persistence(environment): + raise RuntimeError( + f"Persistent environment of type '{type(environment).__name__}' does not " + f"implement required methods (update_handler_address, add_context, get_context_count). " + f"This should have been caught at initialization." + ) + environment.update_handler_address((lm_handler.host, lm_handler.port)) + environment.add_context(prompt) + else: + env_kwargs = self.environment_kwargs.copy() + env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port) + env_kwargs["context_payload"] = prompt + environment: BaseEnv = get_environment(self.environment_type, env_kwargs) - # Initialize the environment - environment: BaseEnv = get_environment(self.environment_type, env_kwargs) + if self.persistent: + self._persistent_env = environment try: yield lm_handler, environment finally: - # Cleanup lm_handler.stop() - if hasattr(environment, "cleanup"): + if not self.persistent and hasattr(environment, "cleanup"): environment.cleanup() def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]: @@ -177,7 +201,19 @@ def completion( for i in range(self.max_iterations): # Current prompt = message history + additional prompt suffix - current_prompt = message_history + [build_user_prompt(root_prompt, i)] + context_count = ( + environment.get_context_count() + if isinstance(environment, SupportsPersistence) + else 1 + ) + history_count = ( + environment.get_history_count() + if isinstance(environment, SupportsPersistence) + else 0 + ) + current_prompt = message_history + [ + build_user_prompt(root_prompt, i, context_count, history_count) + ] iteration: RLMIteration = self._completion_turn( prompt=current_prompt, @@ -201,6 +237,11 @@ def completion( usage = lm_handler.get_usage_summary() self.verbose.print_final_answer(final_answer) self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict()) + + # Store message history in persistent environment + if self.persistent and isinstance(environment, SupportsPersistence): + environment.add_history(message_history) + return RLMChatCompletion( root_model=self.backend_kwargs.get("model_name", "unknown") if self.backend_kwargs @@ -223,6 +264,11 @@ def completion( usage = lm_handler.get_usage_summary() self.verbose.print_final_answer(final_answer) self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict()) + + # Store message history in persistent environment + if self.persistent and isinstance(environment, SupportsPersistence): + environment.add_history(message_history) + return RLMChatCompletion( root_model=self.backend_kwargs.get("model_name", "unknown") if self.backend_kwargs @@ -292,3 +338,47 @@ def _fallback_answer(self, message: str | dict[str, Any]) -> str: client: BaseLM = get_client(self.backend, self.backend_kwargs) response = client.completion(message) return response + + def _validate_persistent_environment_support(self) -> None: + """ + Validate that the configured environment type supports persistent mode. + + Persistent mode requires environments to implement: + - update_handler_address(address): Update LM handler address between calls + - add_context(payload, index): Add new context for multi-turn conversations + - get_context_count(): Return the number of loaded contexts + + Currently only 'local' (LocalREPL) supports these methods. + + Raises: + ValueError: If the environment type does not support persistent mode. + """ + # Known environments that support persistence + persistent_supported_environments = {"local"} + + if self.environment_type not in persistent_supported_environments: + raise ValueError( + f"persistent=True is not supported for environment type '{self.environment_type}'. " + f"Persistent mode requires environments that implement update_handler_address(), " + f"add_context(), and get_context_count(). " + f"Supported environments: {sorted(persistent_supported_environments)}" + ) + + @staticmethod + def _env_supports_persistence(env: BaseEnv) -> bool: + """Check if an environment instance supports persistent mode methods.""" + return isinstance(env, SupportsPersistence) + + def close(self) -> None: + """Clean up persistent environment. Call when done with multi-turn conversations.""" + if self._persistent_env is not None: + if hasattr(self._persistent_env, "cleanup"): + self._persistent_env.cleanup() + self._persistent_env = None + + def __enter__(self) -> "RLM": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False diff --git a/rlm/environments/__init__.py b/rlm/environments/__init__.py index 9d2719f2..6a70431d 100644 --- a/rlm/environments/__init__.py +++ b/rlm/environments/__init__.py @@ -1,8 +1,10 @@ from typing import Any, Literal -from rlm.environments.base_env import BaseEnv +from rlm.environments.base_env import BaseEnv, SupportsPersistence from rlm.environments.local_repl import LocalREPL +__all__ = ["BaseEnv", "LocalREPL", "SupportsPersistence", "get_environment"] + def get_environment( environment: Literal["local", "modal", "docker"], diff --git a/rlm/environments/base_env.py b/rlm/environments/base_env.py index 6a99c64a..963018b4 100644 --- a/rlm/environments/base_env.py +++ b/rlm/environments/base_env.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any, Protocol, runtime_checkable from rlm.core.types import REPLResult @@ -9,7 +10,8 @@ class BaseEnv(ABC): where isolated environments are on a separate machine from the LM. """ - def __init__(self, **kwargs): + def __init__(self, persistent: bool = False, **kwargs): + self.persistent = persistent self.kwargs = kwargs @abstractmethod @@ -31,8 +33,8 @@ class IsolatedEnv(BaseEnv, ABC): guaranteeing complete isolation from the LM process. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, persistent: bool = False, **kwargs): + super().__init__(persistent=persistent, **kwargs) @abstractmethod def setup(self): @@ -54,8 +56,8 @@ class NonIsolatedEnv(BaseEnv, ABC): as a subprocess. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, persistent: bool = False, **kwargs): + super().__init__(persistent=persistent, **kwargs) @abstractmethod def setup(self): @@ -68,3 +70,112 @@ def load_context(self, context_payload: dict | list | str): @abstractmethod def execute_code(self, code: str) -> REPLResult: raise NotImplementedError + + +@runtime_checkable +class SupportsPersistence(Protocol): + """Protocol for environments that support persistent multi-turn sessions. + + CHECKING SUPPORT: + Use isinstance(env, SupportsPersistence) to check if an environment + supports persistence capabilities. + + IMPLEMENTING THIS PROTOCOL: + To add persistence to your environment, implement these 5 methods. + See tests/test_local_repl_persistent.py for expected behavior. + + VERSIONING BEHAVIOR: + Contexts and histories are versioned with numeric suffixes: + - First context -> context_0, context_1, context_2, ... + - First history -> history_0, history_1, history_2, ... + + ALIASING BEHAVIOR: + The unversioned names always point to index 0: + - context -> context_0 (first context) + - history -> history_0 (first history) + + EXAMPLE IMPLEMENTATION: + See rlm/environments/local_repl.py for a complete reference. + + TESTS: + - Unit tests: tests/test_local_repl_persistent.py + - Integration tests: tests/test_multi_turn_integration.py + + Run: uv run pytest tests/test_local_repl_persistent.py -v + """ + + def update_handler_address(self, address: tuple[str, int]) -> None: + """Update the LM handler address for nested LLM calls. + + Called by RLM when the handler address changes between completions. + Store the address so llm_query() calls from executed code can reach + the LM handler. + + Args: + address: (host, port) tuple for the LM handler server. + """ + ... + + def add_context( + self, context_payload: dict | list | str, context_index: int | None = None + ) -> int: + """Add a context payload, making it available as context_N in code. + + Versioning: + - context_index=None: auto-increment (0, 1, 2, ...) + - context_index=N: use specific index N + + Storage: + Must store so executed code can access: + - context_0, context_1, etc. (versioned) + - context (alias to context_0) + + Args: + context_payload: The context data (string, dict, or list). + context_index: Optional specific index, or None to auto-increment. + + Returns: + The index used (for auto-increment, returns the assigned index). + """ + ... + + def get_context_count(self) -> int: + """Return the number of contexts added so far. + + Used by RLM to inform the model how many contexts are available. + """ + ... + + def add_history( + self, message_history: list[dict[str, Any]], history_index: int | None = None + ) -> int: + """Add a message history, making it available as history_N in code. + + Versioning: + - history_index=None: auto-increment (0, 1, 2, ...) + - history_index=N: use specific index N + + Storage: + Must store so executed code can access: + - history_0, history_1, etc. (versioned) + - history (alias to history_0) + + IMPORTANT: Store a deep copy, not a reference. The caller may + modify the list after calling this method. + + Args: + message_history: List of message dicts (role, content). + history_index: Optional specific index, or None to auto-increment. + + Returns: + The index used. + """ + ... + + def get_history_count(self) -> int: + """Return the number of histories added so far. + + Used by RLM to inform the model how many conversation histories + are available. + """ + ... diff --git a/rlm/environments/docker_repl.py b/rlm/environments/docker_repl.py index 6dd8c00a..19714857 100644 --- a/rlm/environments/docker_repl.py +++ b/rlm/environments/docker_repl.py @@ -180,9 +180,14 @@ def __init__( lm_handler_address: tuple[str, int] | None = None, context_payload: dict | list | str | None = None, setup_code: str | None = None, + persistent: bool = False, **kwargs, ): - super().__init__(**kwargs) + if persistent: + raise NotImplementedError( + "Persistent REPLs are currently not supported for environment: DockerREPL" + ) + super().__init__(persistent=persistent, **kwargs) self.image = image self.lm_handler_address = lm_handler_address @@ -292,13 +297,13 @@ def execute_code(self, code: str) -> REPLResult: ) def cleanup(self): - if self.container_id: + if hasattr(self, "container_id") and self.container_id: subprocess.run(["docker", "stop", self.container_id], capture_output=True) self.container_id = None - if self.proxy_server: + if hasattr(self, "proxy_server") and self.proxy_server: self.proxy_server.shutdown() self.proxy_server = None - if os.path.exists(self.temp_dir): + if hasattr(self, "temp_dir") and os.path.exists(self.temp_dir): import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) diff --git a/rlm/environments/local_repl.py b/rlm/environments/local_repl.py index b8183800..05de1d3d 100644 --- a/rlm/environments/local_repl.py +++ b/rlm/environments/local_repl.py @@ -1,3 +1,4 @@ +import copy import io import json import os @@ -122,14 +123,17 @@ def __init__( lm_handler_address: tuple[str, int] | None = None, context_payload: dict | list | str | None = None, setup_code: str | None = None, + persistent: bool = False, **kwargs, ): - super().__init__(**kwargs) + super().__init__(persistent=persistent, **kwargs) self.lm_handler_address = lm_handler_address self.original_cwd = os.getcwd() self.temp_dir = tempfile.mkdtemp(prefix=f"repl_env_{uuid.uuid4()}_") self._lock = threading.Lock() + self._context_count: int = 0 + self._history_count: int = 0 # Setup globals, locals, and modules in environment. self.setup() @@ -222,20 +226,87 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li return [f"Error: LM query failed - {e}"] * len(prompts) def load_context(self, context_payload: dict | list | str): - """Load context into the environment.""" + """Load context into the environment as context_0 (and 'context' alias).""" + self.add_context(context_payload, 0) + + def add_context( + self, context_payload: dict | list | str, context_index: int | None = None + ) -> int: + """ + Add a context with versioned variable name. + + Args: + context_payload: The context data to add + context_index: Optional explicit index. If None, auto-increments. + + Returns: + The context index used. + """ + if context_index is None: + context_index = self._context_count + + var_name = f"context_{context_index}" + if isinstance(context_payload, str): - context_path = os.path.join(self.temp_dir, "context.txt") + context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt") with open(context_path, "w") as f: f.write(context_payload) - self.execute_code(f"with open(r'{context_path}', 'r') as f:\n context = f.read()") + self.execute_code(f"with open(r'{context_path}', 'r') as f:\n {var_name} = f.read()") else: - context_path = os.path.join(self.temp_dir, "context.json") + context_path = os.path.join(self.temp_dir, f"context_{context_index}.json") with open(context_path, "w") as f: json.dump(context_payload, f) self.execute_code( - f"import json\nwith open(r'{context_path}', 'r') as f:\n context = json.load(f)" + f"import json\nwith open(r'{context_path}', 'r') as f:\n {var_name} = json.load(f)" ) + # Alias context_0 as 'context' for backward compatibility + if context_index == 0: + self.execute_code(f"context = {var_name}") + + self._context_count = max(self._context_count, context_index + 1) + return context_index + + def update_handler_address(self, address: tuple[str, int]) -> None: + """Update the LM handler address for a new completion call.""" + self.lm_handler_address = address + + def get_context_count(self) -> int: + """Return the number of contexts loaded.""" + return self._context_count + + def add_history( + self, message_history: list[dict[str, Any]], history_index: int | None = None + ) -> int: + """ + Store a conversation's message history as a versioned variable. + + Args: + message_history: The list of message dicts from a completion call + history_index: Optional explicit index. If None, auto-increments. + + Returns: + The history index used. + """ + if history_index is None: + history_index = self._history_count + + var_name = f"history_{history_index}" + + # Store deep copy to avoid reference issues with nested dicts + self.locals[var_name] = copy.deepcopy(message_history) + + # Alias history_0 as 'history' for convenience + if history_index == 0: + self.locals["history"] = self.locals[var_name] + + self._history_count = max(self._history_count, history_index + 1) + return history_index + + def get_history_count(self) -> int: + """Return the number of conversation histories stored.""" + return self._history_count + @contextmanager def _capture_output(self): """Thread-safe context manager to capture stdout/stderr.""" @@ -265,22 +336,21 @@ def execute_code(self, code: str) -> REPLResult: # Clear pending LLM calls from previous execution self._pending_llm_calls = [] - with self._capture_output() as (stdout_buf, stderr_buf): - with self._temp_cwd(): - try: - combined = {**self.globals, **self.locals} - exec(code, combined, combined) - - # Update locals with new variables - for key, value in combined.items(): - if key not in self.globals and not key.startswith("_"): - self.locals[key] = value - - stdout = stdout_buf.getvalue() - stderr = stderr_buf.getvalue() - except Exception as e: - stdout = stdout_buf.getvalue() - stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}" + with self._capture_output() as (stdout_buf, stderr_buf), self._temp_cwd(): + try: + combined = {**self.globals, **self.locals} + exec(code, combined, combined) + + # Update locals with new variables + for key, value in combined.items(): + if key not in self.globals and not key.startswith("_"): + self.locals[key] = value + + stdout = stdout_buf.getvalue() + stderr = stderr_buf.getvalue() + except Exception as e: + stdout = stdout_buf.getvalue() + stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}" return REPLResult( stdout=stdout, diff --git a/rlm/environments/modal_repl.py b/rlm/environments/modal_repl.py index 2acfed18..82fa24cf 100644 --- a/rlm/environments/modal_repl.py +++ b/rlm/environments/modal_repl.py @@ -309,9 +309,14 @@ def __init__( lm_handler_address: tuple[str, int] | None = None, context_payload: dict | list | str | None = None, setup_code: str | None = None, + persistent: bool = False, **kwargs, ): - super().__init__(**kwargs) + if persistent: + raise NotImplementedError( + "Persistent REPLs are currently not supported for environment: ModalREPL" + ) + super().__init__(persistent=persistent, **kwargs) self.app_name = app_name self.timeout = timeout diff --git a/rlm/environments/prime_repl.py b/rlm/environments/prime_repl.py index 19a082ca..0d88ce90 100644 --- a/rlm/environments/prime_repl.py +++ b/rlm/environments/prime_repl.py @@ -8,9 +8,14 @@ def __init__( context_payload: dict | list | str | None = None, sandbox_name: str | None = None, api_key: str | None = None, + persistent: bool = False, **kwargs, ): - pass + if persistent: + raise NotImplementedError( + "Persistent REPLs are currently not supported for environment: PrimeREPL" + ) + super().__init__(persistent=persistent, **kwargs) def setup(self): pass diff --git a/rlm/utils/prompts.py b/rlm/utils/prompts.py index 3d0bd624..f69b2292 100644 --- a/rlm/utils/prompts.py +++ b/rlm/utils/prompts.py @@ -116,15 +116,31 @@ def build_rlm_system_prompt( USER_PROMPT_WITH_ROOT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the original prompt: \"{root_prompt}\".\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:""" -def build_user_prompt(root_prompt: str | None = None, iteration: int = 0) -> dict[str, str]: +def build_user_prompt( + root_prompt: str | None = None, + iteration: int = 0, + context_count: int = 1, + history_count: int = 0, +) -> dict[str, str]: if iteration == 0: safeguard = "You have not interacted with the REPL environment or seen your prompt / context yet. Your next action should be to look through and figure out how to answer the prompt, so don't just provide a final answer yet.\n\n" prompt = safeguard + ( USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) if root_prompt else USER_PROMPT ) - return {"role": "user", "content": prompt} else: prompt = "The history before is your previous interactions with the REPL environment. " + ( USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) if root_prompt else USER_PROMPT ) - return {"role": "user", "content": prompt} + + # Inform model about multiple contexts if present + if context_count > 1: + prompt += f"\n\nNote: You have {context_count} contexts available (context_0 through context_{context_count - 1})." + + # Inform model about prior conversation histories if present + if history_count > 0: + if history_count == 1: + prompt += "\n\nNote: You have 1 prior conversation history available in the `history` variable." + else: + prompt += f"\n\nNote: You have {history_count} prior conversation histories available (history_0 through history_{history_count - 1})." + + return {"role": "user", "content": prompt} diff --git a/tests/test_local_repl.py b/tests/test_local_repl.py index f03c523c..787a5b54 100644 --- a/tests/test_local_repl.py +++ b/tests/test_local_repl.py @@ -193,3 +193,53 @@ def test_temp_dir_created_and_cleaned(self): assert os.path.exists(temp_dir) repl.cleanup() assert not os.path.exists(temp_dir) + + +class TestLocalREPLSimulatingRLMNoPersistence: + """ + Tests simulating RLM's non-persistent completion behavior. + + When RLM is configured without persistent=True (the default), each + get_completion() call spawns a fresh environment and destroys it after. + This test suite simulates that behavior to prove variables don't survive + across RLM completions. + + Why this matters: This is NOT just testing that two Python objects don't + share state (trivially true). This simulates the actual RLM workflow where + environments are created and destroyed per completion. + """ + + def test_simulated_rlm_completions_reset_environment(self): + """ + Simulates 2 RLM completions to show env resets between calls. + + Without persistent=True, RLM creates a fresh environment for each + completion, so state doesn't carry over. + """ + completion_1_env = LocalREPL() + completion_1_env.execute_code("important_result = 42") + assert completion_1_env.locals["important_result"] == 42 + completion_1_env.cleanup() + + completion_2_env = LocalREPL() + result = completion_2_env.execute_code("print(important_result)") + + assert "NameError" in result.stderr + assert "important_result" in result.stderr + completion_2_env.cleanup() + + def test_simulated_rlm_completions_functions_not_preserved(self): + """ + Simulates 2 RLM completions to show functions don't persist. + """ + completion_1_env = LocalREPL() + completion_1_env.execute_code("def my_helper(): return 'useful'") + assert completion_1_env.execute_code("print(my_helper())").stdout.strip() == "useful" + completion_1_env.cleanup() + + completion_2_env = LocalREPL() + result = completion_2_env.execute_code("my_helper()") + + assert "NameError" in result.stderr + assert "my_helper" in result.stderr + completion_2_env.cleanup() diff --git a/tests/test_local_repl_persistent.py b/tests/test_local_repl_persistent.py new file mode 100644 index 00000000..f654679d --- /dev/null +++ b/tests/test_local_repl_persistent.py @@ -0,0 +1,220 @@ +"""Tests for LocalREPL persistence features. + +These tests verify LocalREPL's multi-context and multi-history capabilities +which support the persistent=True mode in RLM for multi-turn conversations. +""" + +from rlm.environments.local_repl import LocalREPL + + +class TestLocalREPLMultiContext: + """Tests for multi-context support in persistent mode.""" + + def test_add_context_versioning(self): + """Test that add_context creates versioned variables.""" + repl = LocalREPL() + repl.add_context("First", 0) + repl.add_context("Second", 1) + assert repl.locals["context_0"] == "First" + assert repl.locals["context_1"] == "Second" + assert repl.locals["context"] == "First" + assert repl.get_context_count() == 2 + repl.cleanup() + + def test_update_handler_address(self): + """Test handler address can be updated.""" + repl = LocalREPL(lm_handler_address=("127.0.0.1", 5000)) + repl.update_handler_address(("127.0.0.1", 6000)) + assert repl.lm_handler_address == ("127.0.0.1", 6000) + repl.cleanup() + + def test_add_context_auto_increment(self): + """Test that add_context auto-increments when no index provided.""" + repl = LocalREPL() + idx1 = repl.add_context("First") + idx2 = repl.add_context("Second") + assert idx1 == 0 + assert idx2 == 1 + assert repl.locals["context_0"] == "First" + assert repl.locals["context_1"] == "Second" + assert repl.get_context_count() == 2 + repl.cleanup() + + def test_contexts_accessible_in_code(self): + """Test that multiple contexts can be accessed in code execution.""" + repl = LocalREPL() + repl.add_context("Document A content") + repl.add_context("Document B content") + + result = repl.execute_code("combined = f'{context_0} + {context_1}'") + assert result.stderr == "" + assert repl.locals["combined"] == "Document A content + Document B content" + repl.cleanup() + + def test_context_alias_points_to_first(self): + """Test that 'context' always aliases context_0.""" + repl = LocalREPL() + repl.add_context("First") + repl.add_context("Second") + repl.add_context("Third") + + result = repl.execute_code("is_first = context == context_0") + assert result.stderr == "" + assert repl.locals["is_first"] is True + repl.cleanup() + + +class TestLocalREPLHistory: + """Tests for message history storage in LocalREPL for persistent sessions.""" + + def test_add_history_basic(self): + """Test that add_history stores message history correctly.""" + repl = LocalREPL() + + history = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + index = repl.add_history(history) + + assert index == 0 + assert "history_0" in repl.locals + assert "history" in repl.locals # alias + assert repl.locals["history_0"] == history + assert repl.locals["history"] == history + assert repl.get_history_count() == 1 + + repl.cleanup() + + def test_add_multiple_histories(self): + """Test adding multiple conversation histories.""" + repl = LocalREPL() + + history1 = [{"role": "user", "content": "First conversation"}] + history2 = [{"role": "user", "content": "Second conversation"}] + + repl.add_history(history1) + repl.add_history(history2) + + assert repl.get_history_count() == 2 + assert repl.locals["history_0"] == history1 + assert repl.locals["history_1"] == history2 + assert repl.locals["history"] == history1 # alias stays on first + + repl.cleanup() + + def test_history_accessible_via_code(self): + """Test that stored history is accessible via code execution.""" + repl = LocalREPL() + + history = [{"role": "user", "content": "Test message"}] + repl.add_history(history) + + result = repl.execute_code("msg = history[0]['content']") + assert result.stderr == "" + assert repl.locals["msg"] == "Test message" + + repl.cleanup() + + def test_history_is_copy(self): + """Test that stored history is a copy, not a reference.""" + repl = LocalREPL() + + history = [{"role": "user", "content": "Original"}] + repl.add_history(history) + + history[0]["content"] = "Modified" + + assert repl.locals["history_0"][0]["content"] == "Original" + + repl.cleanup() + + def test_can_iterate_histories_in_code(self): + """Test iterating through multiple histories in code.""" + repl = LocalREPL() + + repl.add_history([{"role": "user", "content": "Query 1"}]) + repl.add_history([{"role": "user", "content": "Query 2"}]) + repl.add_history([{"role": "user", "content": "Query 3"}]) + + code = """ +all_contents = [ + history_0[0]['content'], + history_1[0]['content'], + history_2[0]['content'], +] +""" + result = repl.execute_code(code) + assert result.stderr == "" + assert repl.locals["all_contents"] == ["Query 1", "Query 2", "Query 3"] + + repl.cleanup() + + +class TestLocalREPLPersistentState: + """Tests for state persistence across multiple operations in a single REPL instance.""" + + def test_variables_persist_with_contexts(self): + """Variables and contexts should coexist.""" + repl = LocalREPL() + + repl.add_context("My context data") + repl.execute_code("summary = context.upper()") + assert repl.locals["summary"] == "MY CONTEXT DATA" + + repl.add_context("New context") + + assert repl.locals["summary"] == "MY CONTEXT DATA" + assert repl.locals["context_1"] == "New context" + + repl.cleanup() + + def test_variables_persist_with_histories(self): + """Variables and histories should coexist.""" + repl = LocalREPL() + + repl.add_history([{"role": "user", "content": "Hello"}]) + repl.execute_code("extracted = history[0]['content']") + assert repl.locals["extracted"] == "Hello" + + repl.add_history([{"role": "user", "content": "World"}]) + + assert repl.locals["extracted"] == "Hello" + assert repl.locals["history_1"][0]["content"] == "World" + + repl.cleanup() + + def test_full_persistent_session_simulation(self): + """Simulate a multi-turn persistent session.""" + repl = LocalREPL() + + repl.add_context("Document: Sales were $1000") + repl.execute_code("sales = 1000") + + repl.add_context("Document: Costs were $400") + result = repl.execute_code("profit = sales - 400") + assert result.stderr == "" + assert repl.locals["profit"] == 600 + + repl.add_history( + [ + {"role": "user", "content": "What were the sales?"}, + {"role": "assistant", "content": "Sales were $1000"}, + ] + ) + + code = """ +summary = f"Sales: {context_0}, Costs: {context_1}, Profit: {profit}" +prev_question = history_0[0]['content'] +""" + result = repl.execute_code(code) + assert result.stderr == "" + assert "Profit: 600" in repl.locals["summary"] + assert repl.locals["prev_question"] == "What were the sales?" + + assert repl.get_context_count() == 2 + assert repl.get_history_count() == 1 + + repl.cleanup() diff --git a/tests/test_multi_turn_integration.py b/tests/test_multi_turn_integration.py new file mode 100644 index 00000000..f55de39b --- /dev/null +++ b/tests/test_multi_turn_integration.py @@ -0,0 +1,395 @@ +"""Integration tests for multi-turn persistent REPL sessions. + +Tests that multiple LM completion calls in one RLM session: +1. Share the same environment +2. Accumulate contexts (context_0, context_1, ...) +3. Accumulate histories (history_0, history_1, ...) +4. Preserve variables across calls +5. Properly inform the model about available contexts/histories +""" + +from unittest.mock import Mock, patch + +import pytest + +import rlm.core.rlm as rlm_module +from rlm import RLM +from rlm.core.types import ModelUsageSummary, UsageSummary + + +def create_mock_lm(responses: list[str]) -> Mock: + """Create a mock LM that returns responses in order.""" + mock = Mock() + mock.completion.side_effect = list(responses) + mock.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + "mock": ModelUsageSummary(total_calls=1, total_input_tokens=100, total_output_tokens=50) + } + ) + mock.get_last_usage.return_value = mock.get_usage_summary.return_value + return mock + + +class TestMultiTurnPersistentEnvironment: + """Tests for environment persistence across completion calls.""" + + def test_environment_reused_in_persistent_mode(self): + """Verify the same environment instance is reused across completion calls.""" + responses = ["FINAL(answer from call)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("First context") + first_env = rlm._persistent_env + + mock_lm.completion.side_effect = list(responses) + + rlm.completion("Second context") + second_env = rlm._persistent_env + + assert first_env is second_env + assert first_env is not None + + def test_context_accumulation_across_calls(self): + """Verify contexts accumulate: context_0, context_1, etc.""" + responses = ["FINAL(got it)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("First document") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Second document") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Third document") + + env = rlm._persistent_env + assert env.get_context_count() == 3 + assert env.locals["context_0"] == "First document" + assert env.locals["context_1"] == "Second document" + assert env.locals["context_2"] == "Third document" + assert env.locals["context"] == "First document" + + def test_history_accumulation_across_calls(self): + """Verify message histories accumulate: history_0, history_1, etc.""" + responses = ["FINAL(done)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("Context A") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Context B") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Context C") + + env = rlm._persistent_env + assert env.get_history_count() == 3 + assert "history_0" in env.locals + assert "history_1" in env.locals + assert "history_2" in env.locals + assert isinstance(env.locals["history_0"], list) + assert len(env.locals["history_0"]) > 0 + assert env.locals["history"] == env.locals["history_0"] + + def test_variable_persistence_across_completions(self): + """Variables computed in one completion should be available in subsequent ones.""" + first_responses = [ + "Let me compute something\n```repl\ncomputed_value = 42 * 2\nprint(computed_value)\n```", + "FINAL(84)", + ] + second_responses = [ + "```repl\nresult = computed_value + 10\nprint(result)\n```", + "FINAL(94)", + ] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(first_responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("Compute 42 * 2") + assert rlm._persistent_env.locals.get("computed_value") == 84 + + mock_lm.completion.side_effect = list(second_responses) + rlm.completion("Add 10 to the previous result") + + assert rlm._persistent_env.locals.get("computed_value") == 84 + assert rlm._persistent_env.locals.get("result") == 94 + + +class TestMultiTurnPromptAwareness: + """Tests that prompts correctly inform the model about contexts/histories.""" + + def test_prompt_includes_context_count(self): + """Model should be informed about available contexts.""" + responses = ["FINAL(ok)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("First") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Second") + + last_prompt = mock_lm.completion.call_args[0][0] + user_messages = [m for m in last_prompt if m.get("role") == "user"] + user_content = " ".join(m.get("content", "") for m in user_messages) + + assert "2 contexts" in user_content or "context_0" in user_content + + def test_prompt_includes_history_count(self): + """Model should be informed about available histories.""" + responses = ["FINAL(ok)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("First task") + mock_lm.completion.side_effect = list(responses) + rlm.completion("Second task") + + last_prompt = mock_lm.completion.call_args[0][0] + user_messages = [m for m in last_prompt if m.get("role") == "user"] + user_content = " ".join(m.get("content", "") for m in user_messages) + + assert "history" in user_content.lower() + + +class TestMultiTurnCodeExecution: + """Tests for code execution in multi-turn sessions.""" + + def test_can_access_previous_context_in_code(self): + """Code should be able to reference earlier contexts.""" + first_responses = ["FINAL(first done)"] + second_responses = [ + "```repl\nprint(f'First: {context_0}, Second: {context_1}')\n```", + "FINAL(printed both)", + ] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(first_responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("Document A") + + mock_lm.completion.side_effect = list(second_responses) + rlm.completion("Document B") + + env = rlm._persistent_env + assert env.locals["context_0"] == "Document A" + assert env.locals["context_1"] == "Document B" + + def test_can_access_history_in_code(self): + """Code should be able to reference stored histories.""" + first_responses = ["FINAL(first)"] + second_responses = [ + "```repl\nprint(f'History entries: {len(history)}')\n```", + "FINAL(accessed history)", + ] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(first_responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("First query") + + mock_lm.completion.side_effect = list(second_responses) + rlm.completion("Second query") + + env = rlm._persistent_env + assert "history" in env.locals + assert isinstance(env.locals["history"], list) + + +class TestNonPersistentMode: + """Tests to ensure non-persistent mode still works correctly.""" + + def test_non_persistent_creates_fresh_environment(self): + """Non-persistent mode should create new environment each call.""" + responses = ["FINAL(done)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=False, + ) + + rlm.completion("First") + assert rlm._persistent_env is None + + mock_lm.completion.side_effect = list(responses) + rlm.completion("Second") + assert rlm._persistent_env is None + + def test_default_is_non_persistent(self): + """Default behavior should be non-persistent.""" + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + ) + assert rlm.persistent is False + + +class TestPersistentModeResourceManagement: + """Tests for proper resource cleanup in persistent mode.""" + + def test_context_manager_cleanup(self): + """Environment should be cleaned up when exiting context manager.""" + responses = ["FINAL(done)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + rlm.completion("Test") + assert rlm._persistent_env is not None + + assert rlm._persistent_env is None + + def test_explicit_close(self): + """Calling close() should clean up persistent environment.""" + responses = ["FINAL(done)"] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(responses) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) + rlm.completion("Test") + assert rlm._persistent_env is not None + + rlm.close() + assert rlm._persistent_env is None + + +class TestPersistentModeValidation: + """Tests for persistent mode validation.""" + + def test_unsupported_environment_raises_error(self): + """Persistent mode should raise error for unsupported environments.""" + with pytest.raises(ValueError, match="persistent=True is not supported"): + RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + environment="docker", # Not supported for persistent + persistent=True, + ) + + def test_local_environment_supported(self): + """Local environment should support persistent mode.""" + # Should not raise + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + environment="local", + persistent=True, + ) + assert rlm.persistent is True + + +class TestMultiTurnEndToEnd: + """End-to-end tests simulating realistic multi-turn usage.""" + + def test_three_turn_conversation(self): + """Simulate a 3-turn conversation with context accumulation.""" + turn1_responses = [ + "Looking at the first document\n```repl\ndoc1_summary = 'Has info about cats'\nprint(doc1_summary)\n```", + "FINAL(Summarized first doc)", + ] + turn2_responses = [ + "Looking at second document and comparing\n```repl\ndoc2_summary = 'Has info about dogs'\nprint(f'Doc1: {doc1_summary}, Doc2: {doc2_summary}')\n```", + "FINAL(Compared both docs)", + ] + turn3_responses = [ + "Final synthesis\n```repl\nfinal = f'Combined: {doc1_summary} and {doc2_summary} from context_2'\nprint(final)\n```", + "FINAL(synthesized all)", + ] + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(turn1_responses) + mock_get_client.return_value = mock_lm + + with RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + persistent=True, + ) as rlm: + result1 = rlm.completion("First document about cats") + assert "Summarized" in result1.response + + mock_lm.completion.side_effect = list(turn2_responses) + result2 = rlm.completion("Second document about dogs") + assert "Compared" in result2.response + + mock_lm.completion.side_effect = list(turn3_responses) + result3 = rlm.completion("Synthesize everything") + assert "synthesized" in result3.response + + env = rlm._persistent_env + assert env.get_context_count() == 3 + assert env.get_history_count() == 3 + assert env.locals.get("doc1_summary") == "Has info about cats" + assert env.locals.get("doc2_summary") == "Has info about dogs" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])