diff --git a/chai_lab/data/dataset/embeddings/esm.py b/chai_lab/data/dataset/embeddings/esm.py index b522daf..175a7db 100644 --- a/chai_lab/data/dataset/embeddings/esm.py +++ b/chai_lab/data/dataset/embeddings/esm.py @@ -31,11 +31,18 @@ def esm_model(model_name: str, device): """Context transiently keeps ESM model on specified device.""" from transformers import EsmModel + from transformers.utils.import_utils import is_torch_bf16_gpu_available if len(_esm_model) == 0: # lazy loading of the model _esm_model.append( - EsmModel.from_pretrained(model_name, cache_dir=esm_cache_folder) + EsmModel.from_pretrained( + model_name, + cache_dir=esm_cache_folder, + torch_dtype=( + torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 + ), + ) ) [model] = _esm_model