Skip to content

Commit

Permalink
Add Torchvision and Torchaudio to base images (#1897)
Browse files Browse the repository at this point in the history
* Exclude torchaudio from pip installs

* Use our torch audio from the cog base images
instead

* Install torchvision and torchaudio in base images

* Add the compatible versions of torchvision and
torchaudio to the base image python packages

* Include known packages into the requirements.txt

* Fix torch version modifiers in the requirements

* Set cog base image to false with no base image

* If we don’t find a base image explicitly set
this flag to false if it was ambiguous before
* This allows the rest of the script to continue
as if the user does not want a base image
  • Loading branch information
8W9aG authored Aug 23, 2024
1 parent 35893d7 commit b76869f
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 24 deletions.
6 changes: 3 additions & 3 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
var latest *TorchCompatibility
for _, compat := range TorchCompatibilityMatrix {
compat := compat
if compat.TorchVersion() != ver || compat.CUDA == nil {
if !version.Matches(compat.TorchVersion(), ver) || compat.CUDA == nil {
continue
}
greater, err := versionGreater(*compat.CUDA, cuda)
Expand Down Expand Up @@ -330,7 +330,7 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
return "torch", ver, "", "", nil
}

return "torch", latest.Torch, latest.FindLinks, latest.ExtraIndexURL, nil
return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
Expand Down Expand Up @@ -379,7 +379,7 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
return "torchvision", ver, "", "", nil
}

return "torchvision", latest.Torchvision, latest.FindLinks, latest.ExtraIndexURL, nil
return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
Expand Down
44 changes: 39 additions & 5 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
var (
BuildSourceEpochTimestamp int64 = -1
BuildXCachePath string
PipPackageNameRegex = regexp.MustCompile(`^([^>=<~ \n[#]+)`)
)

// TODO(andreas): support conda packages
Expand Down Expand Up @@ -181,6 +182,10 @@ func (c *Config) TorchvisionVersion() (string, bool) {
return c.pythonPackageVersion("torchvision")
}

func (c *Config) TorchaudioVersion() (string, bool) {
return c.pythonPackageVersion("torchaudio")
}

func (c *Config) TensorFlowVersion() (string, bool) {
return c.pythonPackageVersion("tensorflow")
}
Expand Down Expand Up @@ -307,15 +312,18 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
}

// PythonRequirementsForArch returns a requirements.txt file with all the GPU packages resolved for given OS and architecture.
func (c *Config) PythonRequirementsForArch(goos string, goarch string, excludePackages []string) (string, error) {
func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePackages []string) (string, error) {
packages := []string{}
findLinksSet := map[string]bool{}
extraIndexURLSet := map[string]bool{}
for _, pkg := range c.Build.pythonRequirementsContent {
if slices.ContainsString(excludePackages, pkg) {
continue
}

includePackageNames := []string{}
for _, pkg := range includePackages {
includePackageNames = append(includePackageNames, packageName(pkg))
}

// Include all the requirements and remove our include packages if they exist
for _, pkg := range c.Build.pythonRequirementsContent {
archPkg, findLinksList, extraIndexURLs, err := c.pythonPackageForArch(pkg, goos, goarch)
if err != nil {
return "", err
Expand All @@ -331,8 +339,26 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, excludePa
extraIndexURLSet[u] = true
}
}

packageName := packageName(archPkg)
if packageName != "" {
foundIdx := -1
for i, includePkg := range includePackageNames {
if includePkg == packageName {
foundIdx = i
break
}
}
if foundIdx != -1 {
includePackageNames = append(includePackageNames[:foundIdx], includePackageNames[foundIdx+1:]...)
includePackages = append(includePackages[:foundIdx], includePackages[foundIdx+1:]...)
}
}
}

// If we still have some include packages add them in
packages = append(packages, includePackages...)

// Create final requirements.txt output
// Put index URLs first
lines := []string{}
Expand Down Expand Up @@ -576,3 +602,11 @@ func sliceContains(slice []string, s string) bool {
}
return false
}

func packageName(pipRequirement string) string {
match := PipPackageNameRegex.FindStringSubmatch(pipRequirement)
if len(match) <= 1 {
return ""
}
return match[1]
}
38 changes: 31 additions & 7 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ foo==1.0.0`), 0o644)
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.7.1+cu110
torchvision==0.8.2+cu110
torch==1.7.1
torchvision==0.8.2
torchaudio==0.7.2
foo==1.0.0`
require.Equal(t, expected, requirements)
Expand Down Expand Up @@ -197,8 +197,8 @@ foo==1.0.0`), 0o644)
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.12.1+cu116
torchvision==0.13.1+cu116
torch==1.12.1
torchvision==0.13.1
torchaudio==0.12.1
foo==1.0.0`
require.Equal(t, expected, requirements)
Expand Down Expand Up @@ -406,8 +406,8 @@ func TestPythonPackagesForArchTorchGPU(t *testing.T) {
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.7.1+cu110
torchvision==0.8.2+cu110
torch==1.7.1
torchvision==0.8.2
torchaudio==0.7.2
foo==1.0.0`
require.Equal(t, expected, requirements)
Expand Down Expand Up @@ -491,7 +491,7 @@ func TestPythonPackagesBothTorchAndTensorflow(t *testing.T) {
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu121
tensorflow==2.16.1
torch==2.3.1+cu121`
torch==2.3.1`
require.Equal(t, expected, requirements)
}

Expand Down Expand Up @@ -694,3 +694,27 @@ func TestSplitPinnedPythonRequirement(t *testing.T) {
}
}
}

func TestPythonRequirementsForArchWithAddedPackage(t *testing.T) {
config := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.8",
PythonPackages: []string{
"torch==2.4.0 --extra-index-url=https://download.pytorch.org/whl/cu116",
},
CUDA: "11.6.2",
},
}
err := config.ValidateAndComplete("")
require.NoError(t, err)
require.Equal(t, "11.6.2", config.Build.CUDA)
requirements, err := config.PythonRequirementsForArch("", "", []string{
"torchvision==2.4.0",
})
require.NoError(t, err)
expected := `--extra-index-url https://download.pytorch.org/whl/cu116
torch==2.4.0
torchvision==2.4.0`
require.Equal(t, expected, requirements)
}
30 changes: 29 additions & 1 deletion pkg/dockerfile/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,35 @@ func (g *BaseImageGenerator) makeConfig() (*config.Config, error) {

func (g *BaseImageGenerator) pythonPackages() []string {
if g.torchVersion != "" {
return []string{"torch==" + g.torchVersion}
pkgs := []string{"torch==" + g.torchVersion}

// Find torchvision compatibility.
for _, compat := range config.TorchCompatibilityMatrix {
if len(compat.Torchvision) == 0 {
continue
}
if !version.Matches(g.torchVersion, compat.TorchVersion()) {
continue
}

pkgs = append(pkgs, "torchvision=="+compat.Torchvision)
break
}

// Find torchaudio compatibility.
for _, compat := range config.TorchCompatibilityMatrix {
if len(compat.Torchaudio) == 0 {
continue
}
if !version.Matches(g.torchVersion, compat.TorchVersion()) {
continue
}

pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio)
break
}

return pkgs
}
return []string{}
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/dockerfile/base_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dockerfile

import (
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -68,3 +69,10 @@ func TestIsVersionCompatible(t *testing.T) {
compatible := isVersionCompatible("2.3.1+cu121", "2.3")
require.True(t, compatible)
}

func TestPythonPackages(t *testing.T) {
generator, err := NewBaseImageGenerator("12.1", "3.9", "2.1.0")
require.NoError(t, err)
pkgs := generator.pythonPackages()
require.Truef(t, reflect.DeepEqual(pkgs, []string{"torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0"}), "expected %v", pkgs)
}
17 changes: 11 additions & 6 deletions pkg/dockerfile/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,10 @@ func (g *Generator) BaseImage() (string, error) {
if err == nil || g.useCogBaseImage != nil {
return baseImage, err
}
if err != nil {
console.Warnf("Could not find a suitable base image, continuing without base image support (%v).", err)
console.Warnf("Could not find a suitable base image, continuing without base image support (%v).", err)
if g.useCogBaseImage == nil {
g.useCogBaseImage = new(bool)
*g.useCogBaseImage = false
}
}

Expand Down Expand Up @@ -394,14 +396,17 @@ func (g *Generator) installCog() (string, error) {

func (g *Generator) pipInstalls() (string, error) {
var err error
excludePackages := []string{}
includePackages := []string{}
if torchVersion, ok := g.Config.TorchVersion(); ok {
excludePackages = []string{"torch==" + torchVersion}
includePackages = []string{"torch==" + torchVersion}
}
if torchvisionVersion, ok := g.Config.TorchvisionVersion(); ok {
excludePackages = append(excludePackages, "torchvision=="+torchvisionVersion)
includePackages = append(includePackages, "torchvision=="+torchvisionVersion)
}
if torchaudioVersion, ok := g.Config.TorchaudioVersion(); ok {
includePackages = append(includePackages, "torchaudio=="+torchaudioVersion)
}
g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, excludePackages)
g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, includePackages)
if err != nil {
return "", err
}
Expand Down
9 changes: 7 additions & 2 deletions pkg/dockerfile/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,10 @@ COPY . /src`, expectedTorchVersion)

requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
require.NoError(t, err)
require.Equal(t, "pandas==2.0.3", string(requirements))
expected = fmt.Sprintf(`--extra-index-url https://download.pytorch.org/whl/cu118
torch==%s
pandas==2.0.3`, expectedTorchVersion)
require.Equal(t, expected, string(requirements))
}
}

Expand Down Expand Up @@ -649,5 +652,7 @@ COPY . /src`

requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
require.NoError(t, err)
require.Equal(t, "pandas==2.0.3", string(requirements))
require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.1
pandas==2.0.3`, string(requirements))
}

0 comments on commit b76869f

Please sign in to comment.