Skip to content

Commit

Permalink
Create image captioning service
Browse files Browse the repository at this point in the history
  • Loading branch information
derneuere committed Dec 9, 2023
1 parent 2486f34 commit 605cb59
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 23 deletions.
11 changes: 0 additions & 11 deletions api/im2txt/download.sh

This file was deleted.

4 changes: 0 additions & 4 deletions api/im2txt/requirements.txt

This file was deleted.

26 changes: 26 additions & 0 deletions api/image_captioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import requests


def generate_caption(image_path, onnx, blip):
json = {
"image_path": image_path,
"onnx": onnx,
"blip": blip,
}
caption_response = requests.post(
"http://localhost:8007/generate-caption", json=json
).json()

return caption_response["caption"]


def unload_model():
requests.get("http://localhost:8007/unload-model")


def export_onnx(encoder_path, decoder_path):
json = {
"encoder_path": encoder_path,
"decoder_path": decoder_path,
}
requests.get("http://localhost:8007/export-onnx", json=json)
4 changes: 2 additions & 2 deletions api/models/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from api.face_recognition import get_face_encodings, get_face_locations
from api.geocode import GEOCODE_VERSION
from api.geocode.geocode import reverse_geocode
from api.im2txt.sample import Im2txt
from api.image_captioning import generate_caption
from api.models.file import File
from api.models.user import User, get_deleted_user
from api.places365.places365 import place365_instance
Expand Down Expand Up @@ -175,7 +175,7 @@ def _generate_captions_im2txt(self, commit=True):
if site_config.CAPTIONING_MODEL == "blip_base_capfilt_large":
blip = True

caption = Im2txt(blip=blip).generate_caption(image_path, onnx)
caption = generate_caption(image_path=image_path, blip=blip, onnx=onnx)
caption = (
caption.replace("<start>", "").replace("<end>", "").strip().lower()
)
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import onnxruntime as ort
import torch
from django.conf import settings
from numpy import asarray
from PIL import Image
from torchvision import transforms
Expand All @@ -17,9 +16,10 @@
embed_size = 256
hidden_size = 512
num_layers = 1
im2txt_models_path = settings.IM2TXT_ROOT
im2txt_onnx_models_path = settings.IM2TXT_ONNX_ROOT
blip_models_path = settings.BLIP_ROOT

im2txt_models_path = "/protected_media/data_models/im2txt"
im2txt_onnx_models_path = "/protected_media/data_models/im2txt_onnx"
blip_models_path = "/protected_media/data_models/blip"

encoder_path = os.path.join(im2txt_models_path, "models", "encoder-10-1000.ckpt")
decoder_path = os.path.join(im2txt_models_path, "models", "decoder-10-1000.ckpt")
Expand Down Expand Up @@ -117,10 +117,12 @@ def load_models(self, onnx=False):
# self.decoder = torch.compile(self.decoder)

def unload_models(self):
self.encoder.__del__()
self.decoder.__del__()
del self.encoder
del self.decoder
del self.model
self.encoder = None
self.decoder = None
self.model = None

def generate_caption(
self,
Expand Down
File renamed without changes.
86 changes: 86 additions & 0 deletions service/image_captioning/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import threading
import time

import gevent
from flask import Flask, request
from gevent.pywsgi import WSGIServer

from api.im2txt.sample import Im2txt

app = Flask(__name__)

im2txt_instance = None
last_request_time = time.time()


def log(message):
print("image_captioning: {}".format(message))


@app.route("/generate-caption", methods=["POST"])
def generate_caption():
global last_request_time
# Update last request time
last_request_time = time.time()

try:
data = request.get_json()
image_path = data["image_path"]
onnx = data["onnx"]
blip = data["blip"]
except Exception as e:
print(str(e))
return "", 400

global im2txt_instance

if im2txt_instance is None:
im2txt_instance = Im2txt(blip=blip)

return {
"caption": im2txt_instance.generate_caption(image_path=image_path, onnx=onnx)
}, 201


@app.route("/unload-model", methods=["GET"])
def unload_model():
global im2txt_instance
im2txt_instance.unload_models()
im2txt_instance = None
return "", 200


@app.route("/export-onnx", methods=["GET"])
def export_onnx():
global im2txt_instance
if im2txt_instance is None:
im2txt_instance = Im2txt()
data = request.get_json()
encoder_path = data["encoder_path"]
decoder_path = data["decoder_path"]
im2txt_instance.export_onnx(
encoder_output_path=encoder_path, decoder_output_path=decoder_path
)
return "", 200


def check_inactivity():
global last_request_time
idle_threshold = 30

while True:
time.sleep(1)
idle_time = time.time() - last_request_time
if idle_time > idle_threshold and im2txt_instance is not None:
print("Unloading model due to inactivity")
unload_model()


threading.Thread(target=check_inactivity).start()


if __name__ == "__main__":
log("service starting")
server = WSGIServer(("0.0.0.0", 8007), app)
server_thread = gevent.spawn(server.serve_forever)
gevent.joinall([server_thread])

0 comments on commit 605cb59

Please sign in to comment.