diff --git a/rlm/environments/local_repl.py b/rlm/environments/local_repl.py index c37cd2c3..dfa94760 100644 --- a/rlm/environments/local_repl.py +++ b/rlm/environments/local_repl.py @@ -16,9 +16,22 @@ from rlm.environments.base_env import NonIsolatedEnv # ============================================================================= -# Safe Builtins +# Safe Builtins and Protected Names # ============================================================================= +# Protected scaffold names that cannot be overwritten by user code +SCAFFOLD_NAMES = frozenset( + { + "context", + "llm_query", + "llm_query_batched", + "FINAL_VAR", + "SHOW_VARS", + "finish", + "finish_var", + } +) + # Safe builtins - blocks dangerous operations like eval/exec/input _SAFE_BUILTINS = { # Core types and functions @@ -159,11 +172,26 @@ def setup(self): # Track LLM calls made during code execution self._pending_llm_calls: list[RLMChatCompletion] = [] + # Track finish state for programmatic completion (§3.1 fix) + self._finish_answer: str | None = None + # Add helper functions self.globals["FINAL_VAR"] = self._final_var self.globals["SHOW_VARS"] = self._show_vars self.globals["llm_query"] = self._llm_query self.globals["llm_query_batched"] = self._llm_query_batched + self.globals["finish"] = self._finish + self.globals["finish_var"] = self._finish_var + + # Store original scaffold bindings for protection (§3.2 fix) + self._scaffold_bindings = { + "FINAL_VAR": self._final_var, + "SHOW_VARS": self._show_vars, + "llm_query": self._llm_query, + "llm_query_batched": self._llm_query_batched, + "finish": self._finish, + "finish_var": self._finish_var, + } def _final_var(self, variable_name: str) -> str: """Return the value of a variable as a final answer.""" @@ -192,6 +220,32 @@ def _show_vars(self) -> str: return "No variables created yet. Use ```repl``` blocks to create variables." return f"Available variables: {available}" + def _finish(self, value: Any) -> str: + """Signal completion with a final answer (programmatic FINAL).""" + self._finish_answer = str(value) + return f"Task completed with answer: {self._finish_answer}" + + def _finish_var(self, variable_name: str) -> str: + """Signal completion with a variable's value (programmatic FINAL_VAR).""" + variable_name = variable_name.strip().strip("\"'") + if variable_name in self.locals: + self._finish_answer = str(self.locals[variable_name]) + return f"Task completed with answer from variable '{variable_name}': {self._finish_answer}" + + # Provide helpful error message with available variables + available = [k for k in self.locals.keys() if not k.startswith("_")] + if available: + return ( + f"Error: Variable '{variable_name}' not found. " + f"Available variables: {available}. " + f"You must create and assign a variable BEFORE calling finish_var on it." + ) + return ( + f"Error: Variable '{variable_name}' not found. " + f"No variables have been created yet. " + f"You must create and assign a variable in a REPL block BEFORE calling finish_var on it." + ) + def _llm_query(self, prompt: str, model: str | None = None) -> str: """Query the LM via socket connection to the handler. @@ -287,6 +341,9 @@ def add_context( # Alias context_0 as 'context' for backward compatibility if context_index == 0: self.execute_code(f"context = {var_name}") + # Protect 'context' in scaffold bindings + if hasattr(self, "_scaffold_bindings"): + self._scaffold_bindings["context"] = self.locals.get("context") self._context_count = max(self._context_count, context_index + 1) return context_index @@ -370,6 +427,22 @@ def execute_code(self, code: str) -> REPLResult: if key not in self.globals and not key.startswith("_"): self.locals[key] = value + # Restore protected scaffold bindings if overwritten (§3.2 fix) + if hasattr(self, "_scaffold_bindings"): + for name in SCAFFOLD_NAMES: + if name in self._scaffold_bindings: + # Check if binding was overwritten in globals + if name in self.globals and self.globals[name] != self._scaffold_bindings[name]: + self.globals[name] = self._scaffold_bindings[name] + # Check if binding was overwritten in locals + if name in self.locals and self.locals[name] != self._scaffold_bindings[name]: + # For context, update the scaffold binding to track current value + if name == "context": + self._scaffold_bindings[name] = self.locals[name] + else: + # Restore other scaffold bindings + self.locals[name] = self._scaffold_bindings[name] + stdout = stdout_buf.getvalue() stderr = stderr_buf.getvalue() except Exception as e: diff --git a/rlm/utils/parsing.py b/rlm/utils/parsing.py index f0bdd40d..da98027c 100644 --- a/rlm/utils/parsing.py +++ b/rlm/utils/parsing.py @@ -40,8 +40,20 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | Returns: The final answer string, or None if no final answer pattern is found """ - # Check for FINAL_VAR pattern first - must be at start of line - final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)" + # Check if finish() or finish_var() were used programmatically (preferred) + if environment is not None: + try: + answer = getattr(environment, "_finish_answer", None) + if answer is not None and isinstance(answer, str): + return answer + except Exception: + pass # Fall through to regex patterns + + # Fallback to regex patterns for backward compatibility with models trained on FINAL() convention + + # Check for FINAL_VAR pattern first - must be at start of line, non-greedy + # Use DOTALL to allow newlines within parentheses, but non-greedy to stop at first ) + final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)\s*$" match = re.search(final_var_pattern, text, re.MULTILINE | re.DOTALL) if match: variable_name = match.group(1).strip().strip('"').strip("'") @@ -54,8 +66,9 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | return None # Check for FINAL pattern - must be at start of line - # Use greedy matching to capture content with nested parentheses - final_pattern = r"^\s*FINAL\((.*)\)\s*$" + # Non-greedy (.*?) with DOTALL allows multiline but stops at first closing ) + # This fixes the over-capture bug while still supporting multiline content + final_pattern = r"^\s*FINAL\((.*?)\)\s*$" match = re.search(final_pattern, text, re.MULTILINE | re.DOTALL) if match: return match.group(1).strip()