Skip to content

Commit

Permalink
Merge pull request #5 from chaidiscovery/alex/chailab
Browse files Browse the repository at this point in the history
move ESM to/from GPU once per complex, not once per chain
  • Loading branch information
arogozhnikov authored Sep 10, 2024
2 parents 9ffc6ad + cb6ba13 commit 48b4ef7
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions chai_lab/data/dataset/embeddings/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,58 @@ def esm_model(model_name: str, device):
model.to(device)
model.eval()
yield model
# move model back to CPU
model.to("cpu")
model.to("cpu") # move model back to CPU when done


def embedding_context_from_sequence(seq: str, device) -> EmbeddingContext:
def _get_esm_contexts_for_sequences(
prot_sequences: set[str], device
) -> dict[str, EmbeddingContext]:
if len(prot_sequences) == 0:
return {} # skip loading ESM

# local import, requires huggingface transformers
from transformers import EsmTokenizer

model_name = "facebook/esm2_t36_3B_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)

inputs = tokenizer(seq, return_tensors="pt")
inputs = move_data_to_device(dict(**inputs), device=device)
seq2embedding_context = {}

with torch.no_grad():
with esm_model(model_name=model_name, device=device) as model:
outputs = model(**inputs)
for seq in prot_sequences:
inputs = tokenizer(seq, return_tensors="pt")
inputs = move_data_to_device(dict(**inputs), device=device)
outputs = model(**inputs)
# remove BOS/EOS, back to CPU
esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
seq_len, _emb_dim = esm_embeddings.shape
assert seq_len == len(seq)

seq2embedding_context[seq] = EmbeddingContext(
esm_embeddings=esm_embeddings
)

# remove BOS/EOS, back to CPU
esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
seq_len, _emb_dim = esm_embeddings.shape
assert seq_len == len(seq)
return EmbeddingContext(esm_embeddings=esm_embeddings)
return seq2embedding_context


@typecheck
def get_esm_embedding_context(chains: list[Chain], device) -> EmbeddingContext:
# device is used for computing, but result is still on CPU
chain_embs = []

protein_seq2emb_context = _get_esm_contexts_for_sequences(
prot_sequences=set(
chain.entity_data.sequence
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
),
device=device,
)

chain_embs = []
for chain in chains:
if chain.entity_data.entity_type == EntityType.PROTEIN:
emb = embedding_context_from_sequence(
# modified residues represented as X
seq=chain.entity_data.sequence,
device=device,
)
chain_embs.append(emb)
chain_embs.append(protein_seq2emb_context[chain.entity_data.sequence])
else:
# embed non-proteins with zeros
chain_embs.append(
Expand Down

0 comments on commit 48b4ef7

Please sign in to comment.