Skip to content

Commit

Permalink
Add pest detection code to main
Browse files Browse the repository at this point in the history
  • Loading branch information
KastanDay committed Mar 28, 2024
1 parent a450b12 commit d325d9a
Showing 1 changed file with 154 additions and 0 deletions.
154 changes: 154 additions & 0 deletions ai_ta_backend/modal/pest_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Run with: $ modal serve ai_ta_backend/modal/pest_detection.py
Deploy with with: $ modal deploy ai_ta_backend/modal/pest_detection.py
Just send a post request here: https://kastanday--v2-pest-detection-yolo-model-predict.modal.run/
with body:
{
"image_urls": [
"https://www.arborday.org/trees/health/pests/images/figure-whiteflies-1.jpg",
"https://www.arborday.org/trees/health/pests/images/figure-japanese-beetle-3.jpg"
]
}
Inspired by https://modal.com/docs/examples/webcam#prediction-function
"""
import inspect
import json
import os
import traceback
import uuid
from tempfile import NamedTemporaryFile
from typing import List

import modal
from fastapi import Request
from modal import Secret, Stub, build, enter, web_endpoint

# Simpler image, but slower cold starts: modal.Image.from_registry('ultralytics/ultralytics:latest-cpu')
image = (
modal.Image.debian_slim(python_version="3.10").apt_install("libgl1-mesa-glx", "libglib2.0-0")
# .run_commands(["apt-get install -y libgl1-mesa-glx libglib2.0-0 wget"])
.pip_install(
"opencv-python",
"torch==2.2.0",
"ultralytics==8.1.0",
"torchvision==0.17.0",
"boto3==1.28.79",
"fastapi==0.109.2",
"pillow",
))
stub = Stub("v2_pest_detection_yolo", image=image)

# Imports needed inside the image
with image.imports():
import inspect
import os
import traceback
import uuid
from tempfile import NamedTemporaryFile
from typing import List

import boto3
import requests
from PIL import Image
from ultralytics import YOLO


@stub.cls(cpu=1, image=image, secrets=[Secret.from_name("uiuc-chat-aws")])
class Model:
"""
1. Build (bake things into the image for faster subsequent startups)
2. Enter (run once per container start)
3. Web Endpoint (serve a web endpoint)
"""

@build()
def download_model(self):
"""Model weights are downloaded once at image build time using the @build hook and saved into the image. 'Baking' models into the modal.Image at build time provided the fastest cold start. """
model_url = "https://assets.kastan.ai/pest_detection_model_weights.pt"
response = requests.get(model_url)

model_path = "/cache/pest_detection_model_weights.pt"
os.makedirs("/cache/", exist_ok=True)
with open(model_path, 'wb') as f:
f.write(response.content)

@enter()
def run_this_on_container_startup(self):
"""Runs once per container start. Like __init__ but for the container."""
self.model = YOLO('/cache/pest_detection_model_weights.pt')
self.s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
)

@web_endpoint(method="POST")
async def predict(self, request: Request):
"""
This used to use the method decorator
Run the pest detection plugin on an image.
"""
print("Inside predict() endpoint")

print("Request: ", request)
input = await request.json()
print("Request.json(): ", input)
image_urls = input.get('image_urls', [])
print(f"Image URLS (no parsing): {image_urls}")

# image_urls = json.loads(image_urls)
# print(f"json parsed Image URLS: {image_urls}")

try:
# Run the plugin
annotated_images = self._detect_pests(image_urls)
print(f"annotated_images found: {len(annotated_images)}")
results = []
# Generate a unique ID for the request
unique_id = uuid.uuid4()
# self.posthog.capture('distinct_id_of_the_user',
# event='run_pest_detection_invoked',
# properties={
# 'image_urls': image_urls,
# 'unique_id': unique_id,
# })
for index, image in enumerate(annotated_images):
# Infer the file extension from the image URL or set a default
file_extension = '.png'
image_format = file_extension[1:].upper()

with NamedTemporaryFile(mode='wb', suffix=file_extension) as tmpfile:
# Save the image with the specified format
image.save(tmpfile, format=image_format)
tmpfile.flush() # Ensure all data is written to the file
tmpfile.seek(0) # Move the file pointer to the start of the file
# Include UUID and index in the s3_path
s3_path = f'pest_detection/annotated-{unique_id}-{index}{file_extension}'
# Upload the file to S3
with open(tmpfile.name, 'rb') as file_data:
self.s3_client.upload_fileobj(Fileobj=file_data, Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path)
results.append(s3_path)
return results
except Exception as e:
err = f"❌❌ Error in (pest_detection): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" # type: ignore
print(err)
# sentry_sdk.capture_exception(e)
return err

def _detect_pests(self, image_paths: List[str]) -> List[Image.Image]:
""" Run pest detection on the given images. """
# Run inference
results = self.model(image_paths) # results object with inference results

annotated_images = []

# Extract annotated images from the results object
# Flatten the results object to get the annotated images for each input image
for result_set in results:
for r in result_set:
im_array = r.plot() # plot a BGR numpy array of predictions
im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image (annotated with bounding boxes and class labels)
annotated_images.append(im)
return annotated_images

0 comments on commit d325d9a

Please sign in to comment.