Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,45 +339,6 @@ def kill_job_or_invocation(id: str) -> list[dict[str, Any]]:
db = ExecutionDB()
results = []

def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]:
"""Helper function to kill a single job."""
try:
executor_cls = get_executor(job_data.executor)
if hasattr(executor_cls, "kill_job"):
executor_cls.kill_job(job_id)
# Success - job was killed
return {
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "killed",
"data": {"result": "Successfully killed job"},
}
else:
return {
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {
"error": f"Executor {job_data.executor} does not support killing jobs"
},
}
except (ValueError, RuntimeError) as e:
# Expected errors from kill_job
return {
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {"error": str(e)},
}
except Exception as e:
# Unexpected errors
return {
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {"error": f"Unexpected error: {str(e)}"},
}

# Determine if this is a job ID or invocation ID
if "." in id:
# This is a job ID - kill single job
Expand All @@ -391,11 +352,11 @@ def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]:
"data": {},
}
]
results.append(kill_single_job(id, job_data))
jobs_to_kill = {id: job_data}
else:
# This is an invocation ID - kill all jobs in the invocation
jobs = db.get_jobs(id)
if not jobs:
jobs_to_kill = db.get_jobs(id)
if not jobs_to_kill:
return [
{
"invocation": id,
Expand All @@ -405,9 +366,84 @@ def kill_single_job(job_id: str, job_data: JobData) -> dict[str, Any]:
}
]

# Kill each job in the invocation
for job_id, job_data in jobs.items():
results.append(kill_single_job(job_id, job_data))
# Group jobs by executor to optimize kill operations
jobs_by_executor: dict[str, list[tuple[str, JobData]]] = {}
for job_id, job_data in jobs_to_kill.items():
executor = job_data.executor
if executor not in jobs_by_executor:
jobs_by_executor[executor] = []
jobs_by_executor[executor].append((job_id, job_data))

# Kill jobs grouped by executor (optimization: one call per executor)
for executor, job_list in jobs_by_executor.items():
try:
executor_cls = get_executor(executor)
if not hasattr(executor_cls, "kill_jobs"):
# Executor doesn't support killing jobs
for job_id, job_data in job_list:
results.append(
{
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {
"error": f"Executor {executor} does not support killing jobs"
},
}
)
continue

# Extract job IDs for batch kill
job_ids = [job_id for job_id, _ in job_list]

try:
# OPTIMIZATION: Kill all jobs for this executor in one call
executor_cls.kill_jobs(job_ids)

# Success - all jobs were killed
for job_id, job_data in job_list:
results.append(
{
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "killed",
"data": {"result": "Successfully killed job"},
}
)
except (ValueError, RuntimeError) as e:
# Expected errors from kill_job - mark all jobs as error
# Note: kill_job may have killed some jobs before failing
for job_id, job_data in job_list:
results.append(
{
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {"error": str(e)},
}
)
except Exception as e:
# Unexpected errors
for job_id, job_data in job_list:
results.append(
{
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {"error": f"Unexpected error: {str(e)}"},
}
)
except ValueError as e:
# Error getting executor class
for job_id, job_data in job_list:
results.append(
{
"invocation": job_data.invocation_id,
"job_id": job_id,
"status": "error",
"data": {"error": str(e)},
}
)

return results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from typing import Any, List, Optional

from omegaconf import DictConfig

Expand Down Expand Up @@ -81,15 +81,15 @@ def get_status(id: str) -> list[ExecutionStatus]:

@staticmethod
@abstractmethod
def kill_job(job_id: str) -> None:
"""Kill a job by its ID.
def kill_jobs(job_ids: List[str]) -> None:
"""Kill one or more jobs by their IDs.

Args:
job_id: The job ID to kill.
job_ids: List of job IDs to kill.

Raises:
ValueError: If job is not found or invalid.
RuntimeError: If job cannot be killed.
ValueError: If any job is not found or invalid.
RuntimeError: If any job cannot be killed.

Raises:
NotImplementedError: If not implemented by a subclass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import os
import time
import warnings
from pathlib import Path
from typing import List

Expand Down Expand Up @@ -699,64 +700,98 @@ def get_status(id: str) -> List[ExecutionStatus]:
return [ExecutionStatus(id=id, state=state, progress=progress_info)]

@staticmethod
def kill_job(job_id: str) -> None:
"""Kill Lepton evaluation jobs and clean up endpoints.
def kill_jobs(job_ids: List[str]) -> None:
"""Kill one or more Lepton evaluation jobs and clean up endpoints.

Args:
job_id: The job ID to kill.
job_ids: List of job IDs to kill.

Raises:
ValueError: If job is not found or invalid.
RuntimeError: If job cannot be killed.
ValueError: If any job is not found or invalid.
RuntimeError: If any job cannot be killed.
"""
if not job_ids:
return

db = ExecutionDB()
job_data = db.get_job(job_id)
if job_data is None:
raise ValueError(f"Job {job_id} not found")
job_data_list = []
lepton_job_names = []

if job_data.executor != "lepton":
raise ValueError(
f"Job {job_id} is not a Lepton job (executor: {job_data.executor})"
)
# Validate all jobs first
for job_id in job_ids:
job_data = db.get_job(job_id)
if job_data is None:
raise ValueError(f"Job {job_id} not found")

# Cancel the specific Lepton job
lepton_job_name = job_data.data.get("lepton_job_name")
if job_data.executor != "lepton":
raise ValueError(
f"Job {job_id} is not a Lepton job (executor: {job_data.executor})"
)

if lepton_job_name:
lepton_job_name = job_data.data.get("lepton_job_name")
if not lepton_job_name:
raise ValueError(f"No Lepton job name found for job {job_id}")

job_data_list.append((job_id, job_data))
lepton_job_names.append(lepton_job_name)

# OPTIMIZATION: Cancel all Lepton jobs
errors = []
killed_jobs = []
for job_id, job_data in job_data_list:
lepton_job_name = job_data.data.get("lepton_job_name", "")
cancel_success = delete_lepton_job(lepton_job_name)
if cancel_success:
print(f"✅ Cancelled Lepton job: {lepton_job_name}")
# Mark job as killed in database
job_data.data["status"] = "killed"
job_data.data["killed_time"] = time.time()
db.write_job(job_data)
killed_jobs.append((job_id, job_data))
else:
# Use common helper to get informative error message based on job status
status_list = LeptonExecutor.get_status(job_id)
current_status = status_list[0].state if status_list else None
error_msg = LeptonExecutor.get_kill_failure_message(
job_id, f"lepton_job: {lepton_job_name}", current_status
)
raise RuntimeError(error_msg)
else:
raise ValueError(f"No Lepton job name found for job {job_id}")
errors.append(error_msg)

if errors:
raise RuntimeError("; ".join(errors))

print(f"🛑 Killed {len(killed_jobs)} Lepton job(s)")

# For killed jobs, clean up endpoints if they're no longer in use
# Group by endpoint to avoid redundant checks
endpoints_to_check = {}
for job_id, job_data in killed_jobs:
endpoint_name = job_data.data.get("endpoint_name")
if endpoint_name:
if endpoint_name not in endpoints_to_check:
endpoints_to_check[endpoint_name] = []
endpoints_to_check[endpoint_name].append((job_id, job_data))

# Check each endpoint and clean up if no other jobs are using it
for endpoint_name, jobs_using_endpoint in endpoints_to_check.items():
# Get all jobs from the invocation(s) that might use this endpoint
invocation_ids = {
job_data.invocation_id for _, job_data in jobs_using_endpoint
}
all_jobs_in_invocations = {}
for inv_id in invocation_ids:
all_jobs_in_invocations.update(db.get_jobs(inv_id))

print(f"🛑 Killed Lepton job {job_id}")

# For individual jobs, also clean up the dedicated endpoint for this task
# Check if this was the last job using this specific endpoint
endpoint_name = job_data.data.get("endpoint_name")
if endpoint_name:
# Check if any other jobs are still using this endpoint
jobs = db.get_jobs(job_data.invocation_id)
killed_job_ids = {job_id for job_id, _ in jobs_using_endpoint}
other_jobs_using_endpoint = [
j
for j in jobs.values()
for j in all_jobs_in_invocations.values()
if (
j.data.get("endpoint_name") == endpoint_name
and j.data.get("status")
not in ["killed", "failed", "succeeded", "cancelled"]
and j.job_id != job_id
and j.job_id not in killed_job_ids
)
]

Expand All @@ -773,8 +808,27 @@ def kill_job(job_id: str) -> None:
print(
f"📌 Keeping endpoint {endpoint_name} (still used by {len(other_jobs_using_endpoint)} other jobs)"
)
else:
print("📌 No dedicated endpoint to clean up for this job")

@staticmethod
def kill_job(job_id: str) -> None:
"""Kill a single Lepton evaluation job and clean up endpoints.

.. deprecated::
This method is deprecated. Use :meth:`kill_jobs` instead.

Args:
job_id: The job ID to kill.

Raises:
ValueError: If job is not found or invalid.
RuntimeError: If job cannot be killed.
"""
warnings.warn(
"kill_job is deprecated. Use kill_jobs instead.",
DeprecationWarning,
stacklevel=2,
)
LeptonExecutor.kill_jobs([job_id])


def _create_evaluation_launch_script(
Expand Down
Loading