Skip to content

Commit

Permalink
Merge pull request #7 from atrifat/feat-replicate-cog-support
Browse files Browse the repository at this point in the history
Feat Replicate Cog Support
  • Loading branch information
atrifat authored Aug 12, 2024
2 parents 1a97492 + bdcdee7 commit 392d03c
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/push-topic-classification-cog.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Push Topic Classification Cog to Replicate

on:
# Workflow dispatch allows you to manually trigger the workflow from GitHub.com
# Go to your repo, click "Actions", click "Push to Replicate", click "Run workflow"
workflow_dispatch:
inputs:
model_name:
description: 'Enter the model name, like "alice/bunny-detector". If unset, this will default to the value of `image` in cog.yaml.'
# # Uncomment these lines to trigger the workflow on every push to the main branch
# push:
# branches:
# - main

jobs:
push_to_replicate:
name: Push Topic Classification Cog to Replicate

# If your model is large, the default GitHub Actions runner may not
# have enough disk space. If you need more space you can set up a
# bigger runner on GitHub.
runs-on: ubuntu-latest

steps:
# This action cleans up disk space to make more room for your
# model code, weights, etc.
- name: Free disk space
uses: jlumbroso/[email protected]
with:
tool-cache: false

# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: false
dotnet: false
haskell: false
large-packages: true
docker-images: false
swap-storage: true

- name: Checkout
uses: actions/checkout@v4

# This action installs Docker buildx and Cog (and optionally CUDA)
- name: Setup Cog
uses: replicate/setup-cog@v2
with:
# If you set REPLICATE_API_TOKEN in your GitHub repository secrets,
# the action will authenticate with Replicate automatically so you
# can push your model
token: ${{ secrets.REPLICATE_API_TOKEN }}

# If you trigger the workflow manually, you can specify the model name.
# If you leave it blank (or if the workflow is triggered by a push), the
# model name will be derived from the `image` value in cog.yaml.
- name: Push to Replicate
run: |
cd topic-classification-cog
if [ -n "${{ inputs.model_name }}" ]; then
cog push r8.im/${{ inputs.model_name }}
else
cog push
fi
4 changes: 4 additions & 0 deletions topic-classification-cog/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# COG Implementation
COG implementation for Replicate.

Demo in Replicate is available on [atrifat/topic-classification](https://replicate.com/atrifat/topic-classification).
29 changes: 29 additions & 0 deletions topic-classification-cog/cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
# set to true if your model requires a GPU
gpu: true
cuda: "11.8"

# a list of ubuntu apt packages to install
# system_packages:
# - "libgl1-mesa-glx"
# - "libglib2.0-0"

# python version in the form '3.11' or '3.11.4'
python_version: "3.9"

# a list of packages in the format <package-name>==<version>
python_packages:
- "transformers==4.42.3"
- "torch==2.3.1"
- "pandas==2.1.1"
- "numpy==1.26.4"
# commands run after the environment is setup
# run:
# - "echo env is ready!"
# - "echo another command if needed"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
54 changes: 54 additions & 0 deletions topic-classification-cog/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path
import torch
from transformers import pipeline
import pandas as pd
import datetime
import json


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
model_path = "cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all"
self.device = 0 if torch.cuda.is_available() else -1
self.model = pipeline(
"text-classification", model=model_path, tokenizer=model_path
)

def predict(
self,
query: str = Input(description="Text input"),
) -> str:
"""Run a single prediction on the model"""
all_result = []
request_type = type(query)
data = []
try:
data = json.loads(query)
if type(data) is not list:
data = [query]
else:
request_type = type(data)
except Exception as e:
print(e)
data = [query]
pass

start_time = datetime.datetime.now()

tokenizer_kwargs = {"truncation": True, "max_length": 512}
all_result = self.model(data, batch_size=128,
top_k=3, **tokenizer_kwargs)

end_time = datetime.datetime.now()
elapsed_time = end_time - start_time

output = {}
output["time"] = str(elapsed_time)
output["device"] = self.device
output["result"] = all_result

return json.dumps(all_result[0]) if request_type is str else json.dumps(output)

0 comments on commit 392d03c

Please sign in to comment.