Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion rlm/environments/local_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,22 @@
from rlm.environments.base_env import NonIsolatedEnv

# =============================================================================
# Safe Builtins
# Safe Builtins and Protected Names
# =============================================================================

# Protected scaffold names that cannot be overwritten by user code
SCAFFOLD_NAMES = frozenset(
{
"context",
"llm_query",
"llm_query_batched",
"FINAL_VAR",
"SHOW_VARS",
"finish",
"finish_var",
}
)

# Safe builtins - blocks dangerous operations like eval/exec/input
_SAFE_BUILTINS = {
# Core types and functions
Expand Down Expand Up @@ -159,11 +172,26 @@ def setup(self):
# Track LLM calls made during code execution
self._pending_llm_calls: list[RLMChatCompletion] = []

# Track finish state for programmatic completion (§3.1 fix)
self._finish_answer: str | None = None

# Add helper functions
self.globals["FINAL_VAR"] = self._final_var
self.globals["SHOW_VARS"] = self._show_vars
self.globals["llm_query"] = self._llm_query
self.globals["llm_query_batched"] = self._llm_query_batched
self.globals["finish"] = self._finish
self.globals["finish_var"] = self._finish_var

# Store original scaffold bindings for protection (§3.2 fix)
self._scaffold_bindings = {
"FINAL_VAR": self._final_var,
"SHOW_VARS": self._show_vars,
"llm_query": self._llm_query,
"llm_query_batched": self._llm_query_batched,
"finish": self._finish,
"finish_var": self._finish_var,
}

def _final_var(self, variable_name: str) -> str:
"""Return the value of a variable as a final answer."""
Expand Down Expand Up @@ -192,6 +220,32 @@ def _show_vars(self) -> str:
return "No variables created yet. Use ```repl``` blocks to create variables."
return f"Available variables: {available}"

def _finish(self, value: Any) -> str:
"""Signal completion with a final answer (programmatic FINAL)."""
self._finish_answer = str(value)
return f"Task completed with answer: {self._finish_answer}"

def _finish_var(self, variable_name: str) -> str:
"""Signal completion with a variable's value (programmatic FINAL_VAR)."""
variable_name = variable_name.strip().strip("\"'")
if variable_name in self.locals:
self._finish_answer = str(self.locals[variable_name])
return f"Task completed with answer from variable '{variable_name}': {self._finish_answer}"

# Provide helpful error message with available variables
available = [k for k in self.locals.keys() if not k.startswith("_")]
if available:
return (
f"Error: Variable '{variable_name}' not found. "
f"Available variables: {available}. "
f"You must create and assign a variable BEFORE calling finish_var on it."
)
return (
f"Error: Variable '{variable_name}' not found. "
f"No variables have been created yet. "
f"You must create and assign a variable in a REPL block BEFORE calling finish_var on it."
)

def _llm_query(self, prompt: str, model: str | None = None) -> str:
"""Query the LM via socket connection to the handler.

Expand Down Expand Up @@ -287,6 +341,9 @@ def add_context(
# Alias context_0 as 'context' for backward compatibility
if context_index == 0:
self.execute_code(f"context = {var_name}")
# Protect 'context' in scaffold bindings
if hasattr(self, "_scaffold_bindings"):
self._scaffold_bindings["context"] = self.locals.get("context")

self._context_count = max(self._context_count, context_index + 1)
return context_index
Expand Down Expand Up @@ -370,6 +427,22 @@ def execute_code(self, code: str) -> REPLResult:
if key not in self.globals and not key.startswith("_"):
self.locals[key] = value

# Restore protected scaffold bindings if overwritten (§3.2 fix)
if hasattr(self, "_scaffold_bindings"):
for name in SCAFFOLD_NAMES:
if name in self._scaffold_bindings:
# Check if binding was overwritten in globals
if name in self.globals and self.globals[name] != self._scaffold_bindings[name]:
self.globals[name] = self._scaffold_bindings[name]
# Check if binding was overwritten in locals
if name in self.locals and self.locals[name] != self._scaffold_bindings[name]:
# For context, update the scaffold binding to track current value
if name == "context":
self._scaffold_bindings[name] = self.locals[name]
else:
# Restore other scaffold bindings
self.locals[name] = self._scaffold_bindings[name]

stdout = stdout_buf.getvalue()
stderr = stderr_buf.getvalue()
except Exception as e:
Expand Down
21 changes: 17 additions & 4 deletions rlm/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,20 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str |
Returns:
The final answer string, or None if no final answer pattern is found
"""
# Check for FINAL_VAR pattern first - must be at start of line
final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)"
# Check if finish() or finish_var() were used programmatically (preferred)
if environment is not None:
try:
answer = getattr(environment, "_finish_answer", None)
if answer is not None and isinstance(answer, str):
return answer
except Exception:
pass # Fall through to regex patterns

# Fallback to regex patterns for backward compatibility with models trained on FINAL() convention

# Check for FINAL_VAR pattern first - must be at start of line, non-greedy
# Use DOTALL to allow newlines within parentheses, but non-greedy to stop at first )
final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)\s*$"
match = re.search(final_var_pattern, text, re.MULTILINE | re.DOTALL)
if match:
variable_name = match.group(1).strip().strip('"').strip("'")
Expand All @@ -54,8 +66,9 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str |
return None

# Check for FINAL pattern - must be at start of line
# Use greedy matching to capture content with nested parentheses
final_pattern = r"^\s*FINAL\((.*)\)\s*$"
# Non-greedy (.*?) with DOTALL allows multiline but stops at first closing )
# This fixes the over-capture bug while still supporting multiline content
final_pattern = r"^\s*FINAL\((.*?)\)\s*$"
match = re.search(final_pattern, text, re.MULTILINE | re.DOTALL)
if match:
return match.group(1).strip()
Expand Down