Skip to content

Commit

Permalink
Handle converting of tensors conditionally
Browse files Browse the repository at this point in the history
  • Loading branch information
derneuere committed Sep 17, 2023
1 parent 8499911 commit 8a4ba7a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
2 changes: 0 additions & 2 deletions api/batch_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def create_batch_job(job_type, user):
def batch_calculate_clip_embedding(job_id, user):
import torch

# Only supports CPU
torch.device("cpu")
lrj = LongRunningJob.objects.get(job_id=job_id)
lrj.started_at = datetime.now().replace(tzinfo=pytz.utc)

Expand Down
27 changes: 20 additions & 7 deletions api/semantic_search/semantic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def load_model(self):
self.model = SentenceTransformer(dir_clip_ViT_B_32_model)

def calculate_clip_embeddings(self, img_paths):
import torch

if not self.model_is_loaded:
self.load()
imgs = []
Expand All @@ -47,14 +49,25 @@ def calculate_clip_embeddings(self, img_paths):

try:
imgs_emb = self.model.encode(imgs, batch_size=32, convert_to_tensor=True)

if type(img_paths) is list:
magnitudes = map(np.linalg.norm, imgs_emb)

return imgs_emb, magnitudes
if torch.cuda.is_available():
if type(img_paths) is list:
magnitudes = list(
map(lambda x: np.linalg.norm(x.cpu().numpy()), imgs_emb)
)

return imgs_emb, magnitudes
else:
img_emb = imgs_emb[0].cpu().numpy().tolist()
magnitude = np.linalg.norm(img_emb)

return img_emb, magnitude
else:
img_emb = imgs_emb[0].tolist()
magnitude = np.linalg.norm(img_emb)
if type(img_paths) is list:
magnitudes = map(np.linalg.norm, imgs_emb)
return imgs_emb, magnitudes
else:
img_emb = imgs_emb[0].tolist()
magnitude = np.linalg.norm(img_emb)

return img_emb, magnitude
except Exception as e:
Expand Down

0 comments on commit 8a4ba7a

Please sign in to comment.