diff --git a/.gitignore b/.gitignore index 454fd13920..ed7e00fbfc 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ pkg/dockerfile/embed/*.whl docs/README.md docs/CONTRIBUTING.md venv +base-image diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 812fa7f5dc..e4761a5210 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -23,6 +23,8 @@ var buildUseCudaBaseImage string var buildDockerfileFile string var buildUseCogBaseImage bool +const useCogBaseImageFlagKey = "use-cog-base-image" + func newBuildCommand() *cobra.Command { cmd := &cobra.Command{ Use: "build", @@ -63,13 +65,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, buildUseCogBaseImage); err != nil { - if buildUseCogBaseImage && cmd.Flags().Changed("use-cog-base-image") { - console.Infof("Build failed with Cog base image enabled by default. " + - "If you want to build without using pre-built base images, " + - "try `cog build --use-cog-base-image=false`.") - } - + if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd)); err != nil { return err } @@ -116,7 +112,7 @@ func addDockerfileFlag(cmd *cobra.Command) { } func addUseCogBaseImageFlag(cmd *cobra.Command) { - cmd.Flags().BoolVar(&buildUseCogBaseImage, "use-cog-base-image", true, "Use pre-built Cog base image for faster cold boots") + cmd.Flags().BoolVar(&buildUseCogBaseImage, useCogBaseImageFlagKey, true, "Use pre-built Cog base image for faster cold boots") } func addBuildTimestampFlag(cmd *cobra.Command) { @@ -125,7 +121,7 @@ func addBuildTimestampFlag(cmd *cobra.Command) { } func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error { - flags := []string{"use-cog-base-image", "use-cuda-base-image", "dockerfile"} + flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"} var flagsSet []string for _, flag := range flags { if cmd.Flag(flag).Changed { @@ -137,3 +133,12 @@ func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error { } return nil } + +func DetermineUseCogBaseImage(cmd *cobra.Command) *bool { + if !cmd.Flags().Changed(useCogBaseImageFlagKey) { + return nil + } + useCogBaseImage := new(bool) + *useCogBaseImage = buildUseCogBaseImage + return useCogBaseImage +} diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 2ef601da64..f433e0b173 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -48,7 +48,10 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error { }() generator.SetUseCudaBaseImage(buildUseCudaBaseImage) - generator.SetUseCogBaseImage(buildUseCogBaseImage) + useCogBaseImage := DetermineUseCogBaseImage(cmd) + if useCogBaseImage != nil { + generator.SetUseCogBaseImage(*useCogBaseImage) + } if buildSeparateWeights { if imageName == "" { diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index c106678d54..ee32d664e6 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -73,7 +73,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildUseCogBaseImage, buildProgressOutput); err != nil { + if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { return err } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 6e088b1c7f..61a4de2a87 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -56,12 +56,7 @@ func push(cmd *cobra.Command, args []string) error { } } - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, buildUseCogBaseImage); err != nil { - if buildUseCogBaseImage && cmd.Flags().Changed("use-cog-base-image") { - console.Infof("Push failed with Cog base image enabled by default. " + - "If you want to push without using pre-built base images, " + - "try `cog push --use-cog-base-image=false`.") - } + if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd)); err != nil { return err } diff --git a/pkg/cli/run.go b/pkg/cli/run.go index 17a31a1ad8..32f4b2cc29 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -54,7 +54,7 @@ func run(cmd *cobra.Command, args []string) error { if err != nil { return err } - imageName, err := image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildUseCogBaseImage, buildProgressOutput) + imageName, err := image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput) if err != nil { return err } diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 22d0693540..a33c92e74a 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -62,7 +62,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, buildUseCogBaseImage, buildProgressOutput); err != nil { + if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { return err } diff --git a/pkg/dockerfile/generator.go b/pkg/dockerfile/generator.go index 6c3f78f8ef..ee17620fa4 100644 --- a/pkg/dockerfile/generator.go +++ b/pkg/dockerfile/generator.go @@ -52,7 +52,7 @@ type Generator struct { GOARCH string useCudaBaseImage bool - useCogBaseImage bool + useCogBaseImage *bool // absolute path to tmpDir, a directory that will be cleaned up tmpDir string @@ -93,7 +93,7 @@ func NewGenerator(config *config.Config, dir string) (*Generator, error) { relativeTmpDir: relativeTmpDir, fileWalker: filepath.Walk, useCudaBaseImage: true, - useCogBaseImage: false, + useCogBaseImage: nil, }, nil } @@ -103,11 +103,16 @@ func (g *Generator) SetUseCudaBaseImage(argumentValue string) { } func (g *Generator) SetUseCogBaseImage(useCogBaseImage bool) { - g.useCogBaseImage = useCogBaseImage + g.useCogBaseImage = new(bool) + *g.useCogBaseImage = useCogBaseImage } func (g *Generator) IsUsingCogBaseImage() bool { - return g.useCogBaseImage + useCogBaseImage := g.useCogBaseImage + if useCogBaseImage != nil { + return *useCogBaseImage + } + return true } func (g *Generator) generateInitialSteps() (string, error) { @@ -128,7 +133,7 @@ func (g *Generator) generateInitialSteps() (string, error) { return "", err } - if g.useCogBaseImage { + if g.IsUsingCogBaseImage() { pipInstalls, err := g.pipInstalls() if err != nil { return "", err @@ -264,37 +269,14 @@ func (g *Generator) Cleanup() error { } func (g *Generator) BaseImage() (string, error) { - if g.useCogBaseImage { - var changed bool - var err error - - cudaVersion := g.Config.Build.CUDA - - pythonVersion := g.Config.Build.PythonVersion - pythonVersion, changed, err = stripPatchVersion(pythonVersion) - if err != nil { - return "", err - } - if changed { - console.Warnf("Stripping patch version from Python version %s to %s", g.Config.Build.PythonVersion, pythonVersion) - } - - torchVersion, _ := g.Config.TorchVersion() - torchVersion, changed, err = stripPatchVersion(torchVersion) - if err != nil { - return "", err + if g.IsUsingCogBaseImage() { + baseImage, err := g.determineBaseImageName() + if err == nil || g.useCogBaseImage != nil { + return baseImage, err } - if changed { - console.Warnf("Stripping patch version from Torch version %s to %s", g.Config.Build.PythonVersion, pythonVersion) - } - - // validate that the base image configuration exists - imageGenerator, err := NewBaseImageGenerator(cudaVersion, pythonVersion, torchVersion) if err != nil { - return "", err + console.Warnf("Could not find a suitable base image, continuing without base image support (%v).", err) } - baseImage := BaseImageName(imageGenerator.cudaVersion, imageGenerator.pythonVersion, imageGenerator.torchVersion) - return baseImage, nil } if g.Config.Build.GPU && g.useCudaBaseImage { @@ -336,7 +318,7 @@ func (g *Generator) aptInstalls() (string, error) { return "", nil } - if g.useCogBaseImage { + if g.IsUsingCogBaseImage() { packages = slices.FilterString(packages, func(pkg string) bool { return !slices.ContainsString(baseImageSystemPackages, pkg) }) @@ -348,7 +330,7 @@ func (g *Generator) aptInstalls() (string, error) { } func (g *Generator) installPython() (string, error) { - if g.Config.Build.GPU && g.useCudaBaseImage && !g.useCogBaseImage { + if g.Config.Build.GPU && g.useCudaBaseImage && !g.IsUsingCogBaseImage() { return g.installPythonCUDA() } return "", nil @@ -487,7 +469,7 @@ func (g *Generator) copyPipPackagesFromInstallStage() string { // return "COPY --from=deps --link /dep COPY --from=deps /src" // ...except it's actually /root/.pyenv/versions/3.8.17/lib/python3.8/site-packages py := g.Config.Build.PythonVersion - if g.Config.Build.GPU && (g.useCudaBaseImage || g.useCogBaseImage) { + if g.Config.Build.GPU && (g.useCudaBaseImage || g.IsUsingCogBaseImage()) { // this requires buildkit! // we should check for buildkit and otherwise revert to symlinks or copying into /src // we mount to avoid copying, which avoids having two copies in this layer @@ -596,6 +578,39 @@ func (g *Generator) GenerateWeightsManifest() (*weights.Manifest, error) { return m, nil } +func (g *Generator) determineBaseImageName() (string, error) { + var changed bool + var err error + + cudaVersion := g.Config.Build.CUDA + + pythonVersion := g.Config.Build.PythonVersion + pythonVersion, changed, err = stripPatchVersion(pythonVersion) + if err != nil { + return "", err + } + if changed { + console.Warnf("Stripping patch version from Python version %s to %s", g.Config.Build.PythonVersion, pythonVersion) + } + + torchVersion, _ := g.Config.TorchVersion() + torchVersion, changed, err = stripPatchVersion(torchVersion) + if err != nil { + return "", err + } + if changed { + console.Warnf("Stripping patch version from Torch version %s to %s", g.Config.Build.PythonVersion, pythonVersion) + } + + // validate that the base image configuration exists + imageGenerator, err := NewBaseImageGenerator(cudaVersion, pythonVersion, torchVersion) + if err != nil { + return "", err + } + baseImage := BaseImageName(imageGenerator.cudaVersion, imageGenerator.pythonVersion, imageGenerator.torchVersion) + return baseImage, nil +} + func stripPatchVersion(versionString string) (string, bool, error) { if versionString == "" { return "", false, nil diff --git a/pkg/dockerfile/generator_test.go b/pkg/dockerfile/generator_test.go index b3dab6487e..77e73e8ff5 100644 --- a/pkg/dockerfile/generator_test.go +++ b/pkg/dockerfile/generator_test.go @@ -91,6 +91,7 @@ predict: predict.py:Predictor gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) @@ -123,6 +124,7 @@ predict: predict.py:Predictor require.NoError(t, conf.ValidateAndComplete("")) gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) @@ -167,6 +169,7 @@ predict: predict.py:Predictor gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) @@ -218,6 +221,7 @@ predict: predict.py:Predictor gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) @@ -268,6 +272,7 @@ build: gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) @@ -303,6 +308,7 @@ build: gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) fmt.Println(actual) @@ -356,6 +362,7 @@ predict: predict.py:Predictor gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) gen.fileWalker = func(root string, walkFn filepath.WalkFunc) error { for _, path := range []string{"checkpoints/large-a", "models/large-b", "root-large"} { @@ -453,6 +460,7 @@ predict: predict.py:Predictor gen, err := NewGenerator(conf, tmpDir) require.NoError(t, err) + gen.SetUseCogBaseImage(false) actual, err := gen.GenerateDockerfileWithoutSeparateWeights() require.NoError(t, err) diff --git a/pkg/image/build.go b/pkg/image/build.go index 56d233b325..73c4c1ffb0 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -28,7 +28,7 @@ const bundledSchemaPy = ".cog/schema.py" // Build a Cog model from a config // // This is separated out from docker.Build(), so that can be as close as possible to the behavior of 'docker build'. -func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string, dockerfileFile string, useCogBaseImage bool) error { +func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string, dockerfileFile string, useCogBaseImage *bool) error { console.Infof("Building Docker image from environment in cog.yaml as %s...", imageName) // remove bundled schema files that may be left from previous builds @@ -56,7 +56,9 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, } }() generator.SetUseCudaBaseImage(useCudaBaseImage) - generator.SetUseCogBaseImage(useCogBaseImage) + if useCogBaseImage != nil { + generator.SetUseCogBaseImage(*useCogBaseImage) + } if generator.IsUsingCogBaseImage() { cogBaseImageName, err = generator.BaseImage() @@ -222,7 +224,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, return nil } -func BuildBase(cfg *config.Config, dir string, useCudaBaseImage string, useCogBaseImage bool, progressOutput string) (string, error) { +func BuildBase(cfg *config.Config, dir string, useCudaBaseImage string, useCogBaseImage *bool, progressOutput string) (string, error) { // TODO: better image management so we don't eat up disk space // https://github.com/replicate/cog/issues/80 imageName := config.BaseDockerImageName(dir) @@ -239,7 +241,9 @@ func BuildBase(cfg *config.Config, dir string, useCudaBaseImage string, useCogBa }() generator.SetUseCudaBaseImage(useCudaBaseImage) - generator.SetUseCogBaseImage(useCogBaseImage) + if useCogBaseImage != nil { + generator.SetUseCogBaseImage(*useCogBaseImage) + } dockerfileContents, err := generator.GenerateModelBase() if err != nil { diff --git a/test-integration/test_integration/fixtures/torch-baseimage-project/cog.yaml b/test-integration/test_integration/fixtures/torch-baseimage-project/cog.yaml new file mode 100644 index 0000000000..0256e2aa36 --- /dev/null +++ b/test-integration/test_integration/fixtures/torch-baseimage-project/cog.yaml @@ -0,0 +1,6 @@ +build: + gpu: true + python_version: "3.9" + python_packages: + - "torch==1.13.0" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/torch-baseimage-project/openapi.json b/test-integration/test_integration/fixtures/torch-baseimage-project/openapi.json new file mode 100644 index 0000000000..a10a192304 --- /dev/null +++ b/test-integration/test_integration/fixtures/torch-baseimage-project/openapi.json @@ -0,0 +1,531 @@ +{ + "info": { + "title": "Cog", + "version": "0.1.0" + }, + "paths": { + "/": { + "get": { + "summary": "Root", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Root Get" + } + } + }, + "description": "Successful Response" + } + }, + "operationId": "root__get" + } + }, + "/shutdown": { + "post": { + "summary": "Start Shutdown", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Start Shutdown Shutdown Post" + } + } + }, + "description": "Successful Response" + } + }, + "operationId": "start_shutdown_shutdown_post" + } + }, + "/predictions": { + "post": { + "summary": "Predict", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "parameters": [ + { + "in": "header", + "name": "prefer", + "schema": { + "type": "string", + "title": "Prefer" + }, + "required": false + } + ], + "description": "Run a single prediction on the model", + "operationId": "predict_predictions_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionRequest" + } + } + } + } + } + }, + "/health-check": { + "get": { + "summary": "Healthcheck", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Healthcheck Health Check Get" + } + } + }, + "description": "Successful Response" + } + }, + "operationId": "healthcheck_health_check_get" + } + }, + "/predictions/{prediction_id}": { + "put": { + "summary": "Predict Idempotent", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "parameters": [ + { + "in": "path", + "name": "prediction_id", + "schema": { + "type": "string", + "title": "Prediction ID" + }, + "required": true + }, + { + "in": "header", + "name": "prefer", + "schema": { + "type": "string", + "title": "Prefer" + }, + "required": false + } + ], + "description": "Run a single prediction on the model (idempotent creation).", + "operationId": "predict_idempotent_predictions__prediction_id__put", + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [ + { + "$ref": "#/components/schemas/PredictionRequest" + } + ], + "title": "Prediction Request" + } + } + }, + "required": true + } + } + }, + "/predictions/{prediction_id}/cancel": { + "post": { + "summary": "Cancel", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Cancel Predictions Prediction Id Cancel Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "parameters": [ + { + "in": "path", + "name": "prediction_id", + "schema": { + "type": "string", + "title": "Prediction ID" + }, + "required": true + } + ], + "description": "Cancel a running prediction", + "operationId": "cancel_predictions__prediction_id__cancel_post" + } + } + }, + "openapi": "3.1.0", + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "properties": { + "seed": { + "type": "integer", + "title": "Seed", + "x-order": 8, + "description": "Random seed. Leave blank to randomize the seed" + }, + "width": { + "allOf": [ + { + "$ref": "#/components/schemas/width" + } + ], + "default": 768, + "x-order": 2, + "description": "Width of generated image in pixels. Needs to be a multiple of 64" + }, + "height": { + "allOf": [ + { + "$ref": "#/components/schemas/height" + } + ], + "default": 768, + "x-order": 1, + "description": "Height of generated image in pixels. Needs to be a multiple of 64" + }, + "prompt": { + "type": "string", + "title": "Prompt", + "default": "a vision of paradise. unreal engine", + "x-order": 0, + "description": "Input prompt" + }, + "scheduler": { + "allOf": [ + { + "$ref": "#/components/schemas/scheduler" + } + ], + "default": "DPMSolverMultistep", + "x-order": 7, + "description": "Choose a scheduler." + }, + "num_outputs": { + "type": "integer", + "title": "Num Outputs", + "default": 1, + "maximum": 4, + "minimum": 1, + "x-order": 4, + "description": "Number of images to generate." + }, + "guidance_scale": { + "type": "number", + "title": "Guidance Scale", + "default": 7.5, + "maximum": 20, + "minimum": 1, + "x-order": 6, + "description": "Scale for classifier-free guidance" + }, + "negative_prompt": { + "type": "string", + "title": "Negative Prompt", + "x-order": 3, + "description": "Specify things to not see in the output" + }, + "num_inference_steps": { + "type": "integer", + "title": "Num Inference Steps", + "default": 50, + "maximum": 500, + "minimum": 1, + "x-order": 5, + "description": "Number of denoising steps" + } + } + }, + "width": { + "enum": [ + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 576, + 640, + 704, + 768, + 832, + 896, + 960, + 1024 + ], + "type": "integer", + "title": "width", + "description": "An enumeration." + }, + "Output": { + "type": "array", + "items": { + "type": "string", + "format": "uri" + }, + "title": "Output" + }, + "Status": { + "enum": [ + "starting", + "processing", + "succeeded", + "canceled", + "failed" + ], + "type": "string", + "title": "Status", + "description": "An enumeration." + }, + "height": { + "enum": [ + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 576, + 640, + 704, + 768, + 832, + 896, + 960, + 1024 + ], + "type": "integer", + "title": "height", + "description": "An enumeration." + }, + "scheduler": { + "enum": [ + "DDIM", + "K_EULER", + "DPMSolverMultistep", + "K_EULER_ANCESTRAL", + "PNDM", + "KLMS" + ], + "type": "string", + "title": "scheduler", + "description": "An enumeration." + }, + "WebhookEvent": { + "enum": [ + "start", + "output", + "logs", + "completed" + ], + "type": "string", + "title": "WebhookEvent", + "description": "An enumeration." + }, + "ValidationError": { + "type": "object", + "title": "ValidationError", + "required": [ + "loc", + "msg", + "type" + ], + "properties": { + "loc": { + "type": "array", + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + } + }, + "PredictionRequest": { + "type": "object", + "title": "PredictionRequest", + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "input": { + "$ref": "#/components/schemas/Input" + }, + "webhook": { + "type": "string", + "title": "Webhook", + "format": "uri", + "maxLength": 65536, + "minLength": 1 + }, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time" + }, + "output_file_prefix": { + "type": "string", + "title": "Output File Prefix" + }, + "webhook_events_filter": { + "type": "array", + "items": { + "$ref": "#/components/schemas/WebhookEvent" + }, + "default": [ + "completed", + "logs", + "output", + "start" + ], + "uniqueItems": true + } + } + }, + "PredictionResponse": { + "type": "object", + "title": "PredictionResponse", + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "logs": { + "type": "string", + "title": "Logs", + "default": "" + }, + "error": { + "type": "string", + "title": "Error" + }, + "input": { + "$ref": "#/components/schemas/Input" + }, + "output": { + "$ref": "#/components/schemas/Output" + }, + "status": { + "$ref": "#/components/schemas/Status" + }, + "metrics": { + "type": "object", + "title": "Metrics" + }, + "version": { + "type": "string", + "title": "Version" + }, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time" + }, + "started_at": { + "type": "string", + "title": "Started At", + "format": "date-time" + }, + "completed_at": { + "type": "string", + "title": "Completed At", + "format": "date-time" + } + } + }, + "HTTPValidationError": { + "type": "object", + "title": "HTTPValidationError", + "properties": { + "detail": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail" + } + } + } + } + } +} diff --git a/test-integration/test_integration/fixtures/torch-baseimage-project/predict.py b/test-integration/test_integration/fixtures/torch-baseimage-project/predict.py new file mode 100644 index 0000000000..44f6992b01 --- /dev/null +++ b/test-integration/test_integration/fixtures/torch-baseimage-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index fa5c9136f0..90d28a2fc1 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -240,3 +240,49 @@ def test_build_base_image_sha(docker_image): base_layer_hash = labels["run.cog.cog-base-image-last-layer-sha"] layers = image[0]["RootFS"]["Layers"] assert base_layer_hash in layers + + +def test_torch_1_13_0_base_image_fallback(docker_image): + project_dir = Path(__file__).parent / "fixtures/torch-baseimage-project" + build_process = subprocess.run( + ["cog", "build", "-t", docker_image, "--openapi-schema", "openapi.json"], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 0 + + +def test_torch_1_13_0_base_image_fail(docker_image): + project_dir = Path(__file__).parent / "fixtures/torch-baseimage-project" + build_process = subprocess.run( + [ + "cog", + "build", + "-t", + docker_image, + "--openapi-schema", + "openapi.json", + "--use-cog-base-image", + ], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 1 + + +def test_torch_1_13_0_base_image_fail_explicit(docker_image): + project_dir = Path(__file__).parent / "fixtures/torch-baseimage-project" + build_process = subprocess.run( + [ + "cog", + "build", + "-t", + docker_image, + "--openapi-schema", + "openapi.json", + "--use-cog-base-image=false", + ], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 0