Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 111 additions & 10 deletions rlm/core/rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
other_backend_kwargs: list[dict[str, Any]] | None = None,
logger: RLMLogger | None = None,
verbose: bool = False,
persistent: bool = False,
):
"""
Args:
Expand All @@ -66,6 +67,7 @@ def __init__(
other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
logger: The logger to use for the RLM.
verbose: Whether to print verbose output in rich to console.
persistent: If True, reuse the environment across completion() calls for multi-turn conversations.
"""
# Store config for spawning per-completion
self.backend = backend
Expand All @@ -84,6 +86,14 @@ def __init__(
self.logger = logger
self.verbose = VerbosePrinter(enabled=verbose)

# Persistence support
self.persistent = persistent
self._persistent_env: BaseEnv | None = None

# Validate persistence support at initialization
if self.persistent:
self._validate_persistent_environment_support()

# Log metadata if logger is provided
if self.logger or verbose:
metadata = RLMMetadata(
Expand All @@ -108,7 +118,9 @@ def __init__(
def _spawn_completion_context(self, prompt: str | dict[str, Any]):
"""
Spawn an LM handler and environment for a single completion call.
Cleans up both when the context exits.

When persistent=True, the environment is reused across calls.
When persistent=False (default), creates fresh environment each call.
"""
# Create client and wrap in handler
client: BaseLM = get_client(self.backend, self.backend_kwargs)
Expand All @@ -122,20 +134,32 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):

lm_handler.start()

# Pass handler address to environment so it can make llm_query() calls
env_kwargs = self.environment_kwargs.copy()
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
env_kwargs["context_payload"] = prompt
# Environment: reuse if persistent, otherwise create fresh
if self.persistent and self._persistent_env is not None:
environment = self._persistent_env
# Defensive check: ensure environment supports persistence methods
if not self._env_supports_persistence(environment):
raise RuntimeError(
f"Persistent environment of type '{type(environment).__name__}' does not "
f"implement required methods (update_handler_address, add_context, get_context_count). "
f"This should have been caught at initialization."
)
environment.update_handler_address((lm_handler.host, lm_handler.port))
environment.add_context(prompt)
else:
env_kwargs = self.environment_kwargs.copy()
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
env_kwargs["context_payload"] = prompt
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)

# Initialize the environment
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
if self.persistent:
self._persistent_env = environment

try:
yield lm_handler, environment
finally:
# Cleanup
lm_handler.stop()
if hasattr(environment, "cleanup"):
if not self.persistent and hasattr(environment, "cleanup"):
environment.cleanup()

def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -177,7 +201,19 @@ def completion(

for i in range(self.max_iterations):
# Current prompt = message history + additional prompt suffix
current_prompt = message_history + [build_user_prompt(root_prompt, i)]
context_count = (
environment.get_context_count()
if hasattr(environment, "get_context_count")
else 1
)
history_count = (
environment.get_history_count()
if hasattr(environment, "get_history_count")
else 0
)
current_prompt = message_history + [
build_user_prompt(root_prompt, i, context_count, history_count)
]

iteration: RLMIteration = self._completion_turn(
prompt=current_prompt,
Expand All @@ -201,6 +237,11 @@ def completion(
usage = lm_handler.get_usage_summary()
self.verbose.print_final_answer(final_answer)
self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict())

# Store message history in persistent environment
if self.persistent and hasattr(environment, "add_history"):
environment.add_history(message_history)

return RLMChatCompletion(
root_model=self.backend_kwargs.get("model_name", "unknown")
if self.backend_kwargs
Expand All @@ -223,6 +264,11 @@ def completion(
usage = lm_handler.get_usage_summary()
self.verbose.print_final_answer(final_answer)
self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())

# Store message history in persistent environment
if self.persistent and hasattr(environment, "add_history"):
environment.add_history(message_history)

return RLMChatCompletion(
root_model=self.backend_kwargs.get("model_name", "unknown")
if self.backend_kwargs
Expand Down Expand Up @@ -292,3 +338,58 @@ def _fallback_answer(self, message: str | dict[str, Any]) -> str:
client: BaseLM = get_client(self.backend, self.backend_kwargs)
response = client.completion(message)
return response

def _validate_persistent_environment_support(self) -> None:
"""
Validate that the configured environment type supports persistent mode.

Persistent mode requires environments to implement:
- update_handler_address(address): Update LM handler address between calls
- add_context(payload, index): Add new context for multi-turn conversations
- get_context_count(): Return the number of loaded contexts

Currently only 'local' (LocalREPL) supports these methods.

Raises:
ValueError: If the environment type does not support persistent mode.
"""
# Known environments that support persistence
persistent_supported_environments = {"local"}

if self.environment_type not in persistent_supported_environments:
raise ValueError(
f"persistent=True is not supported for environment type '{self.environment_type}'. "
f"Persistent mode requires environments that implement update_handler_address(), "
f"add_context(), and get_context_count(). "
f"Supported environments: {sorted(persistent_supported_environments)}"
)

@staticmethod
def _env_supports_persistence(env: BaseEnv) -> bool:
"""Check if an environment instance supports persistent mode methods."""
return (
hasattr(env, "update_handler_address")
and hasattr(env, "add_context")
and hasattr(env, "get_context_count")
and hasattr(env, "add_history")
and hasattr(env, "get_history_count")
and callable(getattr(env, "update_handler_address", None))
and callable(getattr(env, "add_context", None))
and callable(getattr(env, "get_context_count", None))
and callable(getattr(env, "add_history", None))
and callable(getattr(env, "get_history_count", None))
)

def close(self) -> None:
"""Clean up persistent environment. Call when done with multi-turn conversations."""
if self._persistent_env is not None:
if hasattr(self._persistent_env, "cleanup"):
self._persistent_env.cleanup()
self._persistent_env = None

def __enter__(self) -> "RLM":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
self.close()
return False
111 changes: 90 additions & 21 deletions rlm/environments/local_repl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import io
import json
import os
Expand Down Expand Up @@ -130,6 +131,8 @@ def __init__(
self.original_cwd = os.getcwd()
self.temp_dir = tempfile.mkdtemp(prefix=f"repl_env_{uuid.uuid4()}_")
self._lock = threading.Lock()
self._context_count: int = 0
self._history_count: int = 0

# Setup globals, locals, and modules in environment.
self.setup()
Expand Down Expand Up @@ -222,20 +225,87 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li
return [f"Error: LM query failed - {e}"] * len(prompts)

def load_context(self, context_payload: dict | list | str):
"""Load context into the environment."""
"""Load context into the environment as context_0 (and 'context' alias)."""
self.add_context(context_payload, 0)

def add_context(
self, context_payload: dict | list | str, context_index: int | None = None
) -> int:
"""
Add a context with versioned variable name.

Args:
context_payload: The context data to add
context_index: Optional explicit index. If None, auto-increments.

Returns:
The context index used.
"""
if context_index is None:
context_index = self._context_count

var_name = f"context_{context_index}"

if isinstance(context_payload, str):
context_path = os.path.join(self.temp_dir, "context.txt")
context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt")
with open(context_path, "w") as f:
f.write(context_payload)
self.execute_code(f"with open(r'{context_path}', 'r') as f:\n context = f.read()")
self.execute_code(f"with open(r'{context_path}', 'r') as f:\n {var_name} = f.read()")
else:
context_path = os.path.join(self.temp_dir, "context.json")
context_path = os.path.join(self.temp_dir, f"context_{context_index}.json")
with open(context_path, "w") as f:
json.dump(context_payload, f)
self.execute_code(
f"import json\nwith open(r'{context_path}', 'r') as f:\n context = json.load(f)"
f"import json\nwith open(r'{context_path}', 'r') as f:\n {var_name} = json.load(f)"
)

# Alias context_0 as 'context' for backward compatibility
if context_index == 0:
self.execute_code(f"context = {var_name}")

self._context_count = max(self._context_count, context_index + 1)
return context_index

def update_handler_address(self, address: tuple[str, int]) -> None:
"""Update the LM handler address for a new completion call."""
self.lm_handler_address = address

def get_context_count(self) -> int:
"""Return the number of contexts loaded."""
return self._context_count

def add_history(
self, message_history: list[dict[str, Any]], history_index: int | None = None
) -> int:
"""
Store a conversation's message history as a versioned variable.

Args:
message_history: The list of message dicts from a completion call
history_index: Optional explicit index. If None, auto-increments.

Returns:
The history index used.
"""
if history_index is None:
history_index = self._history_count

var_name = f"history_{history_index}"

# Store deep copy to avoid reference issues with nested dicts
self.locals[var_name] = copy.deepcopy(message_history)

# Alias history_0 as 'history' for convenience
if history_index == 0:
self.locals["history"] = self.locals[var_name]

self._history_count = max(self._history_count, history_index + 1)
return history_index

def get_history_count(self) -> int:
"""Return the number of conversation histories stored."""
return self._history_count

@contextmanager
def _capture_output(self):
"""Thread-safe context manager to capture stdout/stderr."""
Expand Down Expand Up @@ -265,22 +335,21 @@ def execute_code(self, code: str) -> REPLResult:
# Clear pending LLM calls from previous execution
self._pending_llm_calls = []

with self._capture_output() as (stdout_buf, stderr_buf):
with self._temp_cwd():
try:
combined = {**self.globals, **self.locals}
exec(code, combined, combined)

# Update locals with new variables
for key, value in combined.items():
if key not in self.globals and not key.startswith("_"):
self.locals[key] = value

stdout = stdout_buf.getvalue()
stderr = stderr_buf.getvalue()
except Exception as e:
stdout = stdout_buf.getvalue()
stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}"
with self._capture_output() as (stdout_buf, stderr_buf), self._temp_cwd():
try:
combined = {**self.globals, **self.locals}
exec(code, combined, combined)

# Update locals with new variables
for key, value in combined.items():
if key not in self.globals and not key.startswith("_"):
self.locals[key] = value

stdout = stdout_buf.getvalue()
stderr = stderr_buf.getvalue()
except Exception as e:
stdout = stdout_buf.getvalue()
stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}"

return REPLResult(
stdout=stdout,
Expand Down
22 changes: 19 additions & 3 deletions rlm/utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,31 @@ def build_rlm_system_prompt(
USER_PROMPT_WITH_ROOT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the original prompt: \"{root_prompt}\".\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:"""


def build_user_prompt(root_prompt: str | None = None, iteration: int = 0) -> dict[str, str]:
def build_user_prompt(
root_prompt: str | None = None,
iteration: int = 0,
context_count: int = 1,
history_count: int = 0,
) -> dict[str, str]:
if iteration == 0:
safeguard = "You have not interacted with the REPL environment or seen your prompt / context yet. Your next action should be to look through and figure out how to answer the prompt, so don't just provide a final answer yet.\n\n"
prompt = safeguard + (
USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) if root_prompt else USER_PROMPT
)
return {"role": "user", "content": prompt}
else:
prompt = "The history before is your previous interactions with the REPL environment. " + (
USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) if root_prompt else USER_PROMPT
)
return {"role": "user", "content": prompt}

# Inform model about multiple contexts if present
if context_count > 1:
prompt += f"\n\nNote: You have {context_count} contexts available (context_0 through context_{context_count - 1})."

# Inform model about prior conversation histories if present
if history_count > 0:
if history_count == 1:
prompt += "\n\nNote: You have 1 prior conversation history available in the `history` variable."
else:
prompt += f"\n\nNote: You have {history_count} prior conversation histories available (history_0 through history_{history_count - 1})."

return {"role": "user", "content": prompt}
Loading