Skip to content

Commit

Permalink
Extract clip embeddings to its own service
Browse files Browse the repository at this point in the history
  • Loading branch information
derneuere committed Dec 6, 2023
1 parent 14ac7b2 commit 2486f34
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 43 deletions.
7 changes: 2 additions & 5 deletions api/batch_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from api.ml_models import download_models
from api.models.long_running_job import LongRunningJob
from api.models.photo import Photo
from api.semantic_search.semantic_search import semantic_search_instance
from api.semantic_search import create_clip_embeddings


def create_batch_job(job_type, user):
Expand Down Expand Up @@ -74,9 +74,7 @@ def batch_calculate_clip_embedding(job_id, user):
if len(valid_objs) == 0:
continue

imgs_emb, magnitudes = semantic_search_instance.calculate_clip_embeddings(
imgs
)
imgs_emb, magnitudes = create_clip_embeddings(imgs)

for obj, img_emb, magnitude in zip(valid_objs, imgs_emb, magnitudes):
obj.clip_embeddings = img_emb.tolist()
Expand All @@ -87,7 +85,6 @@ def batch_calculate_clip_embedding(job_id, user):
lrj.result = {"progress": {"current": done_count, "target": count}}
lrj.save()

semantic_search_instance.unload()
build_image_similarity_index(user)
lrj.finished_at = datetime.now().replace(tzinfo=pytz.utc)
lrj.finished = True
Expand Down
4 changes: 2 additions & 2 deletions api/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import api.util as util
from api.image_similarity import search_similar_embedding
from api.semantic_search.semantic_search import semantic_search_instance
from api.semantic_search import calculate_query_embeddings


