Skip to content

Commit

Permalink
build: replicate image cog support
Browse files Browse the repository at this point in the history
  • Loading branch information
atrifat committed Aug 12, 2024
1 parent 1a97492 commit aa02433
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
4 changes: 4 additions & 0 deletions topic-classification-cog/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# COG Implementation
COG implementation for Replicate.

Demo in Replicate is available on [atrifat/topic-classification](https://replicate.com/atrifat/topic-classification).
29 changes: 29 additions & 0 deletions topic-classification-cog/cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
# set to true if your model requires a GPU
gpu: true
cuda: "11.8"

# a list of ubuntu apt packages to install
# system_packages:
# - "libgl1-mesa-glx"
# - "libglib2.0-0"

# python version in the form '3.11' or '3.11.4'
python_version: "3.9"

# a list of packages in the format <package-name>==<version>
python_packages:
- "transformers==4.42.3"
- "torch==2.3.1"
- "pandas==2.1.1"
- "numpy==1.26.4"
# commands run after the environment is setup
# run:
# - "echo env is ready!"
# - "echo another command if needed"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
54 changes: 54 additions & 0 deletions topic-classification-cog/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path
import torch
from transformers import pipeline
import pandas as pd
import datetime
import json


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
model_path = "cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all"
self.device = 0 if torch.cuda.is_available() else -1
self.model = pipeline(
"text-classification", model=model_path, tokenizer=model_path
)

def predict(
self,
query: str = Input(description="Text input"),
) -> str:
"""Run a single prediction on the model"""
all_result = []
request_type = type(query)
data = []
try:
data = json.loads(query)
if type(data) is not list:
data = [query]
else:
request_type = type(data)
except Exception as e:
print(e)
data = [query]
pass

start_time = datetime.datetime.now()

tokenizer_kwargs = {"truncation": True, "max_length": 512}
all_result = self.model(data, batch_size=128,
top_k=3, **tokenizer_kwargs)

end_time = datetime.datetime.now()
elapsed_time = end_time - start_time

output = {}
output["time"] = str(elapsed_time)
output["device"] = self.device
output["result"] = all_result

return json.dumps(all_result[0]) if request_type is str else json.dumps(output)

0 comments on commit aa02433

Please sign in to comment.