diff --git a/runners/mlcube_docker/mlcube_docker/docker_run.py b/runners/mlcube_docker/mlcube_docker/docker_run.py index 3b96ea3..2c84140 100644 --- a/runners/mlcube_docker/mlcube_docker/docker_run.py +++ b/runners/mlcube_docker/mlcube_docker/docker_run.py @@ -267,7 +267,21 @@ def run(self) -> None: run_args += " " + extra_args valid_gpu_flag = "--gpus" in self.mlcube.runner and self.mlcube.runner["--gpus"] is not None - cuda_visible_devices = self.mlcube.runner["--gpus"] if valid_gpu_flag else num_gpus + + + if valid_gpu_flag: + cuda_visible_devices = self.mlcube.runner["--gpus"] + if "device" in cuda_visible_devices: + cuda_visible_devices = cuda_visible_devices.replace("device=", "") + else: + cuda_visible_devices = num_gpus + if num_gpus == 0: + cuda_visible_devices = "" + + if cuda_visible_devices.isnumeric(): + cuda_visible_devices = str(list(range(cuda_visible_devices))) + cuda_visible_devices = cuda_visible_devices.replace(" ", "").replace("[","").replace("]","") + run_args += f" --env CUDA_VISIBLE_DEVICES={cuda_visible_devices}" if "entrypoint" in self.mlcube.tasks[self.task]: