Skip to content

Commit

Permalink
Update examples for new Cog types
Browse files Browse the repository at this point in the history
  • Loading branch information
bfirsh committed Jan 14, 2022
1 parent 1680c8e commit 5a9de9b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
18 changes: 8 additions & 10 deletions blur/predict.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import tempfile
from pathlib import Path

from cog import BasePredictor, Input, Path
from PIL import Image, ImageFilter
import cog


class Predictor(cog.Predictor):
def setup(self):
pass

@cog.input("input", type=Path, help="Input image")
@cog.input("blur", type=float, help="Blur radius", default=5)
def predict(self, input, blur):
class Predictor(BasePredictor):
def predict(
self,
image: Path = Input(description="Input image"),
blur: float = Input(description="Blur radius", default=5),
) -> Path:
if blur == 0:
return input
im = Image.open(str(input))
im = Image.open(str(image))
im = im.filter(ImageFilter.BoxBlur(blur))
out_path = Path(tempfile.mkdtemp()) / "out.png"
im.save(str(out_path))
Expand Down
10 changes: 5 additions & 5 deletions hello-world/predict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import cog
from cog import BasePredictor, Input

class Predictor(cog.Predictor):

class Predictor(BasePredictor):
def setup(self):
self.prefix = "hello"

@cog.input("input", type=str, help="Text that will get prefixed by 'hello '")
def predict(self, input):
return f"\n\n{self.prefix} {input}\n\n"
def predict(self, text: str = Input(description="Text to prefix with 'hello '")) -> str:
return self.prefix + " " + text
2 changes: 1 addition & 1 deletion resnet/cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ build:
python_packages:
- "pillow==8.3.1"
- "tensorflow==2.5.0"
predict: "predict.py:ResNetPredictor"
predict: "predict.py:Predictor"
27 changes: 15 additions & 12 deletions resnet/predict.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import cog
from pathlib import Path
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from typing import Any

import numpy as np
from cog import BasePredictor, Input, Path
from tensorflow.keras.applications.resnet50 import (
ResNet50,
decode_predictions,
preprocess_input,
)
from tensorflow.keras.preprocessing import image as keras_image


class ResNetPredictor(cog.Predictor):
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5')
self.model = ResNet50(weights="resnet50_weights_tf_dim_ordering_tf_kernels.h5")

# Define the arguments and types the model takes as input
@cog.input("input", type=Path, help="Image to classify")
def predict(self, input):
def predict(self, image: Path = Input(description="Image to classify")) -> Any:
"""Run a single prediction on the model"""
# Preprocess the image
img = image.load_img(input, target_size=(224, 224))
x = image.img_to_array(img)
img = keras_image.load_img(image, target_size=(224, 224))
x = keras_image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
# Run the prediction
preds = self.model.predict(x)
# Return the top 3 predictions
return str(decode_predictions(preds, top=3)[0])
return decode_predictions(preds, top=3)[0]

0 comments on commit 5a9de9b

Please sign in to comment.