From e4bef47d0975c61a6f0d2499757a6ba66063d540 Mon Sep 17 00:00:00 2001 From: Sugam Devare Date: Tue, 20 Jan 2026 16:36:15 -0800 Subject: [PATCH 1/5] feat: add action latency --- evaluation/benchmarks/swe_bench/run_infer.py | 25 +++++++++++++++++++- openhands/events/serialization/event.py | 4 +++- openhands/llm/llm.py | 4 ++-- openhands/runtime/base.py | 8 +++++++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index 51c67571bdf9..ae027fc45833 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -55,11 +55,12 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller from openhands.critic import AgentFinishedCritic -from openhands.events.action import CmdRunAction, FileReadAction, MessageAction +from openhands.events.action import CmdRunAction, FileReadAction, MessageAction from openhands.events.observation import ( CmdOutputObservation, ErrorObservation, FileReadObservation, + Observation, ) from openhands.events.serialization.event import event_from_dict, event_to_dict from openhands.runtime.base import Runtime @@ -733,6 +734,9 @@ def process_instance( histories = [event_to_dict(event) for event in state.history] metrics = get_metrics(state) + # Calculate action execution times from history + metrics['action_execution_latencies'] = get_action_execution_latencies(state.history) + # Save the output instruction = message_action.content if message_action.image_urls: @@ -752,6 +756,25 @@ def process_instance( return output +def get_action_execution_latencies(history: list) -> list[dict]: + """Extract execution latencies from observations in the history.""" + latencies = [] + for event in history: + if isinstance(event, Observation): + execution_latency = getattr(event, '_execution_latency', None) + if execution_latency is None: + execution_latency = getattr(event, 'execution_latency', None) + if execution_latency is not None: + latencies.append({ + 'observation_type': type(event).__name__, + 'observation_id': str(event.id), + 'latency': float(execution_latency), + 'message': event.message, + 'timestamp': event.timestamp, + }) + return latencies + + def filter_dataset( dataset: pd.DataFrame, filter_column: str, diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index a95992a4f186..ca3155f06885 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -23,6 +23,7 @@ 'observation', 'tool_call_metadata', 'llm_metrics', + 'execution_latency', ] UNDERSCORE_KEYS = [ 'id', @@ -31,6 +32,7 @@ 'cause', 'tool_call_metadata', 'llm_metrics', + 'execution_latency', ] DELETE_FROM_TRAJECTORY_EXTRAS = { @@ -71,7 +73,7 @@ def event_from_dict(data: dict[str, Any]) -> 'Event': model_response_dict = value['model_response'] if isinstance(model_response_dict, dict) and 'provider_specific_fields' in model_response_dict: provider_specific_fields = model_response_dict.pop('provider_specific_fields') - + value = ToolCallMetadata(**value) # Add provider_specific_fields back to the model_response diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index ff066f58626a..2198f71cc53c 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -360,7 +360,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: kwargs.pop('extra_body', None) # Record start time for latency measurement - start_time = time.time() + start_time = time.perf_counter() # we don't support streaming here, thus we get a ModelResponse # Suppress httpx deprecation warnings during LiteLLM calls @@ -386,7 +386,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: self.response_headers = resp._response_headers # Calculate and record latency - latency = time.time() - start_time + latency = time.perf_counter() - start_time response_id = resp.get('id', 'unknown') self.metrics.add_response_latency(latency, response_id) diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 4f563504c786..d8e173b4580c 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -8,6 +8,7 @@ import shutil import string import tempfile +import time from abc import abstractmethod from pathlib import Path from types import MappingProxyType @@ -373,6 +374,9 @@ async def _handle_action(self, event: Action) -> None: # We don't block the command if this is a default timeout action event.set_hard_timeout(self.config.sandbox.timeout, blocking=False) assert event.timeout is not None + + action_start_time = time.perf_counter() + try: await self._export_latest_git_provider_tokens(event) if isinstance(event, MCPAction): @@ -384,10 +388,13 @@ async def _handle_action(self, event: Action) -> None: blocking_val = False event.set_hard_timeout(min(event.timeout,600), blocking=blocking_val) observation = await call_sync_from_async(self.run_action, event) + + observation._execution_latency = time.perf_counter() - action_start_time except PermissionError as e: # Handle PermissionError specially - convert to ErrorObservation # so the agent can receive feedback and continue execution observation = ErrorObservation(content=str(e)) + observation._execution_latency = time.perf_counter() - action_start_time except AgentRuntimeTimeoutError as e: # Handle timeout errors by converting to ErrorObservation # so the agent can receive feedback and try a different approach @@ -409,6 +416,7 @@ async def _handle_action(self, event: Action) -> None: 'Consider trying a different approach, breaking the task into smaller steps, ' 'or optimizing your command to complete within the time limit.' ) + observation._execution_latency = time.perf_counter() - action_start_time except (httpx.NetworkError, AgentRuntimeDisconnectedError) as e: runtime_status = RuntimeStatus.ERROR_RUNTIME_DISCONNECTED error_message = f'{type(e).__name__}: {str(e)}' From 8acdde3cc6ccdf1e85006da381b7ef73330dd475 Mon Sep 17 00:00:00 2001 From: Sugam Devare Date: Wed, 21 Jan 2026 16:52:50 -0800 Subject: [PATCH 2/5] feat: add kill blocklist --- openhands/runtime/utils/bash.py | 230 +++++++------- openhands/runtime/utils/command_blacklist.py | 233 ++++++++++++++ scripts/test_command_blacklist.py | 317 +++++++++++++++++++ 3 files changed, 673 insertions(+), 107 deletions(-) create mode 100644 openhands/runtime/utils/command_blacklist.py create mode 100644 scripts/test_command_blacklist.py diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index 9f52c7a91968..aa231f945b28 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -17,22 +17,23 @@ CmdOutputObservation, ) from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE +from openhands.runtime.utils.command_blacklist import check_command_blacklist from openhands.utils.shutdown_listener import should_continue -RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME') -SU_TO_USER = os.getenv('SU_TO_USER', 'true').lower() in ( - '1', - 'true', - 't', - 'yes', - 'y', - 'on', +RUNTIME_USERNAME = os.getenv("RUNTIME_USERNAME") +SU_TO_USER = os.getenv("SU_TO_USER", "true").lower() in ( + "1", + "true", + "t", + "yes", + "y", + "on", ) def split_bash_commands(commands: str) -> list[str]: if not commands.strip(): - return [''] + return [""] try: parsed = bashlex.parse(commands) except ( @@ -43,9 +44,9 @@ def split_bash_commands(commands: str) -> list[str]: ): # Added AttributeError to catch 'str' object has no attribute 'kind' error (issue #8369) logger.debug( - f'Failed to parse bash commands\n' - f'[input]: {commands}\n' - f'The original command will be returned as is.', + f"Failed to parse bash commands\n" + f"[input]: {commands}\n" + f"The original command will be returned as is.", exc_info=True, ) # If parsing fails, return the original commands @@ -60,7 +61,7 @@ def split_bash_commands(commands: str) -> list[str]: # Include any text between the last command and this one if start > last_end: between = commands[last_end:start] - logger.debug(f'BASH PARSING between: {between}') + logger.debug(f"BASH PARSING between: {between}") if result: result[-1] += between.rstrip() elif between.strip(): @@ -69,21 +70,21 @@ def split_bash_commands(commands: str) -> list[str]: # Extract the command, preserving original formatting command = commands[start:end].rstrip() - logger.debug(f'BASH PARSING command: {command}') + logger.debug(f"BASH PARSING command: {command}") result.append(command) last_end = end # Add any remaining text after the last command to the last command remaining = commands[last_end:].rstrip() - logger.debug(f'BASH PARSING remaining: {remaining}') + logger.debug(f"BASH PARSING remaining: {remaining}") if last_end < len(commands) and result: result[-1] += remaining - logger.debug(f'BASH PARSING result[-1] += remaining: {result[-1]}') + logger.debug(f"BASH PARSING result[-1] += remaining: {result[-1]}") elif last_end < len(commands): if remaining: result.append(remaining) - logger.debug(f'BASH PARSING result.append(remaining): {result[-1]}') + logger.debug(f"BASH PARSING result.append(remaining): {result[-1]}") return result @@ -91,8 +92,8 @@ def escape_bash_special_chars(command: str) -> str: r"""Escapes characters that have different interpretations in bash vs python. Specifically handles escape sequences like \;, \|, \&, etc. """ - if command.strip() == '': - return '' + if command.strip() == "": + return "" try: parts = [] @@ -101,8 +102,8 @@ def escape_bash_special_chars(command: str) -> str: def visit_node(node: Any) -> None: nonlocal last_pos if ( - node.kind == 'redirect' - and hasattr(node, 'heredoc') + node.kind == "redirect" + and hasattr(node, "heredoc") and node.heredoc is not None ): # We're entering a heredoc - preserve everything as-is until we see EOF @@ -116,34 +117,34 @@ def visit_node(node: Any) -> None: last_pos = node.pos[1] return - if node.kind == 'word': + if node.kind == "word": # Get the raw text between the last position and current word between = command[last_pos : node.pos[0]] word_text = command[node.pos[0] : node.pos[1]] # Add the between text, escaping special characters - between = re.sub(r'\\([;&|><])', r'\\\\\1', between) + between = re.sub(r"\\([;&|><])", r"\\\\\1", between) parts.append(between) # Check if word_text is a quoted string or command substitution if ( (word_text.startswith('"') and word_text.endswith('"')) or (word_text.startswith("'") and word_text.endswith("'")) - or (word_text.startswith('$(') and word_text.endswith(')')) - or (word_text.startswith('`') and word_text.endswith('`')) + or (word_text.startswith("$(") and word_text.endswith(")")) + or (word_text.startswith("`") and word_text.endswith("`")) ): # Preserve quoted strings, command substitutions, and heredoc content as-is parts.append(word_text) else: # Escape special chars in unquoted text - word_text = re.sub(r'\\([;&|><])', r'\\\\\1', word_text) + word_text = re.sub(r"\\([;&|><])", r"\\\\\1", word_text) parts.append(word_text) last_pos = node.pos[1] return # Visit child nodes - if hasattr(node, 'parts'): + if hasattr(node, "parts"): for part in node.parts: visit_node(part) @@ -151,7 +152,7 @@ def visit_node(node: Any) -> None: nodes = list(bashlex.parse(command)) for node in nodes: between = command[last_pos : node.pos[0]] - between = re.sub(r'\\([;&|><])', r'\\\\\1', between) + between = re.sub(r"\\([;&|><])", r"\\\\\1", between) parts.append(between) last_pos = node.pos[0] visit_node(node) @@ -159,22 +160,22 @@ def visit_node(node: Any) -> None: # Handle any remaining text after the last word remaining = command[last_pos:] parts.append(remaining) - return ''.join(parts) + return "".join(parts) except (bashlex.errors.ParsingError, NotImplementedError, TypeError): logger.debug( - f'Failed to parse bash commands for special characters escape\n' - f'[input]: {command}\n' - f'The original command will be returned as is.', + f"Failed to parse bash commands for special characters escape\n" + f"[input]: {command}\n" + f"The original command will be returned as is.", exc_info=True, ) return command class BashCommandStatus(Enum): - CONTINUE = 'continue' - COMPLETED = 'completed' - NO_CHANGE_TIMEOUT = 'no_change_timeout' - HARD_TIMEOUT = 'hard_timeout' + CONTINUE = "continue" + COMPLETED = "completed" + NO_CHANGE_TIMEOUT = "no_change_timeout" + HARD_TIMEOUT = "hard_timeout" def _remove_command_prefix(command_output: str, command: str) -> str: @@ -201,12 +202,12 @@ def __init__( def initialize(self) -> None: self.server = libtmux.Server() - _shell_command = '/bin/bash' + _shell_command = "/bin/bash" if SU_TO_USER and self.username in list( - filter(None, [RUNTIME_USERNAME, 'root', 'openhands']) + filter(None, [RUNTIME_USERNAME, "root", "openhands"]) ): # This starts a non-login (new) shell for the given user - _shell_command = f'su {self.username} -' + _shell_command = f"su {self.username} -" # FIXME: we will introduce memory limit using sysbox-runc in coming PR # # otherwise, we are running as the CURRENT USER (e.g., when running LocalRuntime) @@ -218,9 +219,9 @@ def initialize(self) -> None: window_command = _shell_command logger.debug( - f'Initializing bash session in {self.work_dir} with command: {window_command}' + f"Initializing bash session in {self.work_dir} with command: {window_command}" ) - session_name = f'openhands-{self.username}-{uuid.uuid4()}' + session_name = f"openhands-{self.username}-{uuid.uuid4()}" self.session = self.server.new_session( session_name=session_name, start_directory=self.work_dir, # This parameter is supported by libtmux @@ -231,17 +232,17 @@ def initialize(self) -> None: # Set history limit to a large number to avoid losing history # https://unix.stackexchange.com/questions/43414/unlimited-history-in-tmux - self.session.set_option('history-limit', str(self.HISTORY_LIMIT), global_=True) + self.session.set_option("history-limit", str(self.HISTORY_LIMIT), global_=True) self.session.history_limit = self.HISTORY_LIMIT # We need to create a new pane because the initial pane's history limit is (default) 2000 _initial_window = self.session.active_window self.window = self.session.new_window( - window_name='bash', + window_name="bash", window_shell=window_command, start_directory=self.work_dir, # This parameter is supported by libtmux ) self.pane = self.window.active_pane - logger.debug(f'pane: {self.pane}; history_limit: {self.session.history_limit}') + logger.debug(f"pane: {self.pane}; history_limit: {self.session.history_limit}") _initial_window.kill() # Configure bash to use simple PS1 and disable PS2 @@ -253,9 +254,9 @@ def initialize(self) -> None: # Store the last command for interactive input handling self.prev_status: BashCommandStatus | None = None - self.prev_output: str = '' + self.prev_output: str = "" self._closed: bool = False - logger.debug(f'Bash session initialized with work dir: {self.work_dir}') + logger.debug(f"Bash session initialized with work dir: {self.work_dir}") # Maintain the current working directory self._cwd = os.path.abspath(self.work_dir) @@ -267,11 +268,11 @@ def __del__(self) -> None: def _get_pane_content(self) -> str: """Capture the current pane content and update the buffer.""" - content = '\n'.join( + content = "\n".join( map( # avoid double newlines lambda line: line.rstrip(), - self.pane.cmd('capture-pane', '-J', '-pS', '-').stdout, + self.pane.cmd("capture-pane", "-J", "-pS", "-").stdout, ) ) return content @@ -291,20 +292,20 @@ def _is_special_key(self, command: str) -> bool: """Check if the command is a special key.""" # Special keys are of the form C- _command = command.strip() - return _command.startswith('C-') and len(_command) == 3 + return _command.startswith("C-") and len(_command) == 3 def _clear_screen(self) -> None: """Clear the tmux pane screen and history.""" - self.pane.send_keys('C-l', enter=False) + self.pane.send_keys("C-l", enter=False) time.sleep(0.1) - self.pane.cmd('clear-history') + self.pane.cmd("clear-history") def _get_command_output( self, command: str, raw_command_output: str, metadata: CmdOutputMetadata, - continue_prefix: str = '', + continue_prefix: str = "", ) -> str: """Get the command output with the previous command output removed. @@ -333,8 +334,8 @@ def _handle_completed_command( ) -> CmdOutputObservation: is_special_key = self._is_special_key(command) assert len(ps1_matches) >= 1, ( - f'Expected at least one PS1 metadata block, but got {len(ps1_matches)}.\n' - f'---FULL OUTPUT---\n{pane_content!r}\n---END OF OUTPUT---' + f"Expected at least one PS1 metadata block, but got {len(ps1_matches)}.\n" + f"---FULL OUTPUT---\n{pane_content!r}\n---END OF OUTPUT---" ) metadata = CmdOutputMetadata.from_ps1_match(ps1_matches[-1]) @@ -345,11 +346,11 @@ def _handle_completed_command( # Update the current working directory if it has changed if metadata.working_dir != self._cwd and metadata.working_dir: logger.debug( - f'directory_changed: {self._cwd}; {metadata.working_dir}; {command}' + f"directory_changed: {self._cwd}; {metadata.working_dir}; {command}" ) self._cwd = metadata.working_dir - logger.debug(f'COMMAND OUTPUT: {pane_content}') + logger.debug(f"COMMAND OUTPUT: {pane_content}") # Extract the command output between the two PS1 prompts raw_command_output = self._combine_outputs_between_matches( pane_content, @@ -360,12 +361,12 @@ def _handle_completed_command( if get_content_before_last_match: # Count the number of lines in the truncated output num_lines = len(raw_command_output.splitlines()) - metadata.prefix = f'[Previous command outputs are truncated. Showing the last {num_lines} lines of the output below.]\n' + metadata.prefix = f"[Previous command outputs are truncated. Showing the last {num_lines} lines of the output below.]\n" metadata.suffix = ( - f'\n[The command completed with exit code {metadata.exit_code}.]' + f"\n[The command completed with exit code {metadata.exit_code}.]" if not is_special_key - else f'\n[The command completed with exit code {metadata.exit_code}. CTRL+{command[-1].upper()} was sent.]' + else f"\n[The command completed with exit code {metadata.exit_code}. CTRL+{command[-1].upper()} was sent.]" ) command_output = self._get_command_output( command, @@ -373,7 +374,7 @@ def _handle_completed_command( metadata, ) self.prev_status = BashCommandStatus.COMPLETED - self.prev_output = '' # Reset previous command output + self.prev_output = "" # Reset previous command output self._ready_for_next_command() return CmdOutputObservation( content=command_output, @@ -391,22 +392,22 @@ def _handle_nochange_timeout_command( self.prev_status = BashCommandStatus.NO_CHANGE_TIMEOUT if len(ps1_matches) != 1: logger.warning( - 'Expected exactly one PS1 metadata block BEFORE the execution of a command, ' - f'but got {len(ps1_matches)} PS1 metadata blocks:\n---\n{pane_content!r}\n---' + "Expected exactly one PS1 metadata block BEFORE the execution of a command, " + f"but got {len(ps1_matches)} PS1 metadata blocks:\n---\n{pane_content!r}\n---" ) raw_command_output = self._combine_outputs_between_matches( pane_content, ps1_matches ) metadata = CmdOutputMetadata() # No metadata available metadata.suffix = ( - f'\n[The command has no new output after {self.NO_CHANGE_TIMEOUT_SECONDS} seconds. ' - f'{TIMEOUT_MESSAGE_TEMPLATE}]' + f"\n[The command has no new output after {self.NO_CHANGE_TIMEOUT_SECONDS} seconds. " + f"{TIMEOUT_MESSAGE_TEMPLATE}]" ) command_output = self._get_command_output( command, raw_command_output, metadata, - continue_prefix='[Below is the output of the previous command.]\n', + continue_prefix="[Below is the output of the previous command.]\n", ) return CmdOutputObservation( content=command_output, @@ -424,22 +425,22 @@ def _handle_hard_timeout_command( self.prev_status = BashCommandStatus.HARD_TIMEOUT if len(ps1_matches) != 1: logger.warning( - 'Expected exactly one PS1 metadata block BEFORE the execution of a command, ' - f'but got {len(ps1_matches)} PS1 metadata blocks:\n---\n{pane_content!r}\n---' + "Expected exactly one PS1 metadata block BEFORE the execution of a command, " + f"but got {len(ps1_matches)} PS1 metadata blocks:\n---\n{pane_content!r}\n---" ) raw_command_output = self._combine_outputs_between_matches( pane_content, ps1_matches ) metadata = CmdOutputMetadata() # No metadata available metadata.suffix = ( - f'\n[The command timed out after {timeout} seconds. ' - f'{TIMEOUT_MESSAGE_TEMPLATE}]' + f"\n[The command timed out after {timeout} seconds. " + f"{TIMEOUT_MESSAGE_TEMPLATE}]" ) command_output = self._get_command_output( command, raw_command_output, metadata, - continue_prefix='[Below is the output of the previous command.]\n', + continue_prefix="[Below is the output of the previous command.]\n", ) return CmdOutputObservation( @@ -479,25 +480,25 @@ def _combine_outputs_between_matches( return pane_content[ps1_matches[0].end() + 1 :] elif len(ps1_matches) == 0: return pane_content - combined_output = '' + combined_output = "" for i in range(len(ps1_matches) - 1): # Extract content between current and next PS1 prompt output_segment = pane_content[ ps1_matches[i].end() + 1 : ps1_matches[i + 1].start() ] - combined_output += output_segment + '\n' + combined_output += output_segment + "\n" # Add the content after the last PS1 prompt combined_output += pane_content[ps1_matches[-1].end() + 1 :] - logger.debug(f'COMBINED OUTPUT: {combined_output}') + logger.debug(f"COMBINED OUTPUT: {combined_output}") return combined_output def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation: """Execute a command in the bash session.""" if not self._initialized: - raise RuntimeError('Bash session is not initialized') + raise RuntimeError("Bash session is not initialized") # Strip the command of any leading/trailing whitespace - logger.debug(f'RECEIVED ACTION: {action}') + logger.debug(f"RECEIVED ACTION: {action}") command = action.command.strip() is_input: bool = action.is_input @@ -507,16 +508,16 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati BashCommandStatus.NO_CHANGE_TIMEOUT, BashCommandStatus.HARD_TIMEOUT, }: - if command == '': + if command == "": return CmdOutputObservation( - content='ERROR: No previous running command to retrieve logs from.', - command='', + content="ERROR: No previous running command to retrieve logs from.", + command="", metadata=CmdOutputMetadata(), ) if is_input: return CmdOutputObservation( - content='ERROR: No previous running command to interact with.', - command='', + content="ERROR: No previous running command to interact with.", + command="", metadata=CmdOutputMetadata(), ) @@ -525,19 +526,34 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati if len(splited_commands) > 1: return ErrorObservation( content=( - f'ERROR: Cannot execute multiple commands at once.\n' - f'Please run each command separately OR chain them into a single command via && or ;\n' - f'Provided commands:\n{"\n".join(f"({i + 1}) {cmd}" for i, cmd in enumerate(splited_commands))}' + f"ERROR: Cannot execute multiple commands at once.\n" + f"Please run each command separately OR chain them into a single command via && or ;\n" + f"Provided commands:\n{'\n'.join(f'({i + 1}) {cmd}' for i, cmd in enumerate(splited_commands))}" ) ) + # Check if the command is blacklisted (only for non-input commands) + if not is_input and command: + blacklist_result = check_command_blacklist(command) + if blacklist_result.is_blocked: + logger.warning(f"Command blocked by blacklist: {command!r}") + return ErrorObservation( + content=blacklist_result.feedback, + error_id="COMMAND_BLACKLISTED", + ) + if "kill" in command or "rm" in command: + print( + f"[POSSIBLE BLACKLIST]Command {command} bypassed blacklist check", + flush=True, + ) + # Get initial state before sending command initial_pane_output = self._get_pane_content() initial_ps1_matches = CmdOutputMetadata.matches_ps1_metadata( initial_pane_output ) initial_ps1_count = len(initial_ps1_matches) - logger.debug(f'Initial PS1 count: {initial_ps1_count}') + logger.debug(f"Initial PS1 count: {initial_ps1_count}") start_time = time.time() last_change_time = start_time @@ -556,7 +572,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati CMD_OUTPUT_PS1_END.rstrip() ) # prev command is not completed and not is_input - and command != '' # not input and not empty command + and command != "" # not input and not empty command ): _ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output) # Use initial_ps1_matches if _ps1_matches is empty, otherwise use _ps1_matches @@ -570,29 +586,29 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati metadata = CmdOutputMetadata() # No metadata available metadata.suffix = ( f'\n[Your command "{command}" is NOT executed. ' - 'The previous command is still running - You CANNOT send new commands until the previous command is completed. ' - 'By setting `is_input` to `true`, you can interact with the current process: ' - f'{TIMEOUT_MESSAGE_TEMPLATE}]' + "The previous command is still running - You CANNOT send new commands until the previous command is completed. " + "By setting `is_input` to `true`, you can interact with the current process: " + f"{TIMEOUT_MESSAGE_TEMPLATE}]" ) - logger.debug(f'PREVIOUS COMMAND OUTPUT: {raw_command_output}') + logger.debug(f"PREVIOUS COMMAND OUTPUT: {raw_command_output}") command_output = self._get_command_output( command, raw_command_output, metadata, - continue_prefix='[Below is the output of the previous command.]\n', + continue_prefix="[Below is the output of the previous command.]\n", ) return CmdOutputObservation( command=command, content=command_output, metadata=metadata, - hidden=getattr(action, 'hidden', False), + hidden=getattr(action, "hidden", False), ) # Send actual command/inputs to the pane - if command != '': + if command != "": is_special_key = self._is_special_key(command) if is_input: - logger.debug(f'SENDING INPUT TO RUNNING PROCESS: {command!r}') + logger.debug(f"SENDING INPUT TO RUNNING PROCESS: {command!r}") self.pane.send_keys( command, enter=not is_special_key, @@ -600,7 +616,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati else: # convert command to raw string command = escape_bash_special_chars(command) - logger.debug(f'SENDING COMMAND: {command!r}') + logger.debug(f"SENDING COMMAND: {command!r}") self.pane.send_keys( command, enter=not is_special_key, @@ -609,24 +625,24 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # Loop until the command completes or times out while should_continue(): _start_time = time.time() - logger.debug(f'GETTING PANE CONTENT at {_start_time}') + logger.debug(f"GETTING PANE CONTENT at {_start_time}") cur_pane_output = self._get_pane_content() logger.debug( - f'PANE CONTENT GOT after {time.time() - _start_time:.2f} seconds' + f"PANE CONTENT GOT after {time.time() - _start_time:.2f} seconds" ) - cur_pane_lines = cur_pane_output.split('\n') + cur_pane_lines = cur_pane_output.split("\n") if len(cur_pane_lines) <= 20: - logger.debug('PANE_CONTENT: {cur_pane_output}') + logger.debug("PANE_CONTENT: {cur_pane_output}") else: - logger.debug(f'BEGIN OF PANE CONTENT: {cur_pane_lines[:10]}') - logger.debug(f'END OF PANE CONTENT: {cur_pane_lines[-10:]}') + logger.debug(f"BEGIN OF PANE CONTENT: {cur_pane_lines[:10]}") + logger.debug(f"END OF PANE CONTENT: {cur_pane_lines[-10:]}") ps1_matches = CmdOutputMetadata.matches_ps1_metadata(cur_pane_output) current_ps1_count = len(ps1_matches) if cur_pane_output != last_pane_output: last_pane_output = cur_pane_output last_change_time = time.time() - logger.debug(f'CONTENT UPDATED DETECTED at {last_change_time}') + logger.debug(f"CONTENT UPDATED DETECTED at {last_change_time}") # 1) Execution completed: # Condition 1: A new prompt has appeared since the command started. @@ -640,7 +656,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati command, pane_content=cur_pane_output, ps1_matches=ps1_matches, - hidden=getattr(action, 'hidden', False), + hidden=getattr(action, "hidden", False), ) # Timeout checks should only trigger if a new prompt hasn't appeared yet. @@ -650,7 +666,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # We ignore this if the command is *blocking* time_since_last_change = time.time() - last_change_time logger.debug( - f'CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}. Action blocking: {action.blocking}' + f"CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}. Action blocking: {action.blocking}" ) if ( not action.blocking @@ -665,10 +681,10 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # 3) Execution timed out due to hard timeout elapsed_time = time.time() - start_time logger.debug( - f'CHECKING HARD TIMEOUT ({action.timeout}s): elapsed {elapsed_time:.2f}' + f"CHECKING HARD TIMEOUT ({action.timeout}s): elapsed {elapsed_time:.2f}" ) if action.timeout and elapsed_time >= action.timeout: - logger.debug('Hard timeout triggered.') + logger.debug("Hard timeout triggered.") return self._handle_hard_timeout_command( command, pane_content=cur_pane_output, @@ -676,6 +692,6 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati timeout=action.timeout, ) - logger.debug(f'SLEEPING for {self.POLL_INTERVAL} seconds for next poll') + logger.debug(f"SLEEPING for {self.POLL_INTERVAL} seconds for next poll") time.sleep(self.POLL_INTERVAL) - raise RuntimeError('Bash session was likely interrupted...') + raise RuntimeError("Bash session was likely interrupted...") diff --git a/openhands/runtime/utils/command_blacklist.py b/openhands/runtime/utils/command_blacklist.py new file mode 100644 index 000000000000..39f9b3208333 --- /dev/null +++ b/openhands/runtime/utils/command_blacklist.py @@ -0,0 +1,233 @@ +"""Command blacklist for blocking dangerous commands in the runtime. + +This module provides a mechanism to block potentially dangerous commands +that could harm the runtime environment, such as killing the action executor +server or destroying critical system resources. +""" + +import os +import re +from dataclasses import dataclass +from typing import Optional + +from openhands.core.logger import openhands_logger as logger + + +@dataclass +class BlacklistEntry: + """A blacklisted command pattern with associated feedback.""" + + pattern: str # Regex pattern to match + feedback: str # Feedback message to return when blocked + description: str # Human-readable description of what this blocks + enabled: bool = True # Whether this rule is active + + +BLACKLIST_ENABLED = os.getenv( + "OPENHANDS_COMMAND_BLACKLIST_ENABLED", "true" +).lower() in ( + "1", + "true", + "t", + "yes", + "y", + "on", +) + +# The blacklist of dangerous command patterns +COMMAND_BLACKLIST: list[BlacklistEntry] = [ + # === Block ALL killall commands (killall always targets by name, not PID) === + BlacklistEntry( + pattern=r"\bkillall\b", + feedback=( + "ERROR: The `killall` command is not allowed.\n" + "`killall` kills processes by name, which could terminate critical OpenHands processes.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks all killall commands", + ), + # === Block ALL pkill commands (pkill always targets by pattern, not PID) === + BlacklistEntry( + pattern=r"\bpkill\b", + feedback=( + "ERROR: The `pkill` command is not allowed.\n" + "`pkill` kills processes by pattern matching, which could terminate critical OpenHands processes.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks all pkill commands", + ), + # === Block kill with command substitution $(...) === + BlacklistEntry( + pattern=r"\bkill\s+.*\$\(", + feedback=( + "ERROR: The `kill` command with command substitution is not allowed.\n" + "Using `kill $(...)` could terminate unintended processes.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill with command substitution $()", + ), + # === Block kill with backtick command substitution === + BlacklistEntry( + pattern=r"\bkill\s+.*`", + feedback=( + "ERROR: The `kill` command with command substitution is not allowed.\n" + "Using kill with backticks could terminate unintended processes.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill with backtick command substitution", + ), + # === Block kill with variables === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*\$\w+", + feedback=( + "ERROR: The `kill` command with shell variables is not allowed.\n" + "Using `kill $var` could terminate unintended processes if the variable contains unexpected values.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill with shell variables", + ), + # === Block kill -1 or kill -9 -1 (kills all user processes) === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*-1\b", + feedback=( + "ERROR: This command would kill all processes you own.\n" + "`kill -1` or `kill -9 -1` sends a signal to all your processes, including the action executor.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill -1 which kills all user processes", + ), + # === Block kill 0 (kills all processes in the process group) === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*0\b", + feedback=( + "ERROR: This command would kill all processes in the current process group.\n" + "`kill 0` sends a signal to all processes in your process group, including the action executor.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill 0 which kills the process group", + ), + + # === Block kill with negative PIDs (process groups) === + # Negative PIDs like -12345 kill entire process groups + # We need to catch: kill -12345, kill -9 -12345, etc. + # But NOT catch: kill -9 12345 (where -9 is signal, 12345 is positive PID) + BlacklistEntry( + pattern=r"\bkill\s+(-[1-9]|-1[0-5]|-[A-Z]+|-SIG[A-Z]+)?\s*-([2-9]\d\d+|[1-9]\d\d\d+)\s*$", + feedback=( + "ERROR: The `kill` command with negative PIDs (process groups) is not allowed.\n" + "Using negative PIDs kills entire process groups, which could terminate critical processes.\n\n" + "SUGGESTION: Use `ps aux | grep ` to find the specific PID, " + "then use `kill -9 ` to terminate only that process." + ), + description="Blocks kill with negative PIDs (process groups)", + ), + # === Dangerous rm commands === + # Block rm -rf / or rm -rf /* (root filesystem) + BlacklistEntry( + pattern=r"\brm\s+(-\w+\s+)*(/\s*$|/\*)", + feedback=( + "ERROR: This command would destroy the entire filesystem.\n" + "`rm -rf /` or `rm -rf /*` is a catastrophically dangerous command.\n\n" + "SUGGESTION: Be specific about which directory you want to remove." + ), + description="Blocks rm -rf / or rm -rf /*", + ), + # Block rm of critical top-level directories (exact match, not subdirs) + BlacklistEntry( + pattern=r"\brm\s+(-\w+\s+)*(/(bin|usr|etc|var|home|root|opt|lib|lib64|sbin|boot|dev|proc|sys))\s*$", + feedback=( + "ERROR: This command would delete critical system directories.\n" + "Deleting root-level directories like /bin, /usr, /etc, etc. is blocked.\n\n" + "SUGGESTION: Be more specific about which files or directories you want to delete. " + "Use absolute paths to the specific items you want to remove." + ), + description="Blocks rm of critical system directories", + ), + # === Commands that could affect the tmux session === + BlacklistEntry( + pattern=r"\btmux\s+(kill-server|kill-session\s+-t\s+openhands)", + feedback=( + "ERROR: This command would terminate the OpenHands tmux session.\n" + "Killing the tmux server or the openhands session would break command execution.\n\n" + "SUGGESTION: If you need to kill a specific process, use `ps aux | grep ` " + "to find the specific PID, then use `kill -9 ` to terminate only that process." + ), + description="Blocks killing the openhands tmux session", + ), + # === Shutdown/reboot commands === + BlacklistEntry( + pattern=r"\b(shutdown|reboot|poweroff|halt|init\s+[06])\b", + feedback=( + "ERROR: System shutdown/reboot commands are blocked.\n" + "These commands would terminate the runtime environment.\n\n" + "This type of command is not allowed in this environment." + ), + description="Blocks system shutdown/reboot commands", + ), + # === Dangerous dd commands === + BlacklistEntry( + pattern=r"\bdd\s+.*of=\s*(/dev/sd[a-z]|/dev/nvme\w*|/dev/hd[a-z]|/dev/null)\b", + feedback=( + "ERROR: This dd command could overwrite disk devices.\n" + "Writing directly to disk devices is blocked to prevent data loss.\n\n" + "SUGGESTION: Use standard file operations instead of dd for file manipulation." + ), + description="Blocks dd commands that write to disk devices", + ), +] + + +@dataclass +class BlacklistCheckResult: + """Result of checking a command against the blacklist.""" + + is_blocked: bool + matched_entry: Optional[BlacklistEntry] = None + feedback: str = "" + + +def check_command_blacklist(command: str) -> BlacklistCheckResult: + """Check if a command matches any blacklisted pattern. + + Args: + command: The bash command to check. + + Returns: + BlacklistCheckResult with is_blocked=True and feedback if blocked, + or is_blocked=False if the command is allowed. + """ + if not BLACKLIST_ENABLED: + logger.debug("Command blacklist is disabled via environment variable") + return BlacklistCheckResult(is_blocked=False) + + # Normalize the command for matching + normalized_command = command.strip() + + for entry in COMMAND_BLACKLIST: + if not entry.enabled: + continue + + try: + if re.search(entry.pattern, normalized_command, re.IGNORECASE): + logger.warning( + f"Command blocked by blacklist: {normalized_command!r}\n" + f"Matched pattern: {entry.pattern}\n" + f"Description: {entry.description}" + ) + return BlacklistCheckResult( + is_blocked=True, + matched_entry=entry, + feedback=entry.feedback, + ) + except re.error as e: + logger.error(f"Invalid regex pattern in blacklist: {entry.pattern}: {e}") + continue + + return BlacklistCheckResult(is_blocked=False) diff --git a/scripts/test_command_blacklist.py b/scripts/test_command_blacklist.py new file mode 100644 index 000000000000..032d123723f0 --- /dev/null +++ b/scripts/test_command_blacklist.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +"""Standalone test script for command_blacklist.py + +Run this script directly to test all blacklist patterns: + python test_command_blacklist.py + +Or run with verbose output: + python test_command_blacklist.py -v +""" + +import re +import sys +from dataclasses import dataclass + + +# ============================================================================ +# Inline copy of the blacklist logic (for standalone testing without imports) +# ============================================================================ + + +@dataclass +class BlacklistEntry: + """A blacklisted command pattern with associated feedback.""" + + pattern: str + feedback: str + description: str + enabled: bool = True + + +COMMAND_BLACKLIST: list[BlacklistEntry] = [ + # === Block ALL killall commands === + BlacklistEntry( + pattern=r"\bkillall\b", + feedback="ERROR: The `killall` command is not allowed.", + description="Blocks all killall commands", + ), + # === Block ALL pkill commands === + BlacklistEntry( + pattern=r"\bpkill\b", + feedback="ERROR: The `pkill` command is not allowed.", + description="Blocks all pkill commands", + ), + # === Block kill with command substitution $(...) === + BlacklistEntry( + pattern=r"\bkill\s+.*\$\(", + feedback="ERROR: The `kill` command with command substitution is not allowed.", + description="Blocks kill with command substitution $()", + ), + # === Block kill with backtick command substitution === + BlacklistEntry( + pattern=r"\bkill\s+.*`", + feedback="ERROR: The `kill` command with command substitution is not allowed.", + description="Blocks kill with backtick command substitution", + ), + # === Block kill with variables === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*\$\w+", + feedback="ERROR: The `kill` command with shell variables is not allowed.", + description="Blocks kill with shell variables", + ), + # === Block kill -1 or kill -9 -1 === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*-1\b", + feedback="ERROR: This command would kill all processes you own.", + description="Blocks kill -1 which kills all user processes", + ), + # === Block kill 0 === + BlacklistEntry( + pattern=r"\bkill\s+(-\d+\s+|-[A-Z]+\s+|-SIG[A-Z]+\s+)*0\b", + feedback="ERROR: This command would kill all processes in the current process group.", + description="Blocks kill 0 which kills the process group", + ), + # === Block kill with negative PIDs === + BlacklistEntry( + pattern=r"\bkill\s+(-[1-9]|-1[0-5]|-[A-Z]+|-SIG[A-Z]+)?\s*-([2-9]\d\d+|[1-9]\d\d\d+)\s*$", + feedback="ERROR: The `kill` command with negative PIDs is not allowed.", + description="Blocks kill with negative PIDs (process groups)", + ), + # === Dangerous rm commands === + BlacklistEntry( + pattern=r"\brm\s+(-\w+\s+)*(/\s*$|/\*)", + feedback="ERROR: This command would destroy the entire filesystem.", + description="Blocks rm -rf / or rm -rf /*", + ), + BlacklistEntry( + pattern=r"\brm\s+(-\w+\s+)*(/(bin|usr|etc|var|home|root|opt|lib|lib64|sbin|boot|dev|proc|sys))\s*$", + feedback="ERROR: This command would delete critical system directories.", + description="Blocks rm of critical system directories", + ), + # === tmux kill commands === + BlacklistEntry( + pattern=r"\btmux\s+(kill-server|kill-session\s+-t\s+openhands)", + feedback="ERROR: This command would terminate the OpenHands tmux session.", + description="Blocks killing the openhands tmux session", + ), + # === Shutdown/reboot commands === + BlacklistEntry( + pattern=r"\b(shutdown|reboot|poweroff|halt|init\s+[06])\b", + feedback="ERROR: System shutdown/reboot commands are blocked.", + description="Blocks system shutdown/reboot commands", + ), + # === Dangerous dd commands === + BlacklistEntry( + pattern=r"\bdd\s+.*of=\s*(/dev/sd[a-z]|/dev/nvme\w*|/dev/hd[a-z]|/dev/null)\b", + feedback="ERROR: This dd command could overwrite disk devices.", + description="Blocks dd commands that write to disk devices", + ), +] + + +def check_command(command: str) -> tuple[bool, str]: + """Check if a command is blocked. Returns (is_blocked, description).""" + for entry in COMMAND_BLACKLIST: + if not entry.enabled: + continue + try: + if re.search(entry.pattern, command.strip(), re.IGNORECASE): + return True, entry.description + except re.error: + continue + return False, "" + + +# ============================================================================ +# Test Cases +# ============================================================================ + +# Format: (command, should_be_blocked, description) +TEST_CASES = [ + # === ALLOWED kill commands (specific PIDs) === + ("kill 12345", False, "kill with specific PID"), + ("kill -9 12345", False, "kill -9 with specific PID"), + ("kill -15 12345", False, "kill -15 with specific PID"), + ("kill -TERM 12345", False, "kill -TERM with specific PID"), + ("kill -SIGTERM 12345", False, "kill -SIGTERM with specific PID"), + ("kill -9 12345 67890", False, "kill multiple specific PIDs"), + ("kill %1", False, "kill job spec (allowed)"), + ("kill %2", False, "kill job spec %2 (allowed)"), + # === BLOCKED killall commands === + ("killall python", True, "killall python"), + ("killall -9 python", True, "killall -9 python"), + ("killall uvicorn", True, "killall uvicorn"), + ("killall node", True, "killall node"), + ("killall -TERM myprocess", True, "killall with signal"), + # === BLOCKED pkill commands === + ("pkill python", True, "pkill python"), + ("pkill -9 python", True, "pkill -9 python"), + ("pkill -f python", True, "pkill -f python"), + ("pkill -f 'my script'", True, "pkill -f with pattern"), + ("pkill uvicorn", True, "pkill uvicorn"), + ("pkill -f server", True, "pkill -f server"), + # === BLOCKED kill with command substitution === + ("kill $(pgrep python)", True, "kill with $() substitution"), + ("kill -9 $(pgrep python)", True, "kill -9 with $() substitution"), + ("kill $(cat /tmp/pid)", True, "kill with $() reading file"), + ("kill `pgrep python`", True, "kill with backtick substitution"), + ("kill -9 `pgrep python`", True, "kill -9 with backtick substitution"), + # === BLOCKED kill with variables === + ("kill $PID", True, "kill with variable"), + ("kill -9 $PID", True, "kill -9 with variable"), + ("kill $my_pid", True, "kill with underscore variable"), + ("kill -TERM $PID", True, "kill -TERM with variable"), + # === BLOCKED kill -1 (all processes) === + ("kill -1", True, "kill -1"), + ("kill -9 -1", True, "kill -9 -1"), + ("kill -TERM -1", True, "kill -TERM -1"), + ("kill -SIGKILL -1", True, "kill -SIGKILL -1"), + # === BLOCKED kill 0 (process group) === + ("kill 0", True, "kill 0"), + ("kill -9 0", True, "kill -9 0"), + ("kill -TERM 0", True, "kill -TERM 0"), + # === BLOCKED kill with negative PIDs (as target, not signal) === + ("kill -12345", True, "kill -12345 (negative PID as target)"), + ("kill -9 -12345", True, "kill -9 -12345 (negative PID as target)"), + ("kill -200", True, "kill -200 (negative PID as target)"), + # === BLOCKED rm commands === + ("rm -rf /", True, "rm -rf /"), + ("rm -rf /*", True, "rm -rf /*"), + ("rm -r /", True, "rm -r /"), + ("rm /bin", True, "rm /bin"), + ("rm -rf /usr", True, "rm -rf /usr"), + ("rm -rf /etc", True, "rm -rf /etc"), + ("rm -rf /var", True, "rm -rf /var"), + ("rm /home", True, "rm /home"), + ("rm -rf /root", True, "rm -rf /root"), + # === ALLOWED rm commands === + ("rm -rf /tmp/mydir", False, "rm -rf /tmp/mydir (allowed)"), + ("rm file.txt", False, "rm file.txt (allowed)"), + ("rm -rf ./build", False, "rm -rf ./build (allowed)"), + ("rm -rf /workspace/project", False, "rm -rf /workspace/project (allowed)"), + # === BLOCKED tmux commands === + ("tmux kill-server", True, "tmux kill-server"), + ("tmux kill-session -t openhands", True, "tmux kill-session -t openhands"), + # === ALLOWED tmux commands === + ( + "tmux kill-session -t mysession", + False, + "tmux kill-session -t mysession (allowed)", + ), + ("tmux new-session", False, "tmux new-session (allowed)"), + ("tmux list-sessions", False, "tmux list-sessions (allowed)"), + # === BLOCKED shutdown/reboot commands === + ("shutdown", True, "shutdown"), + ("shutdown -h now", True, "shutdown -h now"), + ("reboot", True, "reboot"), + ("poweroff", True, "poweroff"), + ("halt", True, "halt"), + ("init 0", True, "init 0"), + ("init 6", True, "init 6"), + # === BLOCKED dd commands === + ("dd if=/dev/zero of=/dev/sda", True, "dd to /dev/sda"), + ("dd if=/dev/zero of=/dev/nvme0n1", True, "dd to /dev/nvme"), + ("dd if=/dev/zero of=/dev/null", True, "dd to /dev/null"), + # === ALLOWED dd commands === + ( + "dd if=/dev/zero of=./testfile bs=1M count=10", + False, + "dd to regular file (allowed)", + ), + # === Other allowed commands === + ("ls -la", False, "ls -la (allowed)"), + ("ps aux", False, "ps aux (allowed)"), + ("ps aux | grep python", False, "ps aux | grep python (allowed)"), + ("cat /etc/passwd", False, "cat /etc/passwd (allowed)"), + ("echo hello", False, "echo hello (allowed)"), + ("python script.py", False, "python script.py (allowed)"), +] + + +# ============================================================================ +# Test Runner +# ============================================================================ + + +def run_tests(verbose: bool = False) -> tuple[int, int, list]: + """Run all test cases and return (passed, failed, failures).""" + passed = 0 + failed = 0 + failures = [] + + for command, should_block, description in TEST_CASES: + is_blocked, matched_desc = check_command(command) + + if is_blocked == should_block: + passed += 1 + if verbose: + status = "BLOCKED" if is_blocked else "ALLOWED" + print(f" ✓ {status}: {command!r} - {description}") + else: + failed += 1 + expected = "BLOCKED" if should_block else "ALLOWED" + actual = "BLOCKED" if is_blocked else "ALLOWED" + failures.append( + { + "command": command, + "description": description, + "expected": expected, + "actual": actual, + "matched_rule": matched_desc if is_blocked else None, + } + ) + if verbose: + print(f" ✗ FAIL: {command!r}") + print(f" Expected: {expected}, Got: {actual}") + if is_blocked: + print(f" Matched: {matched_desc}") + + return passed, failed, failures + + +def main(): + verbose = "-v" in sys.argv or "--verbose" in sys.argv + + print("=" * 70) + print("Command Blacklist Test Suite") + print("=" * 70) + print() + + # Show all blacklist rules + print("Blacklist Rules:") + print("-" * 70) + for i, entry in enumerate(COMMAND_BLACKLIST, 1): + print(f" {i}. {entry.description}") + if verbose: + print(f" Pattern: {entry.pattern}") + print() + + # Run tests + print("Running Tests:") + print("-" * 70) + passed, failed, failures = run_tests(verbose) + print() + + # Summary + print("=" * 70) + print(f"Results: {passed} passed, {failed} failed") + print("=" * 70) + + if failures: + print() + print("FAILURES:") + print("-" * 70) + for f in failures: + print(f" Command: {f['command']!r}") + print(f" Description: {f['description']}") + print(f" Expected: {f['expected']}, Got: {f['actual']}") + if f["matched_rule"]: + print(f" Matched Rule: {f['matched_rule']}") + print() + + # Exit with appropriate code + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() From f12c3fbd66de8a3c4fedb976cdc7b1949b13f2cd Mon Sep 17 00:00:00 2001 From: Sugam Devare Date: Wed, 21 Jan 2026 20:38:43 -0800 Subject: [PATCH 3/5] feat: add atomic write for traj json --- openhands/llm/llm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 2198f71cc53c..dee7c3a05c80 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -7,6 +7,7 @@ import httpx import uuid +import tempfile from openhands.core.config import LLMConfig @@ -467,8 +468,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Save fncall_messages/response separately _d['fncall_messages'] = original_fncall_messages _d['fncall_response'] = resp - with open(log_file, 'w') as f: + temp_fd, temp_path = tempfile.mkstemp(dir=os.path.dirname(log_file)) + with os.fdopen(temp_fd, 'w') as f: f.write(json.dumps(_d)) + os.replace(temp_path, log_file) return resp From dfd04f41c9af452a9a230c7378699c6119bcb2db Mon Sep 17 00:00:00 2001 From: Sugam Devare Date: Fri, 23 Jan 2026 11:03:04 -0800 Subject: [PATCH 4/5] feat: add tmux memory monitor --- evaluation/benchmarks/swe_bench/run_infer.py | 64 ++++++++++ .../benchmarks/swe_bench/scripts/run_infer.sh | 2 + .../codeact_agent/function_calling.py | 4 +- .../agenthub/codeact_agent/tools/bash.py | 5 +- openhands/runtime/utils/bash.py | 112 +++++++++++++++++- 5 files changed, 183 insertions(+), 4 deletions(-) diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index ae027fc45833..2410db6e579e 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -636,12 +636,76 @@ def complete_runtime( return {'git_patch': git_patch} +def _has_existing_result(eval_output_dir: str, instance_id: str) -> tuple[bool, dict | None]: + completions_dir = os.path.join(eval_output_dir, 'llm_completions', instance_id) + has_completions = False + if os.path.exists(completions_dir): + json_files = [f for f in os.listdir(completions_dir) if f.endswith('.json')] + has_completions = len(json_files) > 0 + + if not has_completions: + return False, None + + output_file = os.path.join(eval_output_dir, 'output.jsonl') + existing_result = None + if os.path.exists(output_file): + try: + with open(output_file, 'r') as f: + for line in f: + try: + result = json.loads(line.strip()) + if result.get('instance_id') == instance_id: + git_patch = result.get('test_result', {}).get('git_patch', '') + if git_patch and git_patch.strip(): + existing_result = result + break + except json.JSONDecodeError: + continue + except Exception as e: + logger.warning(f'Error reading output file for existing result: {e}') + + if has_completions: + return True, existing_result + + return False, None + + def process_instance( instance: pd.Series, metadata: EvalMetadata, reset_logger: bool = True, runtime_failure_count: int = 0, ) -> EvalOutput: + + should_skip, existing_result = _has_existing_result(metadata.eval_output_dir, instance.instance_id) + if should_skip: + if existing_result: + return EvalOutput( + instance_id=existing_result.get('instance_id', instance.instance_id), + instruction=existing_result.get('instruction', ''), + instance=existing_result.get('instance', instance.to_dict()), + test_result=existing_result.get('test_result', {}), + metadata=metadata, + history=existing_result.get('history', []), + metrics=existing_result.get('metrics', {}), + error=existing_result.get('error'), + ) + else: + return EvalOutput( + instance_id=instance.instance_id, + instruction='', + instance=instance.to_dict(), + test_result={ + 'git_patch': '', + 'skipped': True, + 'skip_reason': 'completions_exist_no_result', + }, + metadata=metadata, + history=[], + metrics={}, + error=None, + ) + config = get_config(instance, metadata) # Setup the logger properly, so you can run multi-processing to parallelize the evaluation diff --git a/evaluation/benchmarks/swe_bench/scripts/run_infer.sh b/evaluation/benchmarks/swe_bench/scripts/run_infer.sh index fd57a612fdd5..44862bf73081 100755 --- a/evaluation/benchmarks/swe_bench/scripts/run_infer.sh +++ b/evaluation/benchmarks/swe_bench/scripts/run_infer.sh @@ -133,6 +133,8 @@ echo "EVAL_CONDENSER: $EVAL_CONDENSER" echo "EVAL_OUTPUT_DIR: $EVAL_OUTPUT_DIR" echo "SELECTED_ID: $SELECTED_ID" echo "INSTANCE_DICT_PATH: $INSTANCE_DICT_PATH" +echo "TMUX_MEMORY_LIMIT: $TMUX_MEMORY_LIMIT" +echo "COMMAND_EXEC_TIMEOUT: $COMMAND_EXEC_TIMEOUT" # Default to NOT use Hint if [ -z "$USE_HINT_TEXT" ]; then diff --git a/openhands/agenthub/codeact_agent/function_calling.py b/openhands/agenthub/codeact_agent/function_calling.py index b29d47d85184..e8d9a0df17d6 100644 --- a/openhands/agenthub/codeact_agent/function_calling.py +++ b/openhands/agenthub/codeact_agent/function_calling.py @@ -4,6 +4,7 @@ """ import json +from os import getenv from litellm import ( ModelResponse, @@ -125,7 +126,8 @@ def response_to_actions( # Set hard timeout if provided (capped at 600 seconds max) if 'timeout' in arguments: try: - action.set_hard_timeout(min(float(arguments['timeout']), 600)) + command_execution_timeout = int(getenv("COMMAND_EXEC_TIMEOUT", "300")) + action.set_hard_timeout(min(float(arguments['timeout']), command_execution_timeout)) except ValueError as e: raise FunctionCallValidationError( f"Invalid float passed to 'timeout' argument: {arguments['timeout']}" diff --git a/openhands/agenthub/codeact_agent/tools/bash.py b/openhands/agenthub/codeact_agent/tools/bash.py index 855a5594eca2..f3edd253b03c 100644 --- a/openhands/agenthub/codeact_agent/tools/bash.py +++ b/openhands/agenthub/codeact_agent/tools/bash.py @@ -1,3 +1,4 @@ +from os import getenv from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk from openhands.agenthub.codeact_agent.tools.prompt import refine_prompt @@ -45,6 +46,8 @@ def create_cmd_run_tool( description = ( _SHORT_BASH_DESCRIPTION if use_short_description else _DETAILED_BASH_DESCRIPTION ) + command_execution_timeout = int(getenv("COMMAND_EXEC_TIMEOUT", "300")) + return ChatCompletionToolParam( type='function', function=ChatCompletionToolParamFunctionChunk( @@ -68,7 +71,7 @@ def create_cmd_run_tool( }, 'timeout': { 'type': 'number', - 'description': 'Optional. Sets a hard timeout in seconds for the command execution. If not provided, the command will use the default soft timeout behavior. Max value is 600 seconds.', + 'description': f'Optional. Sets a hard timeout in seconds for the command execution. If not provided, the command will use the default soft timeout behavior. Max timeout allowed is {command_execution_timeout} seconds.', }, 'security_risk': { 'type': 'string', diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index aa231f945b28..e9426daf5d76 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -1,13 +1,16 @@ import os import re +import signal import time import uuid from enum import Enum -from typing import Any +from typing import Any, Set import bashlex import libtmux - +import psutil +import threading +import logging from openhands.core.logger import openhands_logger as logger from openhands.events.action import CmdRunAction from openhands.events.observation import ErrorObservation @@ -31,6 +34,102 @@ ) +class TmuxMemoryMonitor(threading.Thread): + def __init__(self, tmux_server, limit_mb, interval=2.0): + super().__init__(daemon=True) + self.server = tmux_server + self.limit_mb = limit_mb + self.limit_bytes = limit_mb * 1024 * 1024 + self.interval = interval + self.running = True + print(f"[TmuxMemoryMonitor] initialized with limit: {self.limit_mb} MB", flush=True) + + def _get_server_pid(self): + try: + pid_str = self.server.cmd("display-message", "-p", "#{pid}").stdout[0] + return int(pid_str) + except (IndexError, ValueError, Exception): + return None + + def get_tree_memory(self, parent_pid): + total_mem = 0 + try: + parent = psutil.Process(parent_pid) + procs = [parent] + parent.children(recursive=True) + for p in procs: + try: + total_mem += p.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except psutil.NoSuchProcess: + return 0 + return total_mem + + def kill_inner_processes(self): + try: + # Loop through all sessions -> windows -> panes + for session in self.server.sessions: + for window in session.windows: + for pane in window.panes: + try: + # Get the PID of the process inside the pane + # #{pane_pid} is the PID of the shell or command running in the pane + pane_pid_str = pane.cmd( + "display-message", "-p", "#{pane_pid}" + ).stdout[0] + pane_pid = int(pane_pid_str) + + # Kill the process tree of that pane + parent = psutil.Process(pane_pid) + children = parent.children(recursive=True) + + print( + f"[TmuxMemoryMonitor] Killing pane {pane.id} (PID: {pane_pid})", + flush=True, + ) + + for child in children: + child.kill() + + parent.kill() + + except (psutil.NoSuchProcess, IndexError, ValueError): + continue + except Exception as e: + print(f"[TmuxMemoryMonitor] Error killing panes: {e}", flush=True) + + def run(self): + print(f"[TmuxMemoryMonitor] started. Limit: {self.limit_mb} MB", flush=True) + time.sleep(1) + + server_pid = self._get_server_pid() + if not server_pid: + print("[TmuxMemoryMonitor] Could not determine Tmux Server PID. Monitor aborting.", flush=True) + return + + while self.running: + used_bytes = self.get_tree_memory(server_pid) + used_mb = used_bytes / (1024 * 1024) + + if used_bytes > self.limit_bytes: + print( + f"[TmuxMemoryMonitor] MEMORY LIMIT EXCEEDED: {int(used_mb)}MB > {self.limit_mb}MB", + flush=True, + ) + print( + "[TmuxMemoryMonitor] KILLING PROCESSES INSIDE TMUX (Server stays alive)...", + flush=True, + ) + + self.kill_inner_processes() + time.sleep(10) + + time.sleep(self.interval) + + def stop(self): + self.running = False + + def split_bash_commands(commands: str) -> list[str]: if not commands.strip(): return [""] @@ -199,6 +298,7 @@ def __init__( self.username = username self._initialized = False self.max_memory_mb = max_memory_mb + self.memory_monitor = None def initialize(self) -> None: self.server = libtmux.Server() @@ -230,6 +330,10 @@ def initialize(self) -> None: y=1000, ) + tmux_memory_limit = int(os.getenv("TMUX_MEMORY_LIMIT", "32768")) + self.memory_monitor = TmuxMemoryMonitor(self.server, limit_mb=tmux_memory_limit) + self.memory_monitor.start() + # Set history limit to a large number to avoid losing history # https://unix.stackexchange.com/questions/43414/unlimited-history-in-tmux self.session.set_option("history-limit", str(self.HISTORY_LIMIT), global_=True) @@ -264,6 +368,8 @@ def initialize(self) -> None: def __del__(self) -> None: """Ensure the session is closed when the object is destroyed.""" + if self.memory_monitor: + self.memory_monitor.stop() self.close() def _get_pane_content(self) -> str: @@ -279,6 +385,8 @@ def _get_pane_content(self) -> str: def close(self) -> None: """Clean up the session.""" + if self.memory_monitor: + self.memory_monitor.stop() if self._closed: return self.session.kill() From d6bd12edd828a6af88f8036583a6a8e209315eeb Mon Sep 17 00:00:00 2001 From: Sugam Devare Date: Mon, 2 Feb 2026 16:27:21 -0800 Subject: [PATCH 5/5] feat: add validation failure action --- .../codeact_agent/function_calling.py | 413 +++++++++--------- openhands/core/schema/action.py | 3 + openhands/core/schema/observation.py | 3 + openhands/events/action/__init__.py | 2 + openhands/events/action/agent.py | 26 ++ openhands/events/observation/__init__.py | 2 + openhands/events/observation/agent.py | 19 + openhands/runtime/base.py | 8 + .../action_execution_client.py | 8 + 9 files changed, 283 insertions(+), 201 deletions(-) diff --git a/openhands/agenthub/codeact_agent/function_calling.py b/openhands/agenthub/codeact_agent/function_calling.py index e8d9a0df17d6..8dc0f98d1d35 100644 --- a/openhands/agenthub/codeact_agent/function_calling.py +++ b/openhands/agenthub/codeact_agent/function_calling.py @@ -40,6 +40,7 @@ IPythonRunCellAction, MessageAction, TaskTrackingAction, + ValidationFailureAction, ) from openhands.events.action.agent import CondensationRequestAction from openhands.events.action.mcp import MCPAction @@ -103,226 +104,236 @@ def response_to_actions( for i, tool_call in enumerate(assistant_msg.tool_calls): action: Action logger.debug(f'Tool call in function_calling.py: {tool_call}') + try: - arguments = json.loads(tool_call.function.arguments) - except json.decoder.JSONDecodeError as e: - raise FunctionCallValidationError( - f'Failed to parse tool call arguments: {tool_call.function.arguments}' - ) from e - - # ================================================ - # CmdRunTool (Bash) - # ================================================ - - if tool_call.function.name == create_cmd_run_tool()['function']['name']: - if 'command' not in arguments: + try: + arguments = json.loads(tool_call.function.arguments) + except json.decoder.JSONDecodeError as e: raise FunctionCallValidationError( - f'Missing required argument "command" in tool call {tool_call.function.name}' - ) - # convert is_input to boolean - is_input = arguments.get('is_input', 'false') == 'true' - action = CmdRunAction(command=arguments['command'], is_input=is_input) - - # Set hard timeout if provided (capped at 600 seconds max) - if 'timeout' in arguments: - try: - command_execution_timeout = int(getenv("COMMAND_EXEC_TIMEOUT", "300")) - action.set_hard_timeout(min(float(arguments['timeout']), command_execution_timeout)) - except ValueError as e: - raise FunctionCallValidationError( - f"Invalid float passed to 'timeout' argument: {arguments['timeout']}" - ) from e - set_security_risk(action, arguments) - - # ================================================ - # IPythonTool (Jupyter) - # ================================================ - elif tool_call.function.name == IPythonTool['function']['name']: - if 'code' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "code" in tool call {tool_call.function.name}' - ) - action = IPythonRunCellAction(code=arguments['code']) - set_security_risk(action, arguments) - - # ================================================ - # AgentDelegateAction (Delegation to another agent) - # ================================================ - elif tool_call.function.name == 'delegate_to_browsing_agent': - action = AgentDelegateAction( - agent='BrowsingAgent', - inputs=arguments, - ) + f'Failed to parse tool call arguments: {tool_call.function.arguments}' + ) from e - # ================================================ - # AgentFinishAction - # ================================================ - elif tool_call.function.name == FinishTool['function']['name']: - action = AgentFinishAction( - final_thought=arguments.get('message', ''), - ) + # ================================================ + # CmdRunTool (Bash) + # ================================================ - # ================================================ - # LLMBasedFileEditTool (LLM-based file editor, deprecated) - # ================================================ - elif tool_call.function.name == LLMBasedFileEditTool['function']['name']: - if 'path' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "path" in tool call {tool_call.function.name}' - ) - if 'content' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "content" in tool call {tool_call.function.name}' - ) - action = FileEditAction( - path=arguments['path'], - content=arguments['content'], - start=arguments.get('start', 1), - end=arguments.get('end', -1), - impl_source=arguments.get( - 'impl_source', FileEditSource.LLM_BASED_EDIT - ), - ) - elif ( - tool_call.function.name - == create_str_replace_editor_tool()['function']['name'] - ): - if 'command' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "command" in tool call {tool_call.function.name}' - ) - if 'path' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "path" in tool call {tool_call.function.name}' + if tool_call.function.name == create_cmd_run_tool()['function']['name']: + if 'command' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "command" in tool call {tool_call.function.name}' + ) + # convert is_input to boolean + is_input = arguments.get('is_input', 'false') == 'true' + action = CmdRunAction(command=arguments['command'], is_input=is_input) + + # Set hard timeout if provided (capped at 600 seconds max) + if 'timeout' in arguments: + try: + command_execution_timeout = int(getenv("COMMAND_EXEC_TIMEOUT", "300")) + action.set_hard_timeout(min(float(arguments['timeout']), command_execution_timeout)) + except ValueError as e: + raise FunctionCallValidationError( + f"Invalid float passed to 'timeout' argument: {arguments['timeout']}" + ) from e + set_security_risk(action, arguments) + + # ================================================ + # IPythonTool (Jupyter) + # ================================================ + elif tool_call.function.name == IPythonTool['function']['name']: + if 'code' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "code" in tool call {tool_call.function.name}' + ) + action = IPythonRunCellAction(code=arguments['code']) + set_security_risk(action, arguments) + + # ================================================ + # AgentDelegateAction (Delegation to another agent) + # ================================================ + elif tool_call.function.name == 'delegate_to_browsing_agent': + action = AgentDelegateAction( + agent='BrowsingAgent', + inputs=arguments, ) - path = arguments['path'] - command = arguments['command'] - other_kwargs = { - k: v for k, v in arguments.items() if k not in ['command', 'path'] - } - - if command == 'view': - action = FileReadAction( - path=path, - impl_source=FileReadSource.OH_ACI, - view_range=other_kwargs.get('view_range', None), + + # ================================================ + # AgentFinishAction + # ================================================ + elif tool_call.function.name == FinishTool['function']['name']: + action = AgentFinishAction( + final_thought=arguments.get('message', ''), ) - else: - if 'view_range' in other_kwargs: - # Remove view_range from other_kwargs since it is not needed for FileEditAction - other_kwargs.pop('view_range') - - # Filter out unexpected arguments - valid_kwargs_for_editor = {} - # Get valid parameters from the str_replace_editor tool definition - str_replace_editor_tool = create_str_replace_editor_tool() - valid_params = set( - str_replace_editor_tool['function']['parameters'][ - 'properties' - ].keys() + + # ================================================ + # LLMBasedFileEditTool (LLM-based file editor, deprecated) + # ================================================ + elif tool_call.function.name == LLMBasedFileEditTool['function']['name']: + if 'path' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "path" in tool call {tool_call.function.name}' + ) + if 'content' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "content" in tool call {tool_call.function.name}' + ) + action = FileEditAction( + path=arguments['path'], + content=arguments['content'], + start=arguments.get('start', 1), + end=arguments.get('end', -1), + impl_source=arguments.get( + 'impl_source', FileEditSource.LLM_BASED_EDIT + ), ) + elif ( + tool_call.function.name + == create_str_replace_editor_tool()['function']['name'] + ): + if 'command' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "command" in tool call {tool_call.function.name}' + ) + if 'path' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "path" in tool call {tool_call.function.name}' + ) + path = arguments['path'] + command = arguments['command'] + other_kwargs = { + k: v for k, v in arguments.items() if k not in ['command', 'path'] + } + + if command == 'view': + action = FileReadAction( + path=path, + impl_source=FileReadSource.OH_ACI, + view_range=other_kwargs.get('view_range', None), + ) + else: + if 'view_range' in other_kwargs: + # Remove view_range from other_kwargs since it is not needed for FileEditAction + other_kwargs.pop('view_range') + + # Filter out unexpected arguments + valid_kwargs_for_editor = {} + # Get valid parameters from the str_replace_editor tool definition + str_replace_editor_tool = create_str_replace_editor_tool() + valid_params = set( + str_replace_editor_tool['function']['parameters'][ + 'properties' + ].keys() + ) + + for key, value in other_kwargs.items(): + if key in valid_params: + # security_risk is valid but should NOT be part of editor kwargs + if key != 'security_risk': + valid_kwargs_for_editor[key] = value + else: + raise FunctionCallValidationError( + f'Unexpected argument {key} in tool call {tool_call.function.name}. Allowed arguments are: {valid_params}' + ) + + action = FileEditAction( + path=path, + command=command, + impl_source=FileEditSource.OH_ACI, + **valid_kwargs_for_editor, + ) + + set_security_risk(action, arguments) + # ================================================ + # AgentThinkAction + # ================================================ + elif tool_call.function.name == ThinkTool['function']['name']: + action = AgentThinkAction(thought=arguments.get('thought', '')) + + # ================================================ + # CondensationRequestAction + # ================================================ + elif tool_call.function.name == CondensationRequestTool['function']['name']: + action = CondensationRequestAction() + + # ================================================ + # BrowserTool + # ================================================ + elif tool_call.function.name == BrowserTool['function']['name']: + if 'code' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "code" in tool call {tool_call.function.name}' + ) + action = BrowseInteractiveAction(browser_actions=arguments['code']) + set_security_risk(action, arguments) + + # ================================================ + # TaskTrackingAction + # ================================================ + elif tool_call.function.name == TASK_TRACKER_TOOL_NAME: + if 'command' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "command" in tool call {tool_call.function.name}' + ) + if arguments['command'] == 'plan' and 'task_list' not in arguments: + raise FunctionCallValidationError( + f'Missing required argument "task_list" for "plan" command in tool call {tool_call.function.name}' + ) + + raw_task_list = arguments.get('task_list', []) + if not isinstance(raw_task_list, list): + raise FunctionCallValidationError( + f'Invalid format for "task_list". Expected a list but got {type(raw_task_list)}.' + ) - for key, value in other_kwargs.items(): - if key in valid_params: - # security_risk is valid but should NOT be part of editor kwargs - if key != 'security_risk': - valid_kwargs_for_editor[key] = value + # Normalize task_list to ensure it's always a list of dictionaries + normalized_task_list = [] + for task_idx, task in enumerate(raw_task_list): + if isinstance(task, dict): + # Task is already in correct format, ensure required fields exist + normalized_task = { + 'id': task.get('id', f'task-{task_idx + 1}'), + 'title': task.get('title', 'Untitled task'), + 'status': task.get('status', 'todo'), + 'notes': task.get('notes', ''), + } else: + # Unexpected format, raise validation error + logger.warning( + f'Unexpected task format in task_list: {type(task)} - {task}' + ) raise FunctionCallValidationError( - f'Unexpected argument {key} in tool call {tool_call.function.name}. Allowed arguments are: {valid_params}' + f'Unexpected task format in task_list: {type(task)}. Each task should be a dictionary.' ) + normalized_task_list.append(normalized_task) - action = FileEditAction( - path=path, - command=command, - impl_source=FileEditSource.OH_ACI, - **valid_kwargs_for_editor, + action = TaskTrackingAction( + command=arguments['command'], + task_list=normalized_task_list, ) - set_security_risk(action, arguments) - # ================================================ - # AgentThinkAction - # ================================================ - elif tool_call.function.name == ThinkTool['function']['name']: - action = AgentThinkAction(thought=arguments.get('thought', '')) - - # ================================================ - # CondensationRequestAction - # ================================================ - elif tool_call.function.name == CondensationRequestTool['function']['name']: - action = CondensationRequestAction() - - # ================================================ - # BrowserTool - # ================================================ - elif tool_call.function.name == BrowserTool['function']['name']: - if 'code' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "code" in tool call {tool_call.function.name}' - ) - action = BrowseInteractiveAction(browser_actions=arguments['code']) - set_security_risk(action, arguments) - - # ================================================ - # TaskTrackingAction - # ================================================ - elif tool_call.function.name == TASK_TRACKER_TOOL_NAME: - if 'command' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "command" in tool call {tool_call.function.name}' - ) - if arguments['command'] == 'plan' and 'task_list' not in arguments: - raise FunctionCallValidationError( - f'Missing required argument "task_list" for "plan" command in tool call {tool_call.function.name}' + # ================================================ + # MCPAction (MCP) + # ================================================ + elif mcp_tool_names and tool_call.function.name in mcp_tool_names: + action = MCPAction( + name=tool_call.function.name, + arguments=arguments, ) - - raw_task_list = arguments.get('task_list', []) - if not isinstance(raw_task_list, list): - raise FunctionCallValidationError( - f'Invalid format for "task_list". Expected a list but got {type(raw_task_list)}.' + else: + raise FunctionCallNotExistsError( + f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.' ) - # Normalize task_list to ensure it's always a list of dictionaries - normalized_task_list = [] - for i, task in enumerate(raw_task_list): - if isinstance(task, dict): - # Task is already in correct format, ensure required fields exist - normalized_task = { - 'id': task.get('id', f'task-{i + 1}'), - 'title': task.get('title', 'Untitled task'), - 'status': task.get('status', 'todo'), - 'notes': task.get('notes', ''), - } - else: - # Unexpected format, raise validation error - logger.warning( - f'Unexpected task format in task_list: {type(task)} - {task}' - ) - raise FunctionCallValidationError( - f'Unexpected task format in task_list: {type(task)}. Each task should be a dictionary.' - ) - normalized_task_list.append(normalized_task) - - action = TaskTrackingAction( - command=arguments['command'], - task_list=normalized_task_list, - ) - - # ================================================ - # MCPAction (MCP) - # ================================================ - elif mcp_tool_names and tool_call.function.name in mcp_tool_names: - action = MCPAction( - name=tool_call.function.name, - arguments=arguments, - ) - else: - raise FunctionCallNotExistsError( - f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.' + except FunctionCallValidationError as e: + # Convert validation errors to ValidationFailureAction instead of raising + action = ValidationFailureAction( + function_name=tool_call.function.name, + error_message=str(e), + thought=thought if i == 0 else '', ) - # We only add thought to the first action - if i == 0: + # We only add thought to the first action (if not already added via ValidationFailureAction) + if i == 0 and not isinstance(action, ValidationFailureAction): action = combine_thought(action, thought) # Add metadata for tool calling action.tool_call_metadata = ToolCallMetadata( diff --git a/openhands/core/schema/action.py b/openhands/core/schema/action.py index 168689b0f9bf..2aea98c7ff24 100644 --- a/openhands/core/schema/action.py +++ b/openhands/core/schema/action.py @@ -100,3 +100,6 @@ class ActionType(str, Enum): LOOP_RECOVERY = 'loop_recovery' """Recover dead loop.""" + + VALIDATION_FAILURE = 'validation_failure' + """Represents a validation failure for a function call.""" diff --git a/openhands/core/schema/observation.py b/openhands/core/schema/observation.py index 3f0c71052c51..7f976fca3d8f 100644 --- a/openhands/core/schema/observation.py +++ b/openhands/core/schema/observation.py @@ -61,3 +61,6 @@ class ObservationType(str, Enum): LOOP_DETECTION = 'loop_detection' """Results of a dead-loop detection""" + + VALIDATION_FAILURE = 'validation_failure' + """Result of a validation failure for a function call""" diff --git a/openhands/events/action/__init__.py b/openhands/events/action/__init__.py index 4731d68e9ee1..19b538c828e6 100644 --- a/openhands/events/action/__init__.py +++ b/openhands/events/action/__init__.py @@ -12,6 +12,7 @@ LoopRecoveryAction, RecallAction, TaskTrackingAction, + ValidationFailureAction, ) from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction @@ -47,4 +48,5 @@ 'TaskTrackingAction', 'ActionSecurityRisk', 'LoopRecoveryAction', + 'ValidationFailureAction', ] diff --git a/openhands/events/action/agent.py b/openhands/events/action/agent.py index 08192fa7d964..d3c6c1d59ce1 100644 --- a/openhands/events/action/agent.py +++ b/openhands/events/action/agent.py @@ -240,3 +240,29 @@ class LoopRecoveryAction(Action): option: int = 1 action: str = ActionType.LOOP_RECOVERY + + +@dataclass +class ValidationFailureAction(Action): + """An action that represents a validation failure for a function call. + + This is returned when the LLM outputs an invalid function call (e.g., missing + required arguments, invalid argument values, malformed JSON). + + Attributes: + function_name: The name of the function/tool that failed validation. + error_message: The error message describing the validation failure. + thought: The agent's explanation of its actions. + action: The action type, namely ActionType.VALIDATION_FAILURE. + """ + + function_name: str = '' + error_message: str = '' + thought: str = '' + action: str = ActionType.VALIDATION_FAILURE + + @property + def message(self) -> str: + if self.function_name: + return f'Validation failure for {self.function_name}: {self.error_message}' + return f'Validation failure: {self.error_message}' diff --git a/openhands/events/observation/__init__.py b/openhands/events/observation/__init__.py index 144dd6fba193..9db0ee540516 100644 --- a/openhands/events/observation/__init__.py +++ b/openhands/events/observation/__init__.py @@ -4,6 +4,7 @@ AgentStateChangedObservation, AgentThinkObservation, RecallObservation, + ValidationFailureObservation, ) from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( @@ -52,4 +53,5 @@ 'MCPObservation', 'FileDownloadObservation', 'TaskTrackingObservation', + 'ValidationFailureObservation', ] diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py index 6c015d98b2b7..7ed974ce92c9 100644 --- a/openhands/events/observation/agent.py +++ b/openhands/events/observation/agent.py @@ -44,6 +44,25 @@ def message(self) -> str: return self.content +@dataclass +class ValidationFailureObservation(Observation): + """The output of a validation failure action. + + This is returned when the LLM outputs an invalid function call (e.g., missing + required arguments, invalid argument values, malformed JSON). + """ + + function_name: str = '' + error_message: str = '' + observation: str = ObservationType.VALIDATION_FAILURE + + @property + def message(self) -> str: + if self.function_name: + return f'Validation failure for {self.function_name}: {self.error_message}' + return f'Validation failure: {self.error_message}' + + @dataclass class MicroagentKnowledge: """Represents knowledge from a triggered microagent. diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index d8e173b4580c..58e99077692e 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -37,6 +37,7 @@ FileWriteAction, IPythonRunCellAction, TaskTrackingAction, + ValidationFailureAction, ) from openhands.events.action.mcp import MCPAction from openhands.events.event import Event @@ -49,6 +50,7 @@ Observation, TaskTrackingObservation, UserRejectObservation, + ValidationFailureObservation, ) from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS from openhands.integrations.provider import ( @@ -968,6 +970,12 @@ def run_action(self, action: Action) -> Observation: if not action.runnable: if isinstance(action, AgentThinkAction): return AgentThinkObservation('Your thought has been logged.') + elif isinstance(action, ValidationFailureAction): + return ValidationFailureObservation( + content=action.error_message, + function_name=action.function_name, + error_message=action.error_message, + ) elif isinstance(action, TaskTrackingAction): # Get the session-specific task file path conversation_dir = get_conversation_dir( diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 554a7bfd5be0..f1066ea113eb 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -29,6 +29,7 @@ FileReadAction, FileWriteAction, IPythonRunCellAction, + ValidationFailureAction, ) from openhands.events.action.action import Action from openhands.events.action.files import FileEditSource @@ -39,6 +40,7 @@ NullObservation, Observation, UserRejectObservation, + ValidationFailureObservation, ) from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS @@ -291,6 +293,12 @@ def send_action_for_execution(self, action: Action) -> Observation: if not action.runnable: if isinstance(action, AgentThinkAction): return AgentThinkObservation('Your thought has been logged.') + elif isinstance(action, ValidationFailureAction): + return ValidationFailureObservation( + content=action.error_message, + function_name=action.function_name, + error_message=action.error_message, + ) return NullObservation('') if ( hasattr(action, 'confirmation_state')