From 8a4ba7a8e5d1cb64760315122961ef680c16fae3 Mon Sep 17 00:00:00 2001 From: Niaz Date: Sun, 17 Sep 2023 13:12:55 +0200 Subject: [PATCH] Handle converting of tensors conditionally --- api/batch_jobs.py | 2 -- api/semantic_search/semantic_search.py | 27 +++++++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/api/batch_jobs.py b/api/batch_jobs.py index d876f75b98..d7971a103b 100644 --- a/api/batch_jobs.py +++ b/api/batch_jobs.py @@ -32,8 +32,6 @@ def create_batch_job(job_type, user): def batch_calculate_clip_embedding(job_id, user): import torch - # Only supports CPU - torch.device("cpu") lrj = LongRunningJob.objects.get(job_id=job_id) lrj.started_at = datetime.now().replace(tzinfo=pytz.utc) diff --git a/api/semantic_search/semantic_search.py b/api/semantic_search/semantic_search.py index ce1d557d7e..88ad519adc 100644 --- a/api/semantic_search/semantic_search.py +++ b/api/semantic_search/semantic_search.py @@ -28,6 +28,8 @@ def load_model(self): self.model = SentenceTransformer(dir_clip_ViT_B_32_model) def calculate_clip_embeddings(self, img_paths): + import torch + if not self.model_is_loaded: self.load() imgs = [] @@ -47,14 +49,25 @@ def calculate_clip_embeddings(self, img_paths): try: imgs_emb = self.model.encode(imgs, batch_size=32, convert_to_tensor=True) - - if type(img_paths) is list: - magnitudes = map(np.linalg.norm, imgs_emb) - - return imgs_emb, magnitudes + if torch.cuda.is_available(): + if type(img_paths) is list: + magnitudes = list( + map(lambda x: np.linalg.norm(x.cpu().numpy()), imgs_emb) + ) + + return imgs_emb, magnitudes + else: + img_emb = imgs_emb[0].cpu().numpy().tolist() + magnitude = np.linalg.norm(img_emb) + + return img_emb, magnitude else: - img_emb = imgs_emb[0].tolist() - magnitude = np.linalg.norm(img_emb) + if type(img_paths) is list: + magnitudes = map(np.linalg.norm, imgs_emb) + return imgs_emb, magnitudes + else: + img_emb = imgs_emb[0].tolist() + magnitude = np.linalg.norm(img_emb) return img_emb, magnitude except Exception as e: