diff --git a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/api/functional.py b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/api/functional.py index 047e852e..9283ab3f 100644 --- a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/api/functional.py +++ b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/api/functional.py @@ -339,45 +339,6 @@ def kill_job_or_invocation(id: str) -> list[dict[str, Any]]: db = ExecutionDB() results = [] - def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]: - """Helper function to kill a single job.""" - try: - executor_cls = get_executor(job_data.executor) - if hasattr(executor_cls, "kill_job"): - executor_cls.kill_job(job_id) - # Success - job was killed - return { - "invocation": job_data.invocation_id, - "job_id": job_id, - "status": "killed", - "data": {"result": "Successfully killed job"}, - } - else: - return { - "invocation": job_data.invocation_id, - "job_id": job_id, - "status": "error", - "data": { - "error": f"Executor {job_data.executor} does not support killing jobs" - }, - } - except (ValueError, RuntimeError) as e: - # Expected errors from kill_job - return { - "invocation": job_data.invocation_id, - "job_id": job_id, - "status": "error", - "data": {"error": str(e)}, - } - except Exception as e: - # Unexpected errors - return { - "invocation": job_data.invocation_id, - "job_id": job_id, - "status": "error", - "data": {"error": f"Unexpected error: {str(e)}"}, - } - # Determine if this is a job ID or invocation ID if "." in id: # This is a job ID - kill single job @@ -391,11 +352,11 @@ def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]: "data": {}, } ] - results.append(kill_single_job(id, job_data)) + jobs_to_kill = {id: job_data} else: # This is an invocation ID - kill all jobs in the invocation - jobs = db.get_jobs(id) - if not jobs: + jobs_to_kill = db.get_jobs(id) + if not jobs_to_kill: return [ { "invocation": id, @@ -405,9 +366,84 @@ def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]: } ] - # Kill each job in the invocation - for job_id, job_data in jobs.items(): - results.append(kill_single_job(job_id, job_data)) + # Group jobs by executor to optimize kill operations + jobs_by_executor: dict[str, list[tuple[str, JobData]]] = {} + for job_id, job_data in jobs_to_kill.items(): + executor = job_data.executor + if executor not in jobs_by_executor: + jobs_by_executor[executor] = [] + jobs_by_executor[executor].append((job_id, job_data)) + + # Kill jobs grouped by executor (optimization: one call per executor) + for executor, job_list in jobs_by_executor.items(): + try: + executor_cls = get_executor(executor) + if not hasattr(executor_cls, "kill_jobs"): + # Executor doesn't support killing jobs + for job_id, job_data in job_list: + results.append( + { + "invocation": job_data.invocation_id, + "job_id": job_id, + "status": "error", + "data": { + "error": f"Executor {executor} does not support killing jobs" + }, + } + ) + continue + + # Extract job IDs for batch kill + job_ids = [job_id for job_id, _ in job_list] + + try: + # OPTIMIZATION: Kill all jobs for this executor in one call + executor_cls.kill_jobs(job_ids) + + # Success - all jobs were killed + for job_id, job_data in job_list: + results.append( + { + "invocation": job_data.invocation_id, + "job_id": job_id, + "status": "killed", + "data": {"result": "Successfully killed job"}, + } + ) + except (ValueError, RuntimeError) as e: + # Expected errors from kill_job - mark all jobs as error + # Note: kill_job may have killed some jobs before failing + for job_id, job_data in job_list: + results.append( + { + "invocation": job_data.invocation_id, + "job_id": job_id, + "status": "error", + "data": {"error": str(e)}, + } + ) + except Exception as e: + # Unexpected errors + for job_id, job_data in job_list: + results.append( + { + "invocation": job_data.invocation_id, + "job_id": job_id, + "status": "error", + "data": {"error": f"Unexpected error: {str(e)}"}, + } + ) + except ValueError as e: + # Error getting executor class + for job_id, job_data in job_list: + results.append( + { + "invocation": job_data.invocation_id, + "job_id": job_id, + "status": "error", + "data": {"error": str(e)}, + } + ) return results diff --git a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/base.py b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/base.py index 7e3ab844..51a7b373 100644 --- a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/base.py +++ b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/base.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Optional +from typing import Any, List, Optional from omegaconf import DictConfig @@ -81,15 +81,15 @@ def get_status(id: str) -> list[ExecutionStatus]: @staticmethod @abstractmethod - def kill_job(job_id: str) -> None: - """Kill a job by its ID. + def kill_jobs(job_ids: List[str]) -> None: + """Kill one or more jobs by their IDs. Args: - job_id: The job ID to kill. + job_ids: List of job IDs to kill. Raises: - ValueError: If job is not found or invalid. - RuntimeError: If job cannot be killed. + ValueError: If any job is not found or invalid. + RuntimeError: If any job cannot be killed. Raises: NotImplementedError: If not implemented by a subclass. diff --git a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/lepton/executor.py b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/lepton/executor.py index ecb8aa60..72fb2d4b 100644 --- a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/lepton/executor.py +++ b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/lepton/executor.py @@ -20,6 +20,7 @@ import os import time +import warnings from pathlib import Path from typing import List @@ -699,30 +700,46 @@ def get_status(id: str) -> List[ExecutionStatus]: return [ExecutionStatus(id=id, state=state, progress=progress_info)] @staticmethod - def kill_job(job_id: str) -> None: - """Kill Lepton evaluation jobs and clean up endpoints. + def kill_jobs(job_ids: List[str]) -> None: + """Kill one or more Lepton evaluation jobs and clean up endpoints. Args: - job_id: The job ID to kill. + job_ids: List of job IDs to kill. Raises: - ValueError: If job is not found or invalid. - RuntimeError: If job cannot be killed. + ValueError: If any job is not found or invalid. + RuntimeError: If any job cannot be killed. """ + if not job_ids: + return + db = ExecutionDB() - job_data = db.get_job(job_id) - if job_data is None: - raise ValueError(f"Job {job_id} not found") + job_data_list = [] + lepton_job_names = [] - if job_data.executor != "lepton": - raise ValueError( - f"Job {job_id} is not a Lepton job (executor: {job_data.executor})" - ) + # Validate all jobs first + for job_id in job_ids: + job_data = db.get_job(job_id) + if job_data is None: + raise ValueError(f"Job {job_id} not found") - # Cancel the specific Lepton job - lepton_job_name = job_data.data.get("lepton_job_name") + if job_data.executor != "lepton": + raise ValueError( + f"Job {job_id} is not a Lepton job (executor: {job_data.executor})" + ) - if lepton_job_name: + lepton_job_name = job_data.data.get("lepton_job_name") + if not lepton_job_name: + raise ValueError(f"No Lepton job name found for job {job_id}") + + job_data_list.append((job_id, job_data)) + lepton_job_names.append(lepton_job_name) + + # OPTIMIZATION: Cancel all Lepton jobs + errors = [] + killed_jobs = [] + for job_id, job_data in job_data_list: + lepton_job_name = job_data.data.get("lepton_job_name", "") cancel_success = delete_lepton_job(lepton_job_name) if cancel_success: print(f"✅ Cancelled Lepton job: {lepton_job_name}") @@ -730,6 +747,7 @@ def kill_job(job_id: str) -> None: job_data.data["status"] = "killed" job_data.data["killed_time"] = time.time() db.write_job(job_data) + killed_jobs.append((job_id, job_data)) else: # Use common helper to get informative error message based on job status status_list = LeptonExecutor.get_status(job_id) @@ -737,26 +755,43 @@ def kill_job(job_id: str) -> None: error_msg = LeptonExecutor.get_kill_failure_message( job_id, f"lepton_job: {lepton_job_name}", current_status ) - raise RuntimeError(error_msg) - else: - raise ValueError(f"No Lepton job name found for job {job_id}") + errors.append(error_msg) + + if errors: + raise RuntimeError("; ".join(errors)) + + print(f"🛑 Killed {len(killed_jobs)} Lepton job(s)") + + # For killed jobs, clean up endpoints if they're no longer in use + # Group by endpoint to avoid redundant checks + endpoints_to_check = {} + for job_id, job_data in killed_jobs: + endpoint_name = job_data.data.get("endpoint_name") + if endpoint_name: + if endpoint_name not in endpoints_to_check: + endpoints_to_check[endpoint_name] = [] + endpoints_to_check[endpoint_name].append((job_id, job_data)) + + # Check each endpoint and clean up if no other jobs are using it + for endpoint_name, jobs_using_endpoint in endpoints_to_check.items(): + # Get all jobs from the invocation(s) that might use this endpoint + invocation_ids = { + job_data.invocation_id for _, job_data in jobs_using_endpoint + } + all_jobs_in_invocations = {} + for inv_id in invocation_ids: + all_jobs_in_invocations.update(db.get_jobs(inv_id)) - print(f"🛑 Killed Lepton job {job_id}") - - # For individual jobs, also clean up the dedicated endpoint for this task - # Check if this was the last job using this specific endpoint - endpoint_name = job_data.data.get("endpoint_name") - if endpoint_name: # Check if any other jobs are still using this endpoint - jobs = db.get_jobs(job_data.invocation_id) + killed_job_ids = {job_id for job_id, _ in jobs_using_endpoint} other_jobs_using_endpoint = [ j - for j in jobs.values() + for j in all_jobs_in_invocations.values() if ( j.data.get("endpoint_name") == endpoint_name and j.data.get("status") not in ["killed", "failed", "succeeded", "cancelled"] - and j.job_id != job_id + and j.job_id not in killed_job_ids ) ] @@ -773,8 +808,27 @@ def kill_job(job_id: str) -> None: print( f"📌 Keeping endpoint {endpoint_name} (still used by {len(other_jobs_using_endpoint)} other jobs)" ) - else: - print("📌 No dedicated endpoint to clean up for this job") + + @staticmethod + def kill_job(job_id: str) -> None: + """Kill a single Lepton evaluation job and clean up endpoints. + + .. deprecated:: + This method is deprecated. Use :meth:`kill_jobs` instead. + + Args: + job_id: The job ID to kill. + + Raises: + ValueError: If job is not found or invalid. + RuntimeError: If job cannot be killed. + """ + warnings.warn( + "kill_job is deprecated. Use kill_jobs instead.", + DeprecationWarning, + stacklevel=2, + ) + LeptonExecutor.kill_jobs([job_id]) def _create_evaluation_launch_script( diff --git a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/executor.py b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/executor.py index f7bd492c..0f244411 100644 --- a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/executor.py +++ b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/executor.py @@ -26,6 +26,7 @@ import shutil import subprocess import time +import warnings from typing import List, Optional import jinja2 @@ -462,77 +463,116 @@ def get_status(id: str) -> List[ExecutionStatus]: ] @staticmethod - def kill_job(job_id: str) -> None: - """Kill a local job. + def kill_jobs(job_ids: List[str]) -> None: + """Kill one or more local jobs. Args: - job_id: The job ID (e.g., abc123.0) to kill. + job_ids: List of job IDs (e.g., ['abc123.0', 'abc123.1']) to kill. Raises: - ValueError: If job is not found or invalid. - RuntimeError: If Docker container cannot be stopped. + ValueError: If any job is not found or invalid. + RuntimeError: If any job cannot be killed. """ + if not job_ids: + return + db = ExecutionDB() - job_data = db.get_job(job_id) + job_data_list = [] + container_names = [] - if job_data is None: - raise ValueError(f"Job {job_id} not found") + # Validate all jobs first + for job_id in job_ids: + job_data = db.get_job(job_id) - if job_data.executor != "local": - raise ValueError( - f"Job {job_id} is not a local job (executor: {job_data.executor})" - ) + if job_data is None: + raise ValueError(f"Job {job_id} not found") + + if job_data.executor != "local": + raise ValueError( + f"Job {job_id} is not a local job (executor: {job_data.executor})" + ) + + container_name = job_data.data.get("container") + if not container_name: + raise ValueError(f"No container name found for job {job_id}") - # Get container name from database - container_name = job_data.data.get("container") - if not container_name: - raise ValueError(f"No container name found for job {job_id}") + job_data_list.append((job_id, job_data)) + container_names.append(container_name) killed_something = False - # First, try to stop the Docker container if it's running - result = subprocess.run( - shlex.split(f"docker stop {container_name}"), - capture_output=True, - text=True, - timeout=30, - ) - if result.returncode == 0: - killed_something = True - # Don't raise error if container doesn't exist (might be still pulling) - - # Find and kill Docker processes for this container - result = subprocess.run( - shlex.split(f"pkill -f 'docker run.*{container_name}'"), - capture_output=True, - text=True, - timeout=10, - ) - if result.returncode == 0: - killed_something = True - - # If we successfully killed something, mark as killed - if killed_something: - job_data.data["killed"] = True - db.write_job(job_data) - LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id) - return + # OPTIMIZATION: Stop all containers in one command + if container_names: + containers_str = " ".join(container_names) + result = subprocess.run( + shlex.split(f"docker stop {containers_str}"), + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + killed_something = True + # Don't raise error if containers don't exist (might be still pulling) + + # Find and kill Docker processes for all containers + # Use a pattern that matches any of the container names + pattern = "|".join(container_names) + result = subprocess.run( + shlex.split(f"pkill -f 'docker run.*({pattern})'"), + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + killed_something = True + + # Mark jobs as killed in database + errors = [] + for job_id, job_data in job_data_list: + if killed_something: + job_data.data["killed"] = True + db.write_job(job_data) + LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id) + else: + # If nothing was killed, check if this is a pending job + status_list = LocalExecutor.get_status(job_id) + if status_list and status_list[0].state == ExecutionState.PENDING: + # For pending jobs, mark as killed even though there's nothing to kill yet + job_data.data["killed"] = True + db.write_job(job_data) + LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id) + else: + # Use common helper to get informative error message based on job status + current_status = status_list[0].state if status_list else None + container_name = job_data.data.get("container", "") + error_msg = LocalExecutor.get_kill_failure_message( + job_id, f"container: {container_name}", current_status + ) + errors.append(error_msg) - # If nothing was killed, check if this is a pending job - status_list = LocalExecutor.get_status(job_id) - if status_list and status_list[0].state == ExecutionState.PENDING: - # For pending jobs, mark as killed even though there's nothing to kill yet - job_data.data["killed"] = True - db.write_job(job_data) - LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id) - return + if errors: + raise RuntimeError("; ".join(errors)) + + @staticmethod + def kill_job(job_id: str) -> None: + """Kill a single local job. - # Use common helper to get informative error message based on job status - current_status = status_list[0].state if status_list else None - error_msg = LocalExecutor.get_kill_failure_message( - job_id, f"container: {container_name}", current_status + .. deprecated:: + This method is deprecated. Use :meth:`kill_jobs` instead. + + Args: + job_id: The job ID (e.g., abc123.0) to kill. + + Raises: + ValueError: If job is not found or invalid. + RuntimeError: If Docker container cannot be stopped. + """ + warnings.warn( + "kill_job is deprecated. Use kill_jobs instead.", + DeprecationWarning, + stacklevel=2, ) - raise RuntimeError(error_msg) + LocalExecutor.kill_jobs([job_id]) @staticmethod def _add_to_killed_jobs(invocation_id: str, job_id: str) -> None: diff --git a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/slurm/executor.py b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/slurm/executor.py index 4851e573..55906264 100644 --- a/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/slurm/executor.py +++ b/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/slurm/executor.py @@ -441,48 +441,116 @@ def _map_slurm_state_to_execution_state(slurm_status: str) -> ExecutionState: return ExecutionState.FAILED @staticmethod - def kill_job(job_id: str) -> None: - """Kill a SLURM job. + def kill_jobs(job_ids: List[str]) -> None: + """Kill one or more SLURM jobs. Args: - job_id: The job ID (e.g., abc123.0) to kill. + job_ids: List of job IDs (e.g., ['abc123.0', 'abc123.1']) to kill. + + Raises: + ValueError: If any job is not found or invalid. + RuntimeError: If any job cannot be killed. """ + if not job_ids: + return + db = ExecutionDB() - job_data = db.get_job(job_id) - if job_data is None: - raise ValueError(f"Job {job_id} not found") + # Group jobs by connection parameters (username, hostname, socket) + # so we can kill jobs on the same host in one command + jobs_by_connection: Dict[ + tuple[str, str, str | None], List[tuple[str, JobData]] + ] = {} - if job_data.executor != "slurm": - raise ValueError( - f"Job {job_id} is not a slurm job (executor: {job_data.executor})" - ) + for job_id in job_ids: + job_data = db.get_job(job_id) - # OPTIMIZATION: Query status AND kill in ONE SSH call - slurm_status, result = _kill_slurm_job( - slurm_job_ids=[job_data.data.get("slurm_job_id")], - username=job_data.data.get("username"), - hostname=job_data.data.get("hostname"), - socket=job_data.data.get("socket"), - ) + if job_data is None: + raise ValueError(f"Job {job_id} not found") - # Mark job as killed in database if kill succeeded - if result.returncode == 0: - job_data.data["killed"] = True - db.write_job(job_data) - else: - # Use the pre-fetched status for better error message - current_status = None - if slurm_status: - current_status = SlurmExecutor._map_slurm_state_to_execution_state( - slurm_status + if job_data.executor != "slurm": + raise ValueError( + f"Job {job_id} is not a slurm job (executor: {job_data.executor})" ) - error_msg = SlurmExecutor.get_kill_failure_message( - job_id, - f"slurm_job_id: {job_data.data.get('slurm_job_id')}", - current_status, + + connection_key = ( + job_data.data.get("username", ""), + job_data.data.get("hostname", ""), + job_data.data.get("socket"), ) - raise RuntimeError(error_msg) + + if connection_key not in jobs_by_connection: + jobs_by_connection[connection_key] = [] + jobs_by_connection[connection_key].append((job_id, job_data)) + + # Kill jobs grouped by connection + errors = [] + for (username, hostname, socket), job_list in jobs_by_connection.items(): + slurm_job_ids = [ + job_data.data.get("slurm_job_id") + for _, job_data in job_list + if job_data.data.get("slurm_job_id") + ] + + if not slurm_job_ids: + continue + + # OPTIMIZATION: Query status AND kill in ONE SSH call for all jobs on this host + slurm_status_dict, result = _kill_slurm_job( + slurm_job_ids=slurm_job_ids, + username=username, + hostname=hostname, + socket=socket, + ) + + # Mark jobs as killed in database if kill succeeded + if result.returncode == 0: + for job_id, job_data in job_list: + job_data.data["killed"] = True + db.write_job(job_data) + else: + # Collect errors for all jobs that failed + for job_id, job_data in job_list: + slurm_job_id = job_data.data.get("slurm_job_id") + current_status = None + # Get status for this specific job from the status dict + if slurm_status_dict and slurm_job_id in slurm_status_dict: + slurm_status = slurm_status_dict[slurm_job_id] + current_status = ( + SlurmExecutor._map_slurm_state_to_execution_state( + slurm_status + ) + ) + error_msg = SlurmExecutor.get_kill_failure_message( + job_id, + f"slurm_job_id: {slurm_job_id}", + current_status, + ) + errors.append(error_msg) + + if errors: + raise RuntimeError("; ".join(errors)) + + @staticmethod + def kill_job(job_id: str) -> None: + """Kill a single SLURM job. + + .. deprecated:: + This method is deprecated. Use :meth:`kill_jobs` instead. + + Args: + job_id: The job ID (e.g., abc123.0) to kill. + + Raises: + ValueError: If job is not found or invalid. + RuntimeError: If job cannot be killed. + """ + warnings.warn( + "kill_job is deprecated. Use kill_jobs instead.", + DeprecationWarning, + stacklevel=2, + ) + SlurmExecutor.kill_jobs([job_id]) def _create_slurm_sbatch_script( @@ -971,8 +1039,8 @@ def _query_slurm_jobs_status( def _kill_slurm_job( slurm_job_ids: List[str], username: str, hostname: str, socket: str | None -) -> tuple[str | None, subprocess.CompletedProcess]: - """Kill a SLURM job, querying status first in one SSH call for efficiency. +) -> tuple[Dict[str, str] | None, subprocess.CompletedProcess]: + """Kill SLURM jobs, querying status first in one SSH call for efficiency. Args: slurm_job_ids: List of SLURM job IDs to kill. @@ -981,7 +1049,8 @@ def _kill_slurm_job( socket: control socket location or None Returns: - Tuple of (status_string, completed_process) where status_string is the SLURM status or None + Tuple of (status_dict, completed_process) where status_dict maps slurm_job_id to status string, + or None if status parsing failed or no jobs. """ if len(slurm_job_ids) == 0: return None, subprocess.CompletedProcess(args=[], returncode=0) @@ -1007,11 +1076,18 @@ def _kill_slurm_job( # Parse the sacct output (before scancel runs) sacct_output = completed_process.stdout.decode("utf-8") sacct_output_lines = sacct_output.strip().split("\n") - slurm_status = None - if sacct_output_lines and len(slurm_job_ids) == 1: - slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines) - - return slurm_status, completed_process + slurm_status_dict: Dict[str, str] = {} + if sacct_output_lines: + # Parse status for all jobs + for slurm_job_id in slurm_job_ids: + try: + status = _parse_slurm_job_status(slurm_job_id, sacct_output_lines) + slurm_status_dict[slurm_job_id] = status + except Exception: + # If parsing fails for a job, skip it + pass + + return slurm_status_dict if slurm_status_dict else None, completed_process def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str: