diff --git a/mlcube/mlcube/parser.py b/mlcube/mlcube/parser.py index 17ee426..0ef794d 100644 --- a/mlcube/mlcube/parser.py +++ b/mlcube/mlcube/parser.py @@ -128,6 +128,9 @@ def parse_extra_arg( if parsed_args.get("gpus", None): if platform == "docker": runner_run_args["--gpus"] = parsed_args["gpus"] + os.environ["CUDA_VISIBLE_DEVICES"] = parsed_args[ + "gpus" + ] else: runner_run_args["--nv"] = "" os.environ["SINGULARITYENV_CUDA_VISIBLE_DEVICES"] = parsed_args[