Skip to content

Commit

Permalink
Handle CUDA version modifiers on torch (#1876)
Browse files Browse the repository at this point in the history
* Handle CUDA version modifiers on torch

* If we get given a torch with a cuda version
modifier such as 2.0.3+cu118 we fail if we do not
explicitly support this version.
* Instead extract the CUDA version from the
version tag (11.8) and perform our cuda
compatibility lookup with those 2 versions.

* Update pkg/config/compatibility_test.go

Co-authored-by: Mattt <[email protected]>
Signed-off-by: Will Sackfield <[email protected]>

* Update pkg/config/compatibility.go

Co-authored-by: Mattt <[email protected]>
Signed-off-by: Will Sackfield <[email protected]>

* Fix indentation

---------

Signed-off-by: Will Sackfield <[email protected]>
Signed-off-by: Will Sackfield <[email protected]>
Co-authored-by: Mattt <[email protected]>
  • Loading branch information
8W9aG and mattt authored Aug 15, 2024
1 parent 8dcd976 commit 07cd53b
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 0 deletions.
47 changes: 47 additions & 0 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,55 @@ func generateTorchMinorVersionCompatibilityMatrix(matrix []TorchCompatibility) [

}

func cudaVersionFromTorchPlusVersion(ver string) (string, string) {
const cudaVersionPrefix = "cu"

// Split the version string by the '+' character.
versionParts := strings.Split(ver, "+")

// If there is no '+' in the version string, return the original string with an empty CUDA version.
if len(versionParts) <= 1 {
return "", ver
}

// Extract the part after the last '+'.
cudaVersionPart := versionParts[len(versionParts)-1]

// Check if the extracted part has the CUDA version prefix.
if !strings.HasPrefix(cudaVersionPart, cudaVersionPrefix) {
return "", ver
}

// Trim the CUDA version prefix and reformat the version string.
cleanVersion := strings.TrimPrefix(cudaVersionPart, cudaVersionPrefix)
if len(cleanVersion) < 2 {
return "", ver // Handle case where cleanVersion is too short to reformat.
}

// Insert a dot before the last character to format it as expected.
cleanVersion = cleanVersion[:len(cleanVersion)-1] + "." + cleanVersion[len(cleanVersion)-1:]

// Return the reformatted CUDA version and the main version.
return cleanVersion, versionParts[0]
}

func cudasFromTorch(ver string) ([]string, error) {
cudas := []string{}

// Check the version modifier on torch (such as +cu118)
cudaVer, ver := cudaVersionFromTorchPlusVersion(ver)
if len(cudaVer) > 0 {
for _, compat := range TorchCompatibilityMatrix {
if compat.CUDA == nil {
continue
}
if ver == compat.TorchVersion() && *compat.CUDA == cudaVer {
cudas = append(cudas, *compat.CUDA)
return cudas, nil
}
}
}

for _, compat := range TorchCompatibilityMatrix {
if ver == compat.TorchVersion() && compat.CUDA != nil {
cudas = append(cudas, *compat.CUDA)
Expand Down
7 changes: 7 additions & 0 deletions pkg/config/compatibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ func TestGenerateTorchMinorVersionCompatibilityMatrix(t *testing.T) {
require.Equal(t, expected, actual)
}

func TestCudasFromTorchWithCUVersionModifier(t *testing.T) {
cudas, err := cudasFromTorch("2.0.1+cu118")
require.GreaterOrEqual(t, len(cudas), 1)
require.Equal(t, cudas[0], "11.8")
require.Nil(t, err)
}

func stringp(s string) *string {
return &s
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
build:
gpu: true
python_version: "3.9"
python_packages:
- "torch==2.0.1+cu118"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, s: str) -> str:
return "hello " + s
10 changes: 10 additions & 0 deletions test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def test_build_base_image_sha(docker_image):
assert base_layer_hash in layers


def test_torch_2_0_3_cu118_base_image(docker_image):
project_dir = Path(__file__).parent / "fixtures/torch-cuda-baseimage-project"
build_process = subprocess.run(
["cog", "build", "-t", docker_image, "--use-cog-base-image"],
cwd=project_dir,
capture_output=True,
)
assert build_process.returncode == 0


def test_torch_1_13_0_base_image_fallback(docker_image):
project_dir = Path(__file__).parent / "fixtures/torch-baseimage-project"
build_process = subprocess.run(
Expand Down

0 comments on commit 07cd53b

Please sign in to comment.