diff --git a/blur/predict.py b/blur/predict.py index 6640659..8b98bff 100644 --- a/blur/predict.py +++ b/blur/predict.py @@ -1,20 +1,18 @@ import tempfile -from pathlib import Path +from cog import BasePredictor, Input, Path from PIL import Image, ImageFilter -import cog -class Predictor(cog.Predictor): - def setup(self): - pass - - @cog.input("input", type=Path, help="Input image") - @cog.input("blur", type=float, help="Blur radius", default=5) - def predict(self, input, blur): +class Predictor(BasePredictor): + def predict( + self, + image: Path = Input(description="Input image"), + blur: float = Input(description="Blur radius", default=5), + ) -> Path: if blur == 0: return input - im = Image.open(str(input)) + im = Image.open(str(image)) im = im.filter(ImageFilter.BoxBlur(blur)) out_path = Path(tempfile.mkdtemp()) / "out.png" im.save(str(out_path)) diff --git a/hello-world/predict.py b/hello-world/predict.py index 4627055..efc4bf2 100644 --- a/hello-world/predict.py +++ b/hello-world/predict.py @@ -1,9 +1,9 @@ -import cog +from cog import BasePredictor, Input -class Predictor(cog.Predictor): + +class Predictor(BasePredictor): def setup(self): self.prefix = "hello" - @cog.input("input", type=str, help="Text that will get prefixed by 'hello '") - def predict(self, input): - return f"\n\n{self.prefix} {input}\n\n" + def predict(self, text: str = Input(description="Text to prefix with 'hello '")) -> str: + return self.prefix + " " + text diff --git a/resnet/cog.yaml b/resnet/cog.yaml index ab9f12d..3a4ffe8 100644 --- a/resnet/cog.yaml +++ b/resnet/cog.yaml @@ -3,4 +3,4 @@ build: python_packages: - "pillow==8.3.1" - "tensorflow==2.5.0" -predict: "predict.py:ResNetPredictor" +predict: "predict.py:Predictor" diff --git a/resnet/predict.py b/resnet/predict.py index 897e5af..e56edc7 100644 --- a/resnet/predict.py +++ b/resnet/predict.py @@ -1,26 +1,29 @@ -import cog -from pathlib import Path -from tensorflow.keras.applications.resnet50 import ResNet50 -from tensorflow.keras.preprocessing import image -from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions +from typing import Any + import numpy as np +from cog import BasePredictor, Input, Path +from tensorflow.keras.applications.resnet50 import ( + ResNet50, + decode_predictions, + preprocess_input, +) +from tensorflow.keras.preprocessing import image as keras_image -class ResNetPredictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" - self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5') + self.model = ResNet50(weights="resnet50_weights_tf_dim_ordering_tf_kernels.h5") # Define the arguments and types the model takes as input - @cog.input("input", type=Path, help="Image to classify") - def predict(self, input): + def predict(self, image: Path = Input(description="Image to classify")) -> Any: """Run a single prediction on the model""" # Preprocess the image - img = image.load_img(input, target_size=(224, 224)) - x = image.img_to_array(img) + img = keras_image.load_img(image, target_size=(224, 224)) + x = keras_image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) # Run the prediction preds = self.model.predict(x) # Return the top 3 predictions - return str(decode_predictions(preds, top=3)[0]) + return decode_predictions(preds, top=3)[0]