Skip to content

Commit 758022b

Browse files
committed
Address PR feedback
1 parent aa57921 commit 758022b

File tree

2 files changed

+17
-31
lines changed

2 files changed

+17
-31
lines changed

chai_lab/chai1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ def make_all_atom_feature_context(
409409
else:
410410
restraint_context = RestraintContext.empty()
411411

412-
# Handles leaving atoms for bonds in-place
413-
merged_context.drop_leaving_atoms()
412+
# Handles leaving atoms for glycan bonds in-place
413+
merged_context.drop_glycan_leaving_atoms_inplace()
414414

415415
# Build final feature context
416416
feature_context = AllAtomFeatureContext(

chai_lab/data/dataset/structure/all_atom_structure_context.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ def report_bonds(self) -> None:
110110
)
111111

112112
@typecheck
113-
def _infer_bonds_within_conformer(
113+
def _infer_CO_bonds_within_glycan(
114114
self,
115115
atom_idx: int,
116116
allowed_elements: list[int] | None = None,
117-
exclude_polymers: bool = True,
118117
) -> Bool[Tensor, "{self.num_atoms}"]:
119118
"""Return mask for atoms that atom_idx might bond to based on distances.
120119
@@ -124,12 +123,7 @@ def _infer_bonds_within_conformer(
124123
res = self.token_residue_index[tok]
125124
asym = self.token_asym_id[tok]
126125

127-
if exclude_polymers and self.token_residue_type[tok] in (
128-
EntityType.PROTEIN.value,
129-
EntityType.DNA.value,
130-
EntityType.RNA.value,
131-
EntityType.POLYMER_HYBRID.value,
132-
):
126+
if self.token_residue_type[tok].item() != EntityType.MANUAL_GLYCAN.value:
133127
return torch.zeros(self.num_atoms, dtype=torch.bool)
134128

135129
mask = (
@@ -157,56 +151,48 @@ def _infer_bonds_within_conformer(
157151
bond_candidates = (distances[atom_idx] < 1.5) & mask & is_allowed_element
158152
return bond_candidates
159153

160-
def drop_leaving_atoms(self) -> None:
154+
def drop_glycan_leaving_atoms_inplace(self) -> None:
161155
"""Drop OH groups that leave upon bond formation by setting atom_exists_mask."""
162156
# For each of the bonds, identify the atoms within bond radius and guess which are leaving
157+
oxygen = 8
163158
for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)):
164159
# Find the C-O bonds
165-
(bond_candidates_b,) = torch.where(
166-
self._infer_bonds_within_conformer(
167-
atom_b.item(), allowed_elements=[8], exclude_polymers=True
160+
[bond_candidates_b] = torch.where(
161+
self._infer_CO_bonds_within_glycan(
162+
atom_b.item(), allowed_elements=[oxygen]
168163
)
169164
)
170165
# Filter to bonds that link to terminal atoms
166+
# NOTE do not specify element here
171167
bonds_b = [
172168
candidate
173169
for candidate in bond_candidates_b.tolist()
174-
if (
175-
self._infer_bonds_within_conformer(
176-
candidate, exclude_polymers=True
177-
).sum()
178-
== 1
179-
)
170+
if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1)
180171
]
181172
# If there are multiple such bonds, we can't infer which to drop
182173
if len(bonds_b) == 1:
183-
b_bond = bonds_b.pop()
174+
[b_bond] = bonds_b
184175
self.atom_exists_mask[b_bond] = False
185176
logger.info(
186177
f"Bond {i} right: Dropping latter atom in bond {self.atom_residue_index[atom_b]} {self.atom_ref_name[atom_b]} -> {self.atom_residue_index[b_bond]} {self.atom_ref_name[b_bond]}"
187178
)
188179
continue # Only identify one leaving atom per bond
189180

190181
# Repeat the above for atom_a if we didn't find anything for atom B
191-
(bond_candidates_a,) = torch.where(
192-
self._infer_bonds_within_conformer(
193-
atom_a.item(), allowed_elements=[8], exclude_polymers=True
182+
[bond_candidates_a] = torch.where(
183+
self._infer_CO_bonds_within_glycan(
184+
atom_a.item(), allowed_elements=[oxygen]
194185
)
195186
)
196187
# Filter to bonds that link to terminal atoms
197188
bonds_a = [
198189
candidate
199190
for candidate in bond_candidates_a.tolist()
200-
if (
201-
self._infer_bonds_within_conformer(
202-
candidate, exclude_polymers=True
203-
).sum()
204-
== 1
205-
)
191+
if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1)
206192
]
207193
# If there are multiple such bonds, we can't infer which to drop
208194
if len(bonds_a) == 1:
209-
a_bond = bonds_a.pop()
195+
[a_bond] = bonds_a
210196
self.atom_exists_mask[a_bond] = False
211197
logger.info(
212198
f"Bond {i} left: Dropping latter atom in bond {self.atom_residue_index[atom_a]} {self.atom_ref_element[atom_a]} -> {self.atom_residue_index[a_bond]} {self.atom_ref_element[a_bond]}"

0 commit comments

Comments
 (0)