diff --git a/chai_lab/data/dataset/embeddings/esm.py b/chai_lab/data/dataset/embeddings/esm.py index dd12ea8..22c2632 100644 --- a/chai_lab/data/dataset/embeddings/esm.py +++ b/chai_lab/data/dataset/embeddings/esm.py @@ -33,44 +33,58 @@ def esm_model(model_name: str, device): model.to(device) model.eval() yield model - # move model back to CPU - model.to("cpu") + model.to("cpu") # move model back to CPU when done -def embedding_context_from_sequence(seq: str, device) -> EmbeddingContext: +def _get_esm_contexts_for_sequences( + prot_sequences: set[str], device +) -> dict[str, EmbeddingContext]: + if len(prot_sequences) == 0: + return {} # skip loading ESM + # local import, requires huggingface transformers from transformers import EsmTokenizer model_name = "facebook/esm2_t36_3B_UR50D" tokenizer = EsmTokenizer.from_pretrained(model_name) - inputs = tokenizer(seq, return_tensors="pt") - inputs = move_data_to_device(dict(**inputs), device=device) + seq2embedding_context = {} with torch.no_grad(): with esm_model(model_name=model_name, device=device) as model: - outputs = model(**inputs) + for seq in prot_sequences: + inputs = tokenizer(seq, return_tensors="pt") + inputs = move_data_to_device(dict(**inputs), device=device) + outputs = model(**inputs) + # remove BOS/EOS, back to CPU + esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu") + seq_len, _emb_dim = esm_embeddings.shape + assert seq_len == len(seq) + + seq2embedding_context[seq] = EmbeddingContext( + esm_embeddings=esm_embeddings + ) - # remove BOS/EOS, back to CPU - esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu") - seq_len, _emb_dim = esm_embeddings.shape - assert seq_len == len(seq) - return EmbeddingContext(esm_embeddings=esm_embeddings) + return seq2embedding_context @typecheck def get_esm_embedding_context(chains: list[Chain], device) -> EmbeddingContext: # device is used for computing, but result is still on CPU - chain_embs = [] + protein_seq2emb_context = _get_esm_contexts_for_sequences( + prot_sequences=set( + chain.entity_data.sequence + for chain in chains + if chain.entity_data.entity_type == EntityType.PROTEIN + ), + device=device, + ) + + chain_embs = [] for chain in chains: if chain.entity_data.entity_type == EntityType.PROTEIN: - emb = embedding_context_from_sequence( - # modified residues represented as X - seq=chain.entity_data.sequence, - device=device, - ) - chain_embs.append(emb) + chain_embs.append(protein_seq2emb_context[chain.entity_data.sequence]) else: # embed non-proteins with zeros chain_embs.append(