Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/test_platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def test_nvidia_gpu_linux(monkeypatch):
assert get_torch_platform(gpu_infos) == expected


def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11))
monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6)

gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]

assert get_torch_platform(gpu_infos) == "cu128"
assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124"


def test_nvidia_gpu_mac(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
Expand Down
2 changes: 1 addition & 1 deletion torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def install(packages=[], use_uv=False):
"""

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
torch_platform = get_torch_platform(gpu_infos, packages=packages)
cmds = get_install_commands(torch_platform, packages)
cmds = get_pip_commands(cmds, use_uv=use_uv)
run_commands(cmds)
103 changes: 100 additions & 3 deletions torchruntime/platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,115 @@
import sys
import platform

from packaging.requirements import Requirement
from packaging.version import Version

from .gpu_db import get_nvidia_arch, get_amd_gfx_info
from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK

os_name = platform.system()
arch = platform.machine().lower()
py_version = sys.version_info

_CUDA_12_8_PLATFORM = "cu128"
_CUDA_12_4_PLATFORM = "cu124"
_CUDA_12_8_MIN_VERSIONS = {
"torch": Version("2.7.0"),
"torchaudio": Version("2.7.0"),
"torchvision": Version("0.22.0"),
}


def _parse_release_segments(text):
segments = []
for part in text.split("."):
match = re.match(r"^(\d+)", part)
if not match:
break
segments.append(int(match.group(1)))
return segments


def _upper_bound_for_specifier(specifier):
operator = specifier.operator
version = specifier.version

if operator == "<":
return Version(version), False
if operator == "<=":
return Version(version), True
if operator == "==":
if "*" in version:
prefix = version.split("*", 1)[0].rstrip(".")
prefix_segments = _parse_release_segments(prefix)
if not prefix_segments:
return None, None
prefix_segments[-1] += 1
upper = Version(".".join(str(s) for s in prefix_segments))
return upper, False
return Version(version), True
if operator == "~=":
release_segments = _parse_release_segments(version)
if len(release_segments) < 2:
return None, None
bump_index = len(release_segments) - 2
upper_segments = release_segments[: bump_index + 1]
upper_segments[bump_index] += 1
upper = Version(".".join(str(s) for s in upper_segments))
return upper, False

return None, None


def _packages_require_cuda_12_4(packages):
if not packages:
return False

for package in packages:
try:
requirement = Requirement(package)
except Exception:
continue

name = requirement.name.lower().replace("_", "-")
threshold = _CUDA_12_8_MIN_VERSIONS.get(name)
if not threshold or not requirement.specifier:
continue

threshold_allowed = None
for specifier in requirement.specifier:
upper, inclusive = _upper_bound_for_specifier(specifier)
if not upper:
continue

if upper < threshold:
return True

if upper == threshold and not inclusive:
return True

if upper == threshold and inclusive:
if threshold_allowed is None:
threshold_allowed = requirement.specifier.contains(threshold, prereleases=True)
if not threshold_allowed:
return True

return False


def _adjust_cuda_platform_for_requested_packages(torch_platform, packages):
if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be moved directly inside the discrete function, just before we return cu128.

CUDA isn't integrated, so this check doesn't need to occupy anything more than it needs.

For e.g.:

if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
    return "cu124"

if _packages_require_cuda_12_4(packages):
    return "cu124"

return "cu128"

return _CUDA_12_4_PLATFORM
return torch_platform


def get_torch_platform(gpu_infos):
def get_torch_platform(gpu_infos, packages=[]):
"""
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.

Args:
gpu_infos (list of `torchruntime.device_db.GPU` instances)
packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings.

Returns:
str: A string representing the platform to use. Possible values:
Expand Down Expand Up @@ -53,9 +148,11 @@ def get_torch_platform(gpu_infos):
integrated_devices.append(device)

if discrete_devices:
return _get_platform_for_discrete(discrete_devices)
torch_platform = _get_platform_for_discrete(discrete_devices)
return _adjust_cuda_platform_for_requested_packages(torch_platform, packages)

return _get_platform_for_integrated(integrated_devices)
torch_platform = _get_platform_for_integrated(integrated_devices)
return _adjust_cuda_platform_for_requested_packages(torch_platform, packages)


def _get_platform_for_discrete(gpu_infos):
Expand Down