Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions python/ray/dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,121 @@ async def get_task_cpu_profile(
headers={"Content-Type": "text/html"},
)

@routes.get("/task/gpu_profile")
async def get_task_gpu_profile(
self, req: aiohttp.web.Request
) -> aiohttp.web.Response:
"""Retrieves the Torch GPU profile trace for a specific task.
Note that one worker process works on one task at a time
or one worker works on multiple async tasks.

This is a Torch-specific API. It is not supported for other frameworks.

Params:
task_id: Required. The ID of the task.
attempt_number: Required. The attempt number of the task.
node_id: Required. The ID of the node.
num_iterations: Optional. Number of training steps for profiling. Defaults to 4.
This is the number of calls to the torch Optimizer.step().

Returns:
A redirect to the log API to download the GPU profiling trace file.

Raises:
ValueError: If the "task_id" parameter is missing in the request query.
ValueError: If the "attempt_number" parameter is missing in the request query.
ValueError: If the "node_id" parameter is missing in the request query.
ValueError: If the worker begins working on another task during the profile retrieval.
aiohttp.web.HTTPInternalServerError: If there is an internal server error during the profile retrieval.
aiohttp.web.HTTPInternalServerError: if one of the following happens:
(1) The GPU profiling dependencies are not installed on the target node.
(2) The target node doesn't have GPUs.
(3) The GPU profiling fails or times out.
The output will contain a description of the error.
For example, trying to profile a non-Torch training process will
result in an error.
"""
if "task_id" not in req.query:
raise ValueError("task_id is required")
if "attempt_number" not in req.query:
raise ValueError("task's attempt number is required")
if "node_id" not in req.query:
raise ValueError("node_id is required")

task_id = req.query.get("task_id")
attempt_number = req.query.get("attempt_number")
node_id_hex = req.query.get("node_id")

# Profile for num_iterations training steps (calls to optimizer.step())
num_iterations = int(req.query.get("num_iterations", 4))

addrs = await self._get_stub_address_by_node_id(NodeID.from_hex(node_id_hex))
if not addrs:
raise aiohttp.web.HTTPInternalServerError(
text=f"Failed to get agent address for node {node_id_hex}"
)
node_id, ip, http_port, grpc_port = addrs
reporter_stub = self._make_stub(build_address(ip, grpc_port))

try:
(pid, _) = await self.get_worker_details_for_running_task(
task_id, attempt_number
)
except ValueError as e:
raise aiohttp.web.HTTPInternalServerError(text=str(e))

logger.info(
f"Sending GPU profiling request to {build_address(ip, grpc_port)}, pid {pid}, for {task_id}. "
f"Profiling for {num_iterations} training steps."
)

reply = await reporter_stub.GpuProfiling(
reporter_pb2.GpuProfilingRequest(
pid=pid, num_iterations=num_iterations
)
)

"""
In order to truly confirm whether there are any other tasks
running during the profiling, we need to retrieve all tasks
that are currently running or have finished, and then parse
the task events (i.e., their start and finish times) to check
for any potential overlap. However, this process can be quite
extensive, so here we will make our best efforts to check
for any overlapping tasks. Therefore, we will check if
the task is still running
"""
try:
(_, worker_id) = await self.get_worker_details_for_running_task(
task_id, attempt_number
)
except ValueError as e:
raise aiohttp.web.HTTPInternalServerError(text=str(e))

if not reply.success:
return aiohttp.web.HTTPInternalServerError(text=reply.output)
logger.info("Returning profiling response, size {}".format(len(reply.output)))

task_ids_in_a_worker = await self.get_task_ids_running_in_a_worker(worker_id)
if len(task_ids_in_a_worker) > 1:
logger.warning(
f"Warning: Task {task_id} is running in a worker process that is running multiple tasks: {task_ids_in_a_worker}"
)

filepath = str(reply.output)
download_filename = Path(filepath).name

query = urlencode(
{
"node_ip": ip,
"filename": filepath,
"download_filename": download_filename,
"lines": "-1",
}
)
redirect_url = f"/api/v0/logs/file?{query}"
raise aiohttp.web.HTTPFound(redirect_url)

@routes.get("/worker/traceback")
async def get_traceback(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
"""Retrieves the traceback information for a specific worker.
Expand Down