diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 72ba0e02cd88..2ce44d78020b 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -407,6 +407,119 @@ async def get_task_cpu_profile( headers={"Content-Type": "text/html"}, ) + @routes.get("/task/gpu_profile") + async def get_task_gpu_profile( + self, req: aiohttp.web.Request + ) -> aiohttp.web.Response: + """Retrieves the Torch GPU profile trace for a specific task. + Note that one worker process works on one task at a time + or one worker works on multiple async tasks. + + This is a Torch-specific API. It is not supported for other frameworks. + + Params: + task_id: Required. The ID of the task. + attempt_number: Required. The attempt number of the task. + node_id: Required. The ID of the node. + num_iterations: Optional. Number of training steps for profiling. Defaults to 4. + This is the number of calls to the torch Optimizer.step(). + + Returns: + A redirect to the log API to download the GPU profiling trace file. + + Raises: + ValueError: If the "task_id" parameter is missing in the request query. + ValueError: If the "attempt_number" parameter is missing in the request query. + ValueError: If the "node_id" parameter is missing in the request query. + ValueError: If the worker begins working on another task during the profile retrieval. + aiohttp.web.HTTPInternalServerError: If there is an internal server error during the profile retrieval. + aiohttp.web.HTTPInternalServerError: if one of the following happens: + (1) The GPU profiling dependencies are not installed on the target node. + (2) The target node doesn't have GPUs. + (3) The GPU profiling fails or times out. + The output will contain a description of the error. + For example, trying to profile a non-Torch training process will + result in an error. + """ + if "task_id" not in req.query: + raise ValueError("task_id is required") + if "attempt_number" not in req.query: + raise ValueError("task's attempt number is required") + if "node_id" not in req.query: + raise ValueError("node_id is required") + + task_id = req.query.get("task_id") + attempt_number = req.query.get("attempt_number") + node_id_hex = req.query.get("node_id") + + # Profile for num_iterations training steps (calls to optimizer.step()) + num_iterations = int(req.query.get("num_iterations", 4)) + + addrs = await self._get_stub_address_by_node_id(NodeID.from_hex(node_id_hex)) + if not addrs: + raise aiohttp.web.HTTPInternalServerError( + text=f"Failed to get agent address for node {node_id_hex}" + ) + node_id, ip, http_port, grpc_port = addrs + reporter_stub = self._make_stub(build_address(ip, grpc_port)) + + try: + (pid, _) = await self.get_worker_details_for_running_task( + task_id, attempt_number + ) + except ValueError as e: + raise aiohttp.web.HTTPInternalServerError(text=str(e)) + + logger.info( + f"Sending GPU profiling request to {build_address(ip, grpc_port)}, pid {pid}, for {task_id}. " + f"Profiling for {num_iterations} training steps." + ) + + reply = await reporter_stub.GpuProfiling( + reporter_pb2.GpuProfilingRequest(pid=pid, num_iterations=num_iterations) + ) + + """ + In order to truly confirm whether there are any other tasks + running during the profiling, we need to retrieve all tasks + that are currently running or have finished, and then parse + the task events (i.e., their start and finish times) to check + for any potential overlap. However, this process can be quite + extensive, so here we will make our best efforts to check + for any overlapping tasks. Therefore, we will check if + the task is still running + """ + try: + (_, worker_id) = await self.get_worker_details_for_running_task( + task_id, attempt_number + ) + except ValueError as e: + raise aiohttp.web.HTTPInternalServerError(text=str(e)) + + if not reply.success: + return aiohttp.web.HTTPInternalServerError(text=reply.output) + logger.info("Returning profiling response, size {}".format(len(reply.output))) + + task_ids_in_a_worker = await self.get_task_ids_running_in_a_worker(worker_id) + if len(task_ids_in_a_worker) > 1: + logger.warning( + f"Warning: Task {task_id} is running in a worker process that is running multiple tasks: {task_ids_in_a_worker}" + ) + + filepath = str(reply.output) + download_filename = Path(filepath).name + + query = urlencode( + { + "node_ip": ip, + "filename": filepath, + "download_filename": download_filename, + "lines": "-1", + } + ) + redirect_url = f"/api/v0/logs/file?{query}" + raise aiohttp.web.HTTPFound(redirect_url) + @routes.get("/worker/traceback") async def get_traceback(self, req: aiohttp.web.Request) -> aiohttp.web.Response: """Retrieves the traceback information for a specific worker. diff --git a/python/ray/dashboard/modules/reporter/tests/test_reporter.py b/python/ray/dashboard/modules/reporter/tests/test_reporter.py index 18dda866ec1f..e388abce1552 100644 --- a/python/ray/dashboard/modules/reporter/tests/test_reporter.py +++ b/python/ray/dashboard/modules/reporter/tests/test_reporter.py @@ -1132,6 +1132,107 @@ def verify(): wait_for_condition(verify, timeout=10) +@pytest.mark.skipif( + os.environ.get("RAY_MINIMAL") == "1", + reason="This test is not supposed to work for minimal installation.", +) +def test_get_gpu_profile_non_running_task(shutdown_only): + """ + Verify that we throw an error for a non-running task. + """ + address_info = ray.init() + webui_url = format_web_url(address_info["webui_url"]) + + @ray.remote + def f(): + pass + + ray.get([f.remote() for _ in range(5)]) + + params = { + "task_id": TASK["task_id"], + "attempt_number": TASK["attempt_number"], + "node_id": TASK["node_id"], + } + + # Make sure the API works. + def verify(): + with pytest.raises(requests.exceptions.HTTPError) as exc_info: + resp = requests.get(f"{webui_url}/task/gpu_profile", params=params) + resp.raise_for_status() + assert isinstance(exc_info.value, requests.exceptions.HTTPError) + return True + + wait_for_condition(verify, timeout=10) + + +@pytest.mark.skipif( + os.environ.get("RAY_MINIMAL") == "1", + reason="This test is not supposed to work for minimal installation.", +) +def test_get_gpu_profile_running_task(shutdown_only): + """ + Verify that we can get the GPU profile for a running task. + Note: This test may fail if GPUs or GPU profiling dependencies are not available, + but it verifies that the API endpoint is accessible and handles requests correctly. + """ + address_info = ray.init(include_dashboard=True) + webui_url = format_web_url(address_info["webui_url"]) + + @ray.remote + def f(): + pass + + @ray.remote + def long_running_task(): + print("Long-running task began.") + time.sleep(1000) + print("Long-running task completed.") + + ray.get([f.remote() for _ in range(5)]) + + task = long_running_task.remote() + + params = { + "task_id": task.task_id().hex(), + "attempt_number": 0, + "node_id": ray.get_runtime_context().get_node_id(), + "num_iterations": 2, + } + + def verify(): + resp = requests.get( + f"{webui_url}/task/gpu_profile", params=params, allow_redirects=False + ) + print(f"resp.status_code: {resp.status_code}") + print(f"resp.text: {resp.text[:200] if resp.text else 'No text'}") + + # GPU profiling can either: + # 1. Succeed and return a redirect (HTTP 302) to download the trace file + # 2. Fail with an error (HTTP 500) if GPUs/dependencies are not available + # 3. Return 404 if the task isn't ready yet (temporary state) + # Both 302 and 500 indicate the API endpoint is working correctly + if resp.status_code == 404: + # Task might not be ready yet, retry + return False + + assert resp.status_code in [ + 302, + 500, + ], f"Unexpected status code: {resp.status_code}, text: {resp.text[:200]}" + + # If it's a redirect, verify it points to the logs API + if resp.status_code == 302: + assert "Location" in resp.headers + assert "/api/v0/logs/file" in resp.headers["Location"] + # If it's an error, verify it's a proper error response + elif resp.status_code == 500: + assert len(resp.text) > 0 + return True + + wait_for_condition(verify, timeout=20) + + @pytest.mark.skipif( os.environ.get("RAY_MINIMAL") == "1", reason="This test is not supposed to work for minimal installation.",