Skip to content

Commit

Permalink
refine torch install
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 21, 2024
1 parent e5fcd19 commit f82673c
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,16 +2074,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
if is_windows():
cwd = os.path.join(cwd, config)

# Install PyTorch which is required for transformers tests, and optional for some python tests.
if args.enable_transformers_tool_test and not args.disable_contrib_ops and not args.use_rocm:
index_url = "https://download.pytorch.org/whl/cpu"
if args.use_cuda and is_linux():
index_url = "https://download.pytorch.org/whl/cu124"
if args.cuda_version and version_to_tuple(args.cuda_version) < (12, 0):
index_url = "https://download.pytorch.org/whl/cu118"

# PyTorch is required for transformers tests, and optional for some python tests.
# Install cpu only version of torch when cuda is not enabled in Linux.
extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"]
run_subprocess(
[sys.executable, "-m", "pip", "install", "torch", "--index-url", index_url],
[sys.executable, "-m", "pip", "install", "torch", *extra],
cwd=cwd,
dll_path=dll_path,
python_path=python_path,
Expand Down

0 comments on commit f82673c

Please sign in to comment.