class SemanticSearchFilter(filters.SearchFilter):
Expand All @@ -26,7 +26,7 @@ def filter_queryset(self, request, queryset, view):
if request.user.semantic_search_topk > 0:
query = request.query_params.get("search")
start = datetime.datetime.now()
emb, magnitude = semantic_search_instance.calculate_query_embeddings(query)
emb, magnitude = calculate_query_embeddings(query)
elapsed = (datetime.datetime.now() - start).total_seconds()
util.logger.info(
"finished calculating query embedding - took %.2f seconds" % (elapsed)
Expand Down
20 changes: 0 additions & 20 deletions api/models/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from api.models.file import File
from api.models.user import User, get_deleted_user
from api.places365.places365 import place365_instance
from api.semantic_search.semantic_search import semantic_search_instance
from api.thumbnails import (
createAnimatedThumbnail,
createThumbnail,
Expand Down Expand Up @@ -219,25 +218,6 @@ def _save_captions(self, commit=True, caption=None):
util.logger.warning("could not save captions for image %s" % image_path)
return False

def _generate_clip_embeddings(self, commit=True):
image_path = self.thumbnail_big.path
if not self.clip_embeddings and image_path:
try:
img_emb, magnitude = semantic_search_instance.calculate_clip_embeddings(
image_path
)
self.clip_embeddings = img_emb
self.clip_embeddings_magnitude = magnitude
if commit:
self.save()
util.logger.info(
"generated clip embeddings for image %s." % (image_path)
)
except Exception:
util.logger.exception(
"could not generate clip embeddings for image %s" % image_path
)

def _generate_captions(self, commit):
try:
image_path = self.thumbnail_big.path
Expand Down
37 changes: 37 additions & 0 deletions api/semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import requests
from django.conf import settings

dir_clip_ViT_B_32_model = settings.CLIP_ROOT


def create_clip_embeddings(imgs):
json = {
"imgs": imgs,
"model": dir_clip_ViT_B_32_model,
}
clip_embeddings = requests.post(
"http://localhost:8006/clip-embeddings", json=json
).json()

imgs_emb = clip_embeddings["imgs_emb"]
magnitudes = clip_embeddings["magnitudes"]

# Convert Python lists to NumPy arrays
imgs_emb = [np.array(enc) for enc in imgs_emb]

return imgs_emb, magnitudes


def calculate_query_embeddings(query):
json = {
"query": query,
"model": dir_clip_ViT_B_32_model,
}
query_embedding = requests.post(
"http://localhost:8006/query-embeddings", json=json
).json()

emb = query_embedding["emb"]
magnitude = query_embedding["magnitude"]
return emb, magnitude
File renamed without changes.
48 changes: 48 additions & 0 deletions service/clip_embeddings/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import gevent
from flask import Flask, request
from gevent.pywsgi import WSGIServer
from semantic_search.semantic_search import semantic_search_instance

app = Flask(__name__)


def log(message):
print("clip embeddings: {}".format(message))


@app.route("/clip-embeddings", methods=["POST"])
def create_clip_embeddings():
try:
data = request.get_json()
imgs = data["imgs"]
model = data["model"]
except Exception as e:
print(str(e))
return "", 400
imgs_emb, magnitudes = semantic_search_instance.calculate_clip_embeddings(
imgs, model
)
# Convert NumPy arrays to Python lists
imgs_emb_list = [enc.tolist() for enc in imgs_emb]
magnitudes = [float(m) for m in magnitudes]
return {"imgs_emb": imgs_emb_list, "magnitudes": magnitudes}, 201


@app.route("/query-embeddings", methods=["POST"])
def calculate_query_embeddings():
try:
data = request.get_json()
query = data["query"]
model = data["model"]
except Exception as e:
print(str(e))
return "", 400
emb, magnitude = semantic_search_instance.calculate_query_embeddings(query, model)
return {"emb": emb, "magnitude": magnitude}, 201


if __name__ == "__main__":
log("service starting")
server = WSGIServer(("0.0.0.0", 8006), app)
server_thread = gevent.spawn(server.serve_forever)
gevent.joinall([server_thread])
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

import numpy as np
import PIL
from django.conf import settings
from sentence_transformers import SentenceTransformer

from api.util import logger

dir_clip_ViT_B_32_model = settings.CLIP_ROOT


class SemanticSearch:
model_is_loaded = False

def load(self):
self.load_model()
def load(self, model):
self.load_model(model)
self.model_is_loaded = True
pass

Expand All @@ -24,28 +19,28 @@ def unload(self):
self.model_is_loaded = False
pass

def load_model(self):
self.model = SentenceTransformer(dir_clip_ViT_B_32_model)
def load_model(self, model):
self.model = SentenceTransformer(model)

def calculate_clip_embeddings(self, img_paths):
def calculate_clip_embeddings(self, img_paths, model):
import torch

if not self.model_is_loaded:
self.load()
self.load(model)
imgs = []
if type(img_paths) is list:
for path in img_paths:
try:
img = PIL.Image.open(path)
imgs.append(img)
except PIL.UnidentifiedImageError:
logger.info("Error loading image: {}".format(path))
print("Error loading image: {}".format(path))
else:
try:
img = PIL.Image.open(img_paths)
imgs.append(img)
except PIL.UnidentifiedImageError:
logger.info("Error loading image: {}".format(img_paths))
print("Error loading image: {}".format(img_paths))

try:
imgs_emb = self.model.encode(imgs, batch_size=32, convert_to_tensor=True)
Expand All @@ -71,12 +66,12 @@ def calculate_clip_embeddings(self, img_paths):

return img_emb, magnitude
except Exception as e:
logger.error("Error in calculating clip embeddings: {}".format(e))
print("Error in calculating clip embeddings: {}".format(e))
raise e

def calculate_query_embeddings(self, query):
def calculate_query_embeddings(self, query, model):
if not self.model_is_loaded:
self.load()
self.load(model)

query_emb = self.model.encode([query], convert_to_tensor=True)[0].tolist()
magnitude = np.linalg.norm(query_emb)
Expand Down

0 comments on commit 2486f34

Please sign in to comment.