Skip to content

Commit

Permalink
Handle covalent bond leaving atoms (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin authored Dec 10, 2024
1 parent 5f74e11 commit 91adda8
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 28 deletions.
59 changes: 44 additions & 15 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,7 @@ def sorted(self) -> "StructureCandidates":
)


@torch.no_grad()
def run_inference(
def make_all_atom_feature_context(
fasta_file: Path,
*,
output_dir: Path,
Expand All @@ -306,18 +305,8 @@ def run_inference(
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
) -> StructureCandidates:
if output_dir.exists():
assert not any(
output_dir.iterdir()
), f"Output directory {output_dir} is not empty."
torch_device = torch.device(device if device is not None else "cuda:0")
esm_device: torch.device = torch.device("cpu"),
):
assert not (
use_msa_server and msa_directory
), "Cannot specify both MSA server and directory"
Expand Down Expand Up @@ -385,7 +374,7 @@ def run_inference(

# Load ESM embeddings
if use_esm_embeddings:
embedding_context = get_esm_embedding_context(chains, device=torch_device)
embedding_context = get_esm_embedding_context(chains, device=esm_device)
else:
embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens)

Expand Down Expand Up @@ -420,6 +409,9 @@ def run_inference(
else:
restraint_context = RestraintContext.empty()

# Handles leaving atoms for glycan bonds in-place
merged_context.drop_glycan_leaving_atoms_inplace()

# Build final feature context
feature_context = AllAtomFeatureContext(
chains=chains,
Expand All @@ -430,6 +422,43 @@ def run_inference(
embedding_context=embedding_context,
restraint_context=restraint_context,
)
return feature_context


@torch.no_grad()
def run_inference(
fasta_file: Path,
*,
output_dir: Path,
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
) -> StructureCandidates:
if output_dir.exists():
assert not any(
output_dir.iterdir()
), f"Output directory {output_dir} is not empty."

torch_device = torch.device(device if device is not None else "cuda:0")

feature_context = make_all_atom_feature_context(
fasta_file=fasta_file,
output_dir=output_dir,
use_esm_embeddings=use_esm_embeddings,
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
msa_directory=msa_directory,
constraint_path=constraint_path,
esm_device=torch_device,
)

return run_folding_on_context(
feature_context,
Expand Down
99 changes: 99 additions & 0 deletions chai_lab/data/dataset/structure/all_atom_structure_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import torch
from torch import Tensor

from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.utils.tensor_utils import (
batch_tensorcode_to_string,
cdist,
tensorcode_to_string,
)
from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck
Expand Down Expand Up @@ -107,6 +109,95 @@ def report_bonds(self) -> None:
f"Bond {i} (asym res_idx resname): {asym_a} {res_idx_a} {resname_a} <> {asym_b} {res_idx_b} {resname_b}"
)

@typecheck
def _infer_CO_bonds_within_glycan(
self,
atom_idx: int,
allowed_elements: list[int] | None = None,
) -> Bool[Tensor, "{self.num_atoms}"]:
"""Return mask for atoms that atom_idx might bond to based on distances.
If exclude_polymers is True, then always return no bonds for polymer entities
"""
tok = self.atom_token_index[atom_idx]
res = self.token_residue_index[tok]
asym = self.token_asym_id[tok]

if self.token_entity_type[tok].item() != EntityType.MANUAL_GLYCAN.value:
return torch.zeros(self.num_atoms, dtype=torch.bool)

mask = (
(self.atom_residue_index == res)
& (self.atom_asym_id == asym)
& self.atom_exists_mask
)

# This field contains reference conformers for each residue
# Pairwise distances are therefore valid within each residue
distances = cdist(self.atom_gt_coords)
assert distances.shape == (self.num_atoms, self.num_atoms)
distances[torch.arange(self.num_atoms), torch.arange(self.num_atoms)] = (
torch.inf
)

is_allowed_element = (
torch.isin(
self.atom_ref_element, test_elements=torch.tensor(allowed_elements)
)
if allowed_elements is not None
else torch.ones_like(mask)
)
# Canonical bond length for C-O is 1.43 angstroms; add a bit of headroom
bond_candidates = (distances[atom_idx] < 1.5) & mask & is_allowed_element
return bond_candidates

def drop_glycan_leaving_atoms_inplace(self) -> None:
"""Drop OH groups that leave upon bond formation by setting atom_exists_mask."""
# For each of the bonds, identify the atoms within bond radius and guess which are leaving
oxygen = 8
for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)):
# Find the C-O bonds
[bond_candidates_b] = torch.where(
self._infer_CO_bonds_within_glycan(
atom_b.item(), allowed_elements=[oxygen]
)
)
# Filter to bonds that link to terminal atoms
# NOTE do not specify element here
bonds_b = [
candidate
for candidate in bond_candidates_b.tolist()
if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1)
]
# If there are multiple such bonds, we can't infer which to drop
if len(bonds_b) == 1:
[b_bond] = bonds_b
self.atom_exists_mask[b_bond] = False
logger.info(
f"Bond {i} right: Dropping latter atom in bond {self.atom_residue_index[atom_b]} {self.atom_ref_name[atom_b]} -> {self.atom_residue_index[b_bond]} {self.atom_ref_name[b_bond]}"
)
continue # Only identify one leaving atom per bond

# Repeat the above for atom_a if we didn't find anything for atom B
[bond_candidates_a] = torch.where(
self._infer_CO_bonds_within_glycan(
atom_a.item(), allowed_elements=[oxygen]
)
)
# Filter to bonds that link to terminal atoms
bonds_a = [
candidate
for candidate in bond_candidates_a.tolist()
if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1)
]
# If there are multiple such bonds, we can't infer which to drop
if len(bonds_a) == 1:
[a_bond] = bonds_a
self.atom_exists_mask[a_bond] = False
logger.info(
f"Bond {i} left: Dropping latter atom in bond {self.atom_residue_index[atom_a]} {self.atom_ref_element[atom_a]} -> {self.atom_residue_index[a_bond]} {self.atom_ref_element[a_bond]}"
)

def pad(
self,
n_tokens: int,
Expand Down Expand Up @@ -321,6 +412,14 @@ def num_atoms(self) -> int:
(n_atoms,) = self.atom_token_index.shape
return n_atoms

@property
def atom_residue_index(self) -> Int[Tensor, "n_atoms"]:
return self.token_residue_index[self.atom_token_index]

@property
def atom_asym_id(self) -> Int[Tensor, "n_atoms"]:
return self.token_asym_id[self.atom_token_index]

def to_dict(self) -> dict[str, torch.Tensor]:
return asdict(self)

Expand Down
4 changes: 3 additions & 1 deletion chai_lab/data/parsing/glycans.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def __post_init__(self):

@property
def src_atom_name(self) -> str:
return f"C{self.src_atom}"
"""Links between sugars are O-glycosidic bonds; we use src O dst C."""
return f"O{self.src_atom}"

@property
def dst_atom_name(self) -> str:
"""Links between sugars are O-glycosidic bonds; we use src O dst C."""
return f"C{self.dst_atom}"


Expand Down
2 changes: 1 addition & 1 deletion examples/glycosylation/1ac5.fasta
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
>protein|1AC5
LPSSEEYKVAYELLPGLSEVPDPSNIPQMHAGHIPLRSEDADEQDSSDLEYFFWKFTNNDSNGNVDRPLIIWLNGGPGCSSMDGALVESGPFRVNSDGKLYLNEGSWISKGDLLFIDQPTGTGFSVEQNKDEGKIDKNKFDEDLEDVTKHFMDFLENYFKIFPEDLTRKIILSGESYAGQYIPFFANAILNHNKFSKIDGDTYDLKALLIGNGWIDPNTQSLSYLPFAMEKKLIDESNPNFKHLTNAHENCQNLINSASTDEAAHFSYQECENILNLLLSYTRESSQKGTADCLNMYNFNLKDSYPSCGMNWPKDISFVSKFFSTPGVIDSLHLDSDKIDHWKECTNSVGTKLSNPISKPSIHLLPGLLESGIEIVLFNGDKDLICNNKGVLDTIDNLKWGGIKGFSDDAVSFDWIHKSKSTDDSEEFSGYVKYDRNLTFVSVYNASHMVPFDKSLVSRGIVDIYSNDVMIIDNNGKNVMITT
>glycan|two-sugar
NAG(1-4 NAG)
NAG(4-1 NAG)
>glycan|one-sugar
NAG
22 changes: 14 additions & 8 deletions examples/glycosylation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ Now, a glycan is also covalently bound to a residue; to specify this, we include

chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id
|---|---|---|---|---|---|---|---|---|---|
A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1
A|N436@N|B|@C1|covalent|1.0|0.0|0.0|protein-glycan|bond1

Breaking this down, this specifies that the within chain A (the first entry in the fasta), the "N" residue at the 436-th position (1-indexed) as indicated by the "N436" prefix is bound, via its nitrogen "N" atom as indicated by the "@N" suffix, to the C4 atom in the first glycan ("@C4"). Ring numbering follows standard glycan ring number schemas. For other ligands, you will need check how the specific version of `rdkit` that we use in `chai-lab` (run `uv pip list | grep rdkit` for version) assigns atom names and use the same atom names to specify your bonds. In addition, note that the min and max distance fields are ignored, as is the confidence field.
Breaking this down, this specifies that the within chain A (the first entry in the fasta), the "N" residue at the 436-th position (1-indexed) as indicated by the "N436" prefix is bound, via its nitrogen "N" atom as indicated by the "@N" suffix, to the C1 atom in the first glycan ("@C1"). Ring numbering follows standard glycan ring number schemas. For other ligands, you will need check how the specific version of `rdkit` that we use in `chai-lab` (run `uv pip list | grep rdkit` for version) assigns atom names and use the same atom names to specify your bonds. In addition, note that the min and max distance fields are ignored, as is the confidence field.


### Multi-ring glycan
Expand All @@ -37,30 +37,36 @@ Working through a more complex example, let's say we have a two-ring ligand such
>protein|example-protein
...N...
>glycan|example-dual-sugar
NAG(1-4 NAG)
NAG(4-1 NAG)
```

This syntax specifies that the root of the glycan is the leading `NAG` ring. The parentheses indicate that we are attaching another molecule to the ring directly preceding the parentheses. The `1-4` syntax "draws" a bond between the C1 atom of the previous "root" `NAG` and the C4 atom of the subsequent `NAG` ring. To specify how this glycan ought to be connected to the protein, we again use the restraints file to specify a residue and atom to which the glycan is bound, and the carbon atom within the root glycan ring that is bound.
This syntax specifies that the root of the glycan is the leading `NAG` ring. The parentheses indicate that we are attaching another molecule to the ring directly preceding the parentheses. The `4-1` syntax "draws" a bond between the O4 atom of the previous "root" `NAG` and the C1 atom of the subsequent `NAG` ring. Note that this syntax, when read left-to-right, is "building out" the glycan from the root sugar outwards.

To specify how this glycan ought to be connected to the protein, we again use the restraints file to specify a residue and atom to which the glycan is bound, and the carbon atom within the root glycan ring that is bound.

chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id
|---|---|---|---|---|---|---|---|---|---|
A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1
A|N436@N|B|@C1|covalent|1.0|0.0|0.0|protein-glycan|bond1

You can chain this syntax to create longer ligands:
```
>glycan|4-NAG-in-a-linear-chain
NAG(1-4 NAG(1-4 NAG(1-4 NAG)))
NAG(4-1 NAG(4-1 NAG(4-1 NAG)))
```

...and to create branched ligands
```
>glycan|branched-glycan
NAG(1-4 NAG(1-4 NAG))(3-4 MAN)
NAG(4-1 NAG(4-1 BMA(3-1 MAN)(6-1 MAN)))
```
This branched example has a root `NAG` ring with a branch with two more `NAG` rings and a branch with a single `MAN` ring. For additional examples, please refer to the examples tested in the `tests/test_glycans.py` test file.
This branched example has a root `NAG` ring followed by a `NAG` and a `BMA`, which then branches to two `MAN` rings. For additional examples of this syntax, please refer to the examples in `tests/test_glycans.py`.

### Example

We have included an example of how glycans can be specified under `predict_glycosylated.py` in this directory, along with its corresponding `bonds.restraints` csv file. This example is based on the PDB structure [1AC5](https://www.rcsb.org/structure/1ac5). The predicted structrue (colored, glycans in purple and orange, protein in green) from this script should look like the following when aligned with the ground truth 1AC5 structure (gray):

![glycan example prediction](./output.png)

### A note on leaving atoms

One might notice that in the above example, we are specifying CCD codes for sugar rings and connecting them to each other and an amino acid residue via various bonds. A subtle point is that the reference conformer for these sugar rings include OH hydroxyl groups that leave when bonds are formed. Under the hood, Chai-1 tries to automatically find and remove these atoms (see `AllAtomStructureContext.drop_glycan_leaving_atoms_inplace` for implementation), but this logic only drops leaving hydroxyl groups within glycan sugar rings. For other, non-sugar covalently attached ligands, please specify a SMILES string without the leaving atoms. If this does not work for your use case, please open a GitHub issue.
4 changes: 2 additions & 2 deletions examples/glycosylation/bonds.restraints
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id
A,N437@N,B,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond1
A,N445@N,C,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond2
A,N437@N,B,@C1,covalent,1.0,0.0,0.0,protein-glycan,bond1
A,N445@N,C,@C1,covalent,1.0,0.0,0.0,protein-glycan,bond2
Binary file modified examples/glycosylation/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 56 additions & 1 deletion tests/test_glycans.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
from collections import Counter
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest

from chai_lab.chai1 import make_all_atom_feature_context
from chai_lab.data.parsing.glycans import _glycan_string_to_sugars_and_bonds


Expand All @@ -22,13 +27,20 @@ def test_complex_parsing():

assert bond1.src_sugar_index == 0
assert bond1.dst_sugar_index == 1
assert bond1.src_atom == 6
assert bond1.dst_atom == 1
assert bond2.src_sugar_index == 0
assert bond2.dst_sugar_index == 2
assert bond3.src_sugar_index == 2
assert bond2.src_atom == 4
assert bond2.dst_atom == 1
assert bond3.src_sugar_index == 2
assert bond3.dst_sugar_index == 3
assert bond3.src_atom == 6
assert bond3.dst_atom == 1
assert bond4.src_sugar_index == 3
assert bond4.dst_sugar_index == 4
assert bond4.src_atom == 6
assert bond4.dst_atom == 1


def test_complex_parsing_2():
Expand All @@ -51,3 +63,46 @@ def test_complex_parsing_2():
for (expected_src, expected_dst), bond in zip(expected_bonds, bonds, strict=True):
assert bond.src_sugar_index == expected_src
assert bond.dst_sugar_index == expected_dst


def test_glycan_tokenization_with_bond():
"""Test that tokenization works, and that atoms are dropped as expected."""
glycan = ">glycan|foo\nNAG(4-1 NAG)\n"
with TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)

fasta_file = tmp_path / "input.fasta"
fasta_file.write_text(glycan)

output_dir = tmp_path / "out"

feature_context = make_all_atom_feature_context(
fasta_file,
output_dir=output_dir,
use_esm_embeddings=False, # Just a test; no need
)

# Each NAG component is C8 H15 N O6 -> 8 + 1 + 6 = 15 heavy atoms
# The bond between them displaces one oxygen, leaving 2 * 15 - 1 = 29 atoms
assert feature_context.structure_context.atom_exists_mask.sum() == 29
# We originally constructed all atoms in dropped the atoms that leave
assert feature_context.structure_context.atom_exists_mask.numel() == 30
elements = Counter(
feature_context.structure_context.atom_ref_element[
feature_context.structure_context.atom_exists_mask
].tolist()
)
assert elements[6] == 16 # 6 = Carbon
assert elements[7] == 2 # 7 = Nitrogen
assert elements[8] == 11 # 8 = Oxygen

# Single bond feature between O and C
left, right = feature_context.structure_context.atom_covalent_bond_indices
assert left.numel() == right.numel() == 1
bond_elements = set(
[
feature_context.structure_context.atom_ref_element[left].item(),
feature_context.structure_context.atom_ref_element[right].item(),
]
)
assert bond_elements == {8, 6}

0 comments on commit 91adda8

Please sign in to comment.