diff --git a/docs/infer_json_format.md b/docs/infer_json_format.md index f03b708..c97df98 100644 --- a/docs/infer_json_format.md +++ b/docs/infer_json_format.md @@ -20,10 +20,16 @@ Here is an overview of the JSON file format: The JSON file consists of a list of dictionaries, where each dictionary represents a set of sequences you want to model. Even if you are modeling only one set of sequences, the top-level structure should still be a list. -Each dictionary contains the following three keys: +Each dictionary contains the following keys: * `name`: A string representing the name of the inference job. * `sequences`: A list of dictionaries that describe the entities (e.g., proteins, DNA, RNA, small molecules, and ions) involved in the inference. * `covalent_bonds`: An optional list of dictionaries that define the covalent bonds between atoms from different entities. +* `userCCD`: Optional inline CCD mmCIF text that defines additional chemical components for this job. +* `userCCDPath`: Optional path to a CCD mmCIF file that defines additional chemical components for this job. Relative paths are resolved from the input JSON file location. + +`userCCD` and `userCCDPath` are mutually exclusive. They can be used when a modification or CCD ligand is not present in the default Protenix CCD cache. For example, a custom protein PTM can be provided in a user CCD file and referenced from `proteinChain.modifications` via `ptmType: "CCD_"`. + +The user CCD component must contain the atom names used by the PTM or covalent-bond fields. Protein PTM components should be peptide-like CCD entries with standard backbone atom names such as `N`, `CA`, `C`, and `O`, and correct leaving-atom flags for atoms such as `OXT` when present. Details of `sequences` and `covalent_bonds` are provided below. @@ -395,4 +401,4 @@ The contents of each output file are as follows: - `has_clash` - Boolean flag indicating if there are steric clashes in the predicted structure. - `disorder` - Predicted regions of intrinsic disorder within the protein, highlighting residues that may be flexible or unstructured. - `ranking_score` - Predicted confidence score for ranking complexes. Higher values indicate greater confidence. - - `num_recycles`: Number of recycling steps used during inference. \ No newline at end of file + - `num_recycles`: Number of recycling steps used during inference. diff --git a/examples/custom_ptm_components.cif b/examples/custom_ptm_components.cif new file mode 100644 index 0000000..1f2324f --- /dev/null +++ b/examples/custom_ptm_components.cif @@ -0,0 +1,66 @@ +data_UAA +# +_chem_comp.id UAA +_chem_comp.name 'USER DEFINED AMINO ACID' +_chem_comp.type 'L-PEPTIDE LINKING' +_chem_comp.pdbx_type ATOMP +_chem_comp.formula 'C3 H7 N O2' +_chem_comp.mon_nstd_parent_comp_id ? +_chem_comp.pdbx_synonyms ? +_chem_comp.pdbx_formal_charge 0 +_chem_comp.pdbx_initial_date 2026-05-10 +_chem_comp.pdbx_modified_date 2026-05-10 +_chem_comp.pdbx_ambiguous_flag N +_chem_comp.pdbx_release_status REL +_chem_comp.pdbx_replaced_by ? +_chem_comp.pdbx_replaces ? +_chem_comp.formula_weight 89.094 +_chem_comp.one_letter_code K +_chem_comp.three_letter_code UAA +_chem_comp.pdbx_model_coordinates_db_code ? +_chem_comp.pdbx_model_coordinates_details ? +_chem_comp.pdbx_ideal_coordinates_details ? +_chem_comp.pdbx_ideal_coordinates_missing_flag N +_chem_comp.pdbx_model_coordinates_missing_flag N +_chem_comp.pdbx_processing_site ? +# +loop_ +_chem_comp_atom.comp_id +_chem_comp_atom.atom_id +_chem_comp_atom.alt_atom_id +_chem_comp_atom.type_symbol +_chem_comp_atom.charge +_chem_comp_atom.pdbx_align +_chem_comp_atom.pdbx_aromatic_flag +_chem_comp_atom.pdbx_leaving_atom_flag +_chem_comp_atom.pdbx_stereo_config +_chem_comp_atom.model_Cartn_x +_chem_comp_atom.model_Cartn_y +_chem_comp_atom.model_Cartn_z +_chem_comp_atom.pdbx_model_Cartn_x_ideal +_chem_comp_atom.pdbx_model_Cartn_y_ideal +_chem_comp_atom.pdbx_model_Cartn_z_ideal +_chem_comp_atom.pdbx_component_atom_id +_chem_comp_atom.pdbx_component_comp_id +_chem_comp_atom.pdbx_ordinal +UAA N N N 0 1 N N N 0.000 0.000 0.000 0.000 0.000 0.000 N UAA 1 +UAA CA CA C 0 1 N N N 1.450 0.000 0.000 1.450 0.000 0.000 CA UAA 2 +UAA C C C 0 1 N N N 2.020 1.410 0.000 2.020 1.410 0.000 C UAA 3 +UAA O O O 0 1 N N N 1.320 2.400 0.000 1.320 2.400 0.000 O UAA 4 +UAA CB CB C 0 1 N N N 1.980 -0.780 -1.200 1.980 -0.780 -1.200 CB UAA 5 +UAA OXT OXT O 0 1 N Y N 3.250 1.560 0.000 3.250 1.560 0.000 OXT UAA 6 +# +loop_ +_chem_comp_bond.comp_id +_chem_comp_bond.atom_id_1 +_chem_comp_bond.atom_id_2 +_chem_comp_bond.value_order +_chem_comp_bond.pdbx_aromatic_flag +_chem_comp_bond.pdbx_stereo_config +_chem_comp_bond.pdbx_ordinal +UAA N CA SING N N 1 +UAA CA C SING N N 2 +UAA C O DOUB N N 3 +UAA CA CB SING N N 4 +UAA C OXT SING N N 5 +# diff --git a/examples/example_custom_ptm_userccd.json b/examples/example_custom_ptm_userccd.json new file mode 100644 index 0000000..e7fba73 --- /dev/null +++ b/examples/example_custom_ptm_userccd.json @@ -0,0 +1,20 @@ +[ + { + "name": "custom_ptm_userccd", + "userCCDPath": "custom_ptm_components.cif", + "sequences": [ + { + "proteinChain": { + "sequence": "AKT", + "count": 1, + "modifications": [ + { + "ptmType": "CCD_UAA", + "ptmPosition": 2 + } + ] + } + } + ] + } +] diff --git a/protenix/data/core/ccd.py b/protenix/data/core/ccd.py index 4c6e6d4..5bc0e3a 100644 --- a/protenix/data/core/ccd.py +++ b/protenix/data/core/ccd.py @@ -17,7 +17,7 @@ import pickle from collections import defaultdict from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import biotite import biotite.structure as struc @@ -310,7 +310,9 @@ def get_ccd_ref_info( # Modified from biotite to use consistent ccd components file def _connect_inter_residue( - atoms: AtomArray, residue_starts: np.ndarray + atoms: AtomArray, + residue_starts: np.ndarray, + get_mol_type_fn: Callable[[str], str] = get_mol_type, ) -> struc.BondList: """ Create a :class:`BondList` containing the bonds between adjacent @@ -355,8 +357,8 @@ def _connect_inter_residue( continue # Get link type for this residue from RCSB components.cif - curr_link = get_mol_type(res_names[curr_start_i]) - next_link = get_mol_type(res_names[next_start_i]) + curr_link = get_mol_type_fn(res_names[curr_start_i]) + next_link = get_mol_type_fn(res_names[next_start_i]) if curr_link == "protein" and next_link in "protein": curr_connect_atom_name = "C" diff --git a/protenix/data/core/custom_ccd.py b/protenix/data/core/custom_ccd.py new file mode 100644 index 0000000..9d405d2 --- /dev/null +++ b/protenix/data/core/custom_ccd.py @@ -0,0 +1,283 @@ +# Copyright 2024 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from typing import Any, Mapping + +import biotite +import gemmi +import numpy as np +import rdkit +from biotite.structure import AtomArray +from biotite.structure.io import pdbx +from pdbeccdutils.core import ccd_reader +from rdkit import Chem +from rdkit.Chem import AllChem + +from protenix.data.core import ccd + + +class CCDProvider: + """Per-job CCD overlay with fallback to the global Protenix CCD cache.""" + + def __init__(self, user_ccd_path: Path | None = None) -> None: + self.user_ccd_path = user_ccd_path + self._user_biotite_cif: pdbx.CIFFile | None = None + self._user_gemmi_cif: gemmi.cif.Document | None = None + self._custom_mols: dict[str, Chem.Mol] = {} + self._tmpdir: tempfile.TemporaryDirectory | None = None + if user_ccd_path is not None: + self._load_user_ccd(user_ccd_path) + + @classmethod + def from_job( + cls, + job: Mapping[str, Any], + base_dir: str | Path | None = None, + ) -> "CCDProvider": + """Build a provider from optional job-level userCCD/userCCDPath fields.""" + has_user_ccd = "userCCD" in job and job["userCCD"] is not None + has_user_ccd_path = "userCCDPath" in job and job["userCCDPath"] is not None + if has_user_ccd and has_user_ccd_path: + raise ValueError('Only one of "userCCD" and "userCCDPath" may be set.') + if has_user_ccd: + user_ccd = str(job["userCCD"]) + if not user_ccd.strip(): + raise ValueError('"userCCD" can not be empty.') + tmpdir = tempfile.TemporaryDirectory(prefix="protenix_user_ccd_") + path = Path(tmpdir.name) / "user_components.cif" + path.write_text(user_ccd, encoding="utf-8") + provider = cls(path) + provider._tmpdir = tmpdir + return provider + if has_user_ccd_path: + user_ccd_path = str(job["userCCDPath"]) + if not user_ccd_path.strip(): + raise ValueError('"userCCDPath" can not be empty.') + path = Path(user_ccd_path).expanduser() + if not path.is_absolute() and base_dir is not None: + path = Path(base_dir) / path + if not path.exists() or not path.is_file(): + raise FileNotFoundError(f'userCCDPath does not exist: "{path}"') + return cls(path) + return cls() + + def has_user_components(self) -> bool: + return self._user_biotite_cif is not None + + @property + def user_codes(self) -> set[str]: + if self._user_biotite_cif is None: + return set() + return set(self._user_biotite_cif.keys()) + + def get_custom_mols(self) -> dict[str, Chem.Mol]: + return self._custom_mols + + def get_component_atom_array( + self, + ccd_code: str, + keep_leaving_atoms: bool = False, + keep_hydrogens: bool = False, + ) -> AtomArray | None: + if ccd_code not in self.user_codes: + return ccd.get_component_atom_array( + ccd_code, + keep_leaving_atoms=keep_leaving_atoms, + keep_hydrogens=keep_hydrogens, + ) + + assert self._user_biotite_cif is not None + try: + comp = pdbx.get_component( + self._user_biotite_cif, + data_block=ccd_code, + use_ideal_coord=True, + allow_missing_coord=True, + ) + except biotite.InvalidFileError as exc: + raise ValueError( + f"Can not parse user CCD component {ccd_code}: {exc}" + ) from exc + + atom_category = self._user_biotite_cif[ccd_code]["chem_comp_atom"] + try: + leaving_atom_flag = atom_category["pdbx_leaving_atom_flag"].as_array() + except KeyError: + leaving_atom_flag = np.array(["N"] * len(comp)) + comp.set_annotation("leaving_atom_flag", leaving_atom_flag == "Y") + + for atom_id in ["alt_atom_id", "pdbx_component_atom_id"]: + try: + comp.set_annotation(atom_id, atom_category[atom_id].as_array()) + except KeyError: + comp.set_annotation(atom_id, comp.atom_name.copy()) + + if not keep_leaving_atoms: + comp = comp[~comp.leaving_atom_flag] + if not keep_hydrogens: + comp = comp[~np.isin(comp.element, ["H", "D"])] + + comp.central_to_leaving_groups = ccd._map_central_to_leaving_groups(comp) + return comp + + def get_mol_type(self, ccd_code: str) -> str: + if ccd_code not in self.user_codes: + return ccd.get_mol_type(ccd_code) + + assert self._user_biotite_cif is not None + link_type = ( + self._user_biotite_cif[ccd_code]["chem_comp"]["type"].as_item().upper() + ) + if "PEPTIDE" in link_type and link_type != "PEPTIDE-LIKE": + return "protein" + if "DNA" in link_type: + return "dna" + if "RNA" in link_type: + return "rna" + return "ligand" + + def get_one_letter_code(self, ccd_code: str) -> str | None: + if ccd_code not in self.user_codes: + return ccd.get_one_letter_code(ccd_code) + + assert self._user_biotite_cif is not None + one = self._user_biotite_cif[ccd_code]["chem_comp"]["one_letter_code"].as_item() + if one == "?": + return None + return one + + def get_component_rdkit_mol(self, ccd_code: str) -> Chem.Mol | None: + if ccd_code not in self.user_codes: + return ccd.get_component_rdkit_mol(ccd_code) + if ccd_code not in self._custom_mols: + self._custom_mols[ccd_code] = self._build_user_rdkit_mol(ccd_code) + return self._custom_mols[ccd_code] + + def get_ccd_ref_info( + self, + ccd_code: str, + return_perm: bool = True, + return_atomic_number: bool = False, + ) -> dict[str, Any]: + if ccd_code not in self.user_codes: + return ccd.get_ccd_ref_info( + ccd_code, + return_perm=return_perm, + return_atomic_number=return_atomic_number, + ) + mol = self.get_component_rdkit_mol(ccd_code) + if mol is None: + return {} + return ccd.get_ccd_ref_info( + ccd_code, + return_perm=return_perm, + ccd_mols=((ccd_code, mol),), + return_atomic_number=return_atomic_number, + ) + + def _load_user_ccd(self, user_ccd_path: Path) -> None: + try: + self._user_biotite_cif = pdbx.CIFFile.read(str(user_ccd_path)) + self._user_gemmi_cif = gemmi.cif.read(str(user_ccd_path)) + except Exception as exc: + raise ValueError( + f"Failed to parse user CCD file {user_ccd_path}: {exc}" + ) from exc + if not list(self._user_biotite_cif.keys()): + raise ValueError( + f"User CCD file {user_ccd_path} does not contain components." + ) + for code in self._user_biotite_cif.keys(): + self._validate_user_component(code) + + def _validate_user_component(self, ccd_code: str) -> None: + assert self._user_biotite_cif is not None + block = self._user_biotite_cif[ccd_code] + for category in ["chem_comp", "chem_comp_atom", "chem_comp_bond"]: + try: + block[category] + except KeyError as exc: + raise ValueError( + f'User CCD component "{ccd_code}" is missing required ' + f'category "{category}".' + ) from exc + atom_category = block["chem_comp_atom"] + atom_names = atom_category["atom_id"].as_array() + if len(set(atom_names)) != len(atom_names): + raise ValueError( + f'User CCD component "{ccd_code}" has duplicate atom names.' + ) + + def _build_user_rdkit_mol(self, ccd_code: str) -> Chem.Mol: + assert self._user_gemmi_cif is not None + try: + ccd_block = self._user_gemmi_cif[ccd_code] + except KeyError as exc: + raise ValueError(f'User CCD component "{ccd_code}" not found.') from exc + + try: + ccd_reader_result = ccd_reader._parse_pdb_mmcif(ccd_block, sanitize=True) + except Exception as exc: + raise ValueError( + "Failed to build RDKit molecule for user CCD component " + f'"{ccd_code}": {exc}' + ) from exc + mol = ccd_reader_result.component.mol + mol.atom_map = {atom.GetProp("name"): atom.GetIdx() for atom in mol.GetAtoms()} + mol.name = ccd_code + mol.sanitized = ccd_reader_result.sanitized + mol.ref_conf_id = 0 + mol.ref_conf_type = "ideal" + + num_atom = mol.GetNumAtoms() + if num_atom == 0: + mol.ref_mask = np.zeros(0, dtype=bool) + return mol + + mol.ref_mask = self._build_ref_mask(ccd_block, mol) + + if not mol.sanitized: + return mol + options = AllChem.ETKDGv3() + options.clearConfs = False + try: + conf_id = AllChem.EmbedMolecule(mol, options) + if conf_id >= 0: + mol.ref_conf_id = conf_id + mol.ref_conf_type = "rdkit" + mol.ref_mask[:] = True + except Exception: + pass + return mol + + @staticmethod + def _build_ref_mask(ccd_block: gemmi.cif.Block, mol: rdkit.Chem.Mol) -> np.ndarray: + ref_mask = np.ones(mol.GetNumAtoms(), dtype=bool) + atoms = ccd_block.find( + "_chem_comp_atom.", ["atom_id", "pdbx_model_Cartn_x_ideal"] + ) + if len(atoms) != mol.GetNumAtoms(): + return ref_mask + ref_mask[:] = False + for row in atoms: + atom_id = gemmi.cif.as_string(row["_chem_comp_atom.atom_id"]) + atom_idx = mol.atom_map[atom_id] + x_ideal = row["_chem_comp_atom.pdbx_model_Cartn_x_ideal"] + ref_mask[atom_idx] = x_ideal != "?" + return ref_mask + + +DEFAULT_CCD_PROVIDER = CCDProvider() diff --git a/protenix/data/core/parser.py b/protenix/data/core/parser.py index 08061f5..e7ab7c9 100644 --- a/protenix/data/core/parser.py +++ b/protenix/data/core/parser.py @@ -21,7 +21,7 @@ from collections import Counter, defaultdict from datetime import datetime from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union warnings.filterwarnings( "ignore", message="Category 'chem_comp_bond' not found. No bonds will be parsed" @@ -2683,7 +2683,9 @@ class AddAtomArrayAnnot(object): @staticmethod def add_token_mol_type( - atom_array: AtomArray, sequences: dict[str, str] + atom_array: AtomArray, + sequences: dict[str, str], + get_mol_type_fn: Callable[[str], str] = ccd.get_mol_type, ) -> AtomArray: """ Add molecule types in atom_arry.mol_type based on ccd pdbx_type. @@ -2705,7 +2707,7 @@ def add_token_mol_type( continue res_name = atom_array.res_name[start] - mol_types[start:stop] = ccd.get_mol_type(res_name) + mol_types[start:stop] = get_mol_type_fn(res_name) atom_array.set_annotation("mol_type", mol_types) return atom_array @@ -2928,7 +2930,10 @@ def add_ref_space_uid(atom_array: AtomArray) -> AtomArray: return atom_array @staticmethod - def add_cano_seq_resname(atom_array: AtomArray) -> AtomArray: + def add_cano_seq_resname( + atom_array: AtomArray, + get_one_letter_code_fn: Callable[[str], str | None] = ccd.get_one_letter_code, + ) -> AtomArray: """ Assign to each atom the three-letter residue name (resname) corresponding to its place in the canonical sequences. @@ -2950,7 +2955,7 @@ def add_cano_seq_resname(atom_array: AtomArray) -> AtomArray: mol_type = atom_array.mol_type[start] resname = atom_array.res_name[start] - one_letter_code = ccd.get_one_letter_code(resname) + one_letter_code = get_one_letter_code_fn(resname) if one_letter_code is None or len(one_letter_code) != 1: # Some non-standard residues cannot be mapped back to one standard residue. one_letter_code = "X" if mol_type == "protein" else "N" diff --git a/protenix/data/inference/infer_dataloader.py b/protenix/data/inference/infer_dataloader.py index c4d4040..9bfbbbc 100644 --- a/protenix/data/inference/infer_dataloader.py +++ b/protenix/data/inference/infer_dataloader.py @@ -18,6 +18,7 @@ import time import traceback import warnings +from pathlib import Path from typing import Any, Mapping import torch @@ -158,6 +159,7 @@ def process_one( sample2feat = SampleDictToFeatures( single_sample_dict, extract_features_for_tfg=self.configs.sample_diffusion.guidance.enable, + input_json_dir=Path(self.input_json_path).parent, ) features_dict, atom_array, token_array = sample2feat.get_feature_dict() features_dict["distogram_rep_atom_mask"] = torch.Tensor( diff --git a/protenix/data/inference/json_parser.py b/protenix/data/inference/json_parser.py index 4005ef1..18ea5aa 100644 --- a/protenix/data/inference/json_parser.py +++ b/protenix/data/inference/json_parser.py @@ -16,6 +16,7 @@ import random import warnings from collections import Counter +from pathlib import Path from typing import Any import biotite.structure as struc @@ -25,6 +26,7 @@ from rdkit.Chem import AllChem from protenix.data.core import ccd +from protenix.data.core.custom_ccd import CCDProvider, DEFAULT_CCD_PROVIDER from protenix.utils.logger import get_logger logger = get_logger(__name__) @@ -75,7 +77,9 @@ } -def add_reference_features(atom_array: AtomArray) -> AtomArray: +def add_reference_features( + atom_array: AtomArray, ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER +) -> AtomArray: """ Add reference features of each resiude to atom_array @@ -103,11 +107,23 @@ def add_reference_features(atom_array: AtomArray) -> AtomArray: ref_mask[start:stop] = 1 continue - ref_info = ccd.get_ccd_ref_info(res_name) + ref_info = ccd_provider.get_ccd_ref_info(res_name) if ref_info: atom_sub_idx = [ *map(ref_info["atom_map"].get, atom_array.atom_name[start:stop]) ] + if any(idx is None for idx in atom_sub_idx): + missing_atoms = [ + atom_name + for atom_name, idx in zip( + atom_array.atom_name[start:stop], atom_sub_idx + ) + if idx is None + ] + raise ValueError( + f"Reference info for CCD {res_name} is missing atoms: " + f"{missing_atoms}" + ) ref_pos[start:stop] = ref_info["coord"][atom_sub_idx] ref_charge[start:stop] = ref_info["charge"][atom_sub_idx] ref_mask[start:stop] = ref_info["mask"][atom_sub_idx] @@ -120,7 +136,9 @@ def add_reference_features(atom_array: AtomArray) -> AtomArray: return atom_array -def _remove_non_std_ccd_leaving_atoms(atom_array: AtomArray) -> AtomArray: +def _remove_non_std_ccd_leaving_atoms( + atom_array: AtomArray, ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER +) -> AtomArray: """ Check polymer connections and remove non-standard leaving atoms @@ -148,12 +166,18 @@ def _remove_non_std_ccd_leaving_atoms(atom_array: AtomArray) -> AtomArray: f"all leaving atoms will be removed for both residues." ) for idx, res_name in zip([res_id, res_id + 1], [res_name_i, res_name_j]): - staying_atoms = ccd.get_component_atom_array( + component = ccd_provider.get_component_atom_array( res_name, keep_leaving_atoms=False, keep_hydrogens=False - ).atom_name - if idx == 1 and ccd.get_mol_type(res_name) in ("dna", "rna"): + ) + if component is None: + raise ValueError(f"Can not parse CCD component {res_name}.") + staying_atoms = component.atom_name + if idx == 1 and ccd_provider.get_mol_type(res_name) in ("dna", "rna"): staying_atoms = np.append(staying_atoms, ["OP3"]) - if idx == atom_array.res_id[-1] and ccd.get_mol_type(res_name) == "protein": + if ( + idx == atom_array.res_id[-1] + and ccd_provider.get_mol_type(res_name) == "protein" + ): staying_atoms = np.append(staying_atoms, ["OXT"]) leaving_atoms |= (atom_array.res_id == idx) & ( ~np.isin(atom_array.atom_name, staying_atoms) @@ -178,7 +202,11 @@ def find_range_by_index(starts: np.ndarray, atom_index: int) -> tuple[int, int]: raise ValueError(f"atom_index {atom_index} not found in starts {starts}") -def remove_leaving_atoms(atom_array: AtomArray, bond_count: dict) -> AtomArray: +def remove_leaving_atoms( + atom_array: AtomArray, + bond_count: dict, + ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER, +) -> AtomArray: """ Remove leaving atoms based on ccd info @@ -195,7 +223,7 @@ def remove_leaving_atoms(atom_array: AtomArray, bond_count: dict) -> AtomArray: res_name = atom_array.res_name[centre_idx] centre_name = atom_array.atom_name[centre_idx] - comp = ccd.get_component_atom_array( + comp = ccd_provider.get_component_atom_array( res_name, keep_leaving_atoms=True, keep_hydrogens=False ) if comp is None: @@ -272,7 +300,9 @@ def _add_bonds_to_terminal_residues(atom_array: AtomArray) -> AtomArray: return atom_array -def _build_polymer_atom_array(ccd_seqs: list[str]) -> tuple[AtomArray, struc.BondList]: +def _build_polymer_atom_array( + ccd_seqs: list[str], ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER +) -> tuple[AtomArray, struc.BondList]: """ Build polymer atom_array from ccd codes, but not remove leaving atoms @@ -286,13 +316,20 @@ def _build_polymer_atom_array(ccd_seqs: list[str]) -> tuple[AtomArray, struc.Bon chain = struc.AtomArray(0) for res_id, res_name in enumerate(ccd_seqs): # Keep all leaving atoms, will remove leaving atoms later - residue = ccd.get_component_atom_array( + residue = ccd_provider.get_component_atom_array( res_name, keep_leaving_atoms=True, keep_hydrogens=False ) + if residue is None: + raise ValueError( + f"Can not parse CCD component {res_name}. If this is a custom " + "component, provide it with userCCD or userCCDPath." + ) residue.res_id[:] = res_id + 1 chain += residue res_starts = struc.get_residue_starts(chain, add_exclusive_stop=True) - polymer_bonds = ccd._connect_inter_residue(chain, res_starts) + polymer_bonds = ccd._connect_inter_residue( + chain, res_starts, get_mol_type_fn=ccd_provider.get_mol_type + ) if chain.bonds is None: chain.bonds = polymer_bonds @@ -306,14 +343,39 @@ def _build_polymer_atom_array(ccd_seqs: list[str]) -> tuple[AtomArray, struc.Bon bond_count[i] = bond_count.get(i, 0) + 1 bond_count[j] = bond_count.get(j, 0) + 1 - chain = remove_leaving_atoms(chain, bond_count) + chain = remove_leaving_atoms(chain, bond_count, ccd_provider=ccd_provider) - chain = _remove_non_std_ccd_leaving_atoms(chain) + chain = _remove_non_std_ccd_leaving_atoms(chain, ccd_provider=ccd_provider) return chain -def build_polymer(entity_info: dict) -> dict: +def _validate_modification_position( + position: int, seq_len: int, code: str, position_key: str +) -> int: + index = int(position) - 1 + if index < 0 or index >= seq_len: + raise ValueError( + f"{position_key} for {code} must be in [1, {seq_len}], got {position}." + ) + return index + + +def _normalize_modification_code(code: str) -> str: + if not isinstance(code, str): + raise ValueError( + f"Modification code must be a string, got {type(code).__name__}." + ) + if code.startswith("CCD_"): + code = code[4:] + if not code: + raise ValueError("Modification code can not be empty.") + return code + + +def build_polymer( + entity_info: dict, ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER +) -> dict: """ Build a polymer from a polymer info dict example: { @@ -333,34 +395,32 @@ def build_polymer(entity_info: dict) -> dict: ccd_seqs = [PROTEIN_1to3[x] for x in info["sequence"]] if modifications := info.get("modifications"): for m in modifications: - index = m["ptmPosition"] - 1 - mtype = m["ptmType"] - if mtype.startswith("CCD_"): - ccd_seqs[index] = mtype[4:] - else: - raise ValueError(f"unknown modification type: {mtype}") + mtype = _normalize_modification_code(m["ptmType"]) + index = _validate_modification_position( + m["ptmPosition"], len(ccd_seqs), mtype, "ptmPosition" + ) + ccd_seqs[index] = mtype if glycans := info.get("glycans"): logger.warning(f"glycans not supported: {glycans}") - chain_array = _build_polymer_atom_array(ccd_seqs) + chain_array = _build_polymer_atom_array(ccd_seqs, ccd_provider=ccd_provider) elif poly_type in ("dnaSequence", "rnaSequence"): map_1to3 = DNA_1to3 if poly_type == "dnaSequence" else RNA_1to3 ccd_seqs = [map_1to3[x] for x in info["sequence"]] if modifications := info.get("modifications"): for m in modifications: - index = m["basePosition"] - 1 - mtype = m["modificationType"] - if mtype.startswith("CCD_"): - ccd_seqs[index] = mtype[4:] - else: - raise ValueError(f"unknown modification type: {mtype}") - chain_array = _build_polymer_atom_array(ccd_seqs) + mtype = _normalize_modification_code(m["modificationType"]) + index = _validate_modification_position( + m["basePosition"], len(ccd_seqs), mtype, "basePosition" + ) + ccd_seqs[index] = mtype + chain_array = _build_polymer_atom_array(ccd_seqs, ccd_provider=ccd_provider) else: raise ValueError( "polymer type must be proteinChain, dnaSequence or rnaSequence" ) - chain_array = add_reference_features(chain_array) + chain_array = add_reference_features(chain_array, ccd_provider=ccd_provider) return {"atom_array": chain_array} @@ -539,7 +599,9 @@ def target(mol, q): return atom_info -def build_ligand(entity_info: dict) -> dict: +def build_ligand( + entity_info: dict, ccd_provider: CCDProvider = DEFAULT_CCD_PROVIDER +) -> dict: """ Build a ligand from a ligand entity info dict example1: { @@ -585,9 +647,14 @@ def build_ligand(entity_info: dict) -> dict: atom_array = AtomArray(0) res_ids = [] for idx, code in enumerate(ccd_code): - ccd_atom_array = ccd.get_component_atom_array( + ccd_atom_array = ccd_provider.get_component_atom_array( code, keep_leaving_atoms=True, keep_hydrogens=False ) + if ccd_atom_array is None: + raise ValueError( + f"Can not parse CCD component {code}. If this is a custom " + "component, provide it with userCCD or userCCDPath." + ) atom_array += ccd_atom_array res_id = idx + 1 res_ids += [res_id] * len(ccd_atom_array) @@ -600,13 +667,17 @@ def build_ligand(entity_info: dict) -> dict: else: atom_info = smiles_to_atom_info(ligand_str) atom_info["atom_array"].res_id[:] = 1 - atom_info["atom_array"] = add_reference_features(atom_info["atom_array"]) + atom_info["atom_array"] = add_reference_features( + atom_info["atom_array"], ccd_provider=ccd_provider + ) # add a fake sequence for ligand, which is used for msa featurizer atom_info["sequence"] = "-" * len(atom_info["atom_array"]) return atom_info -def add_entity_atom_array(single_job_dict: dict) -> dict: +def add_entity_atom_array( + single_job_dict: dict, input_json_dir: str | Path | None = None +) -> dict: """ Add atom_array to each entity in single_job_dict @@ -617,18 +688,20 @@ def add_entity_atom_array(single_job_dict: dict) -> dict: dict: deepcopy and updated job dict with atom_array """ single_job_dict = copy.deepcopy(single_job_dict) + ccd_provider = CCDProvider.from_job(single_job_dict, base_dir=input_json_dir) sequences = single_job_dict["sequences"] single_job_dict["ccd_mols"] = {} + single_job_dict["_ccd_provider"] = ccd_provider smiles_ligand_count = 0 for entity_info in sequences: if info := entity_info.get("proteinChain"): - atom_info = build_polymer(entity_info) + atom_info = build_polymer(entity_info, ccd_provider=ccd_provider) elif info := entity_info.get("dnaSequence"): - atom_info = build_polymer(entity_info) + atom_info = build_polymer(entity_info, ccd_provider=ccd_provider) elif info := entity_info.get("rnaSequence"): - atom_info = build_polymer(entity_info) + atom_info = build_polymer(entity_info, ccd_provider=ccd_provider) elif info := entity_info.get("ligand"): - atom_info = build_ligand(entity_info) + atom_info = build_ligand(entity_info, ccd_provider=ccd_provider) if not info["ligand"].startswith("CCD_"): smiles_ligand_count += 1 assert smiles_ligand_count <= 99, "too many smiles ligands" @@ -639,10 +712,11 @@ def add_entity_atom_array(single_job_dict: dict) -> dict: "mol" ] elif info := entity_info.get("ion"): - atom_info = build_ligand(entity_info) + atom_info = build_ligand(entity_info, ccd_provider=ccd_provider) else: raise ValueError( "entity type must be proteinChain, dnaSequence, rnaSequence, ligand or ion" ) info.update(atom_info) + single_job_dict["ccd_mols"].update(ccd_provider.get_custom_mols()) return single_job_dict diff --git a/protenix/data/inference/json_to_feature.py b/protenix/data/inference/json_to_feature.py index 57208e2..c89ffe5 100644 --- a/protenix/data/inference/json_to_feature.py +++ b/protenix/data/inference/json_to_feature.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +from pathlib import Path from typing import Any import numpy as np @@ -36,11 +37,17 @@ class SampleDictToFeatures: def __init__( - self, single_sample_dict: dict[str, Any], extract_features_for_tfg: bool = False + self, + single_sample_dict: dict[str, Any], + extract_features_for_tfg: bool = False, + input_json_dir: str | Path | None = None, ) -> None: self.extract_features_for_tfg = extract_features_for_tfg self.single_sample_dict = single_sample_dict - self.input_dict = add_entity_atom_array(single_sample_dict) + self.input_dict = add_entity_atom_array( + single_sample_dict, input_json_dir=input_json_dir + ) + self.ccd_provider = self.input_dict["_ccd_provider"] self.entity_poly_type_and_seqs = self.get_entity_poly_type_and_seqs() self.entity_poly_type = self.entity_poly_type_and_seqs["entity_poly_type"] self.entity_to_sequences = self.entity_poly_type_and_seqs["entity_to_sequences"] @@ -279,13 +286,14 @@ def add_bonds_between_entities(self, atom_array: AtomArray) -> AtomArray: bond_count[atom_idx1] = bond_count.get(atom_idx1, 0) + 1 bond_count[atom_idx2] = bond_count.get(atom_idx2, 0) + 1 - atom_array = remove_leaving_atoms(atom_array, bond_count) + atom_array = remove_leaving_atoms( + atom_array, bond_count, ccd_provider=self.ccd_provider + ) return atom_array - @staticmethod def add_atom_array_attributes( - atom_array: AtomArray, entity_poly_type: dict[str, str] + self, atom_array: AtomArray, entity_poly_type: dict[str, str] ) -> AtomArray: """ Add attributes to the Biotite AtomArray. @@ -297,12 +305,19 @@ def add_atom_array_attributes( Returns: AtomArray: Biotite Atom array with attributes added. """ - atom_array = AddAtomArrayAnnot.add_token_mol_type(atom_array, entity_poly_type) + atom_array = AddAtomArrayAnnot.add_token_mol_type( + atom_array, + entity_poly_type, + get_mol_type_fn=self.ccd_provider.get_mol_type, + ) atom_array = AddAtomArrayAnnot.add_centre_atom_mask(atom_array) atom_array = AddAtomArrayAnnot.add_atom_mol_type_mask(atom_array) atom_array = AddAtomArrayAnnot.add_distogram_rep_atom_mask(atom_array) atom_array = AddAtomArrayAnnot.add_plddt_m_rep_atom_mask(atom_array) - atom_array = AddAtomArrayAnnot.add_cano_seq_resname(atom_array) + atom_array = AddAtomArrayAnnot.add_cano_seq_resname( + atom_array, + get_one_letter_code_fn=self.ccd_provider.get_one_letter_code, + ) atom_array = AddAtomArrayAnnot.add_tokatom_idx(atom_array) atom_array = AddAtomArrayAnnot.add_modified_res_mask(atom_array) atom_array = AddAtomArrayAnnot.unique_chain_and_add_ids(atom_array) diff --git a/protenix/web_service/colab_request_parser.py b/protenix/web_service/colab_request_parser.py index 33e9ee5..d9eeaab 100644 --- a/protenix/web_service/colab_request_parser.py +++ b/protenix/web_service/colab_request_parser.py @@ -139,6 +139,9 @@ def get_data_json(self) -> str: "name": (self.request["name"]), "covalent_bonds": self.request["covalent_bonds"], } + for ccd_key in ("userCCD", "userCCDPath"): + if ccd_key in self.request: + input_json_dict[ccd_key] = self.request[ccd_key] input_json_path = opjoin(self.request_dir, "inputs.json") sequences = [] @@ -172,6 +175,7 @@ def get_data_json(self) -> str: ccd.RKDIT_MOL_PKL = Path(cache_paths["ccd_components_rdkit_mol_file"]) sample2feat = SampleDictToFeatures( tmp_json_dict, + input_json_dir=self.request_dir, ) atom_array = sample2feat.get_atom_array() num_atoms = len(atom_array) diff --git a/tests/test_user_ccd_ptm.py b/tests/test_user_ccd_ptm.py new file mode 100644 index 0000000..a336306 --- /dev/null +++ b/tests/test_user_ccd_ptm.py @@ -0,0 +1,216 @@ +# ruff: noqa: E501 +import textwrap + +import numpy as np +import pytest + +from protenix.data.inference.json_to_feature import SampleDictToFeatures + + +def _component_block(code: str, one_letter_code: str) -> str: + return textwrap.dedent( + f""" + data_{code} + # + _chem_comp.id {code} + _chem_comp.name '{code} TEST COMPONENT' + _chem_comp.type 'L-PEPTIDE LINKING' + _chem_comp.pdbx_type ATOMP + _chem_comp.formula 'C3 H7 N O2' + _chem_comp.mon_nstd_parent_comp_id ? + _chem_comp.pdbx_synonyms ? + _chem_comp.pdbx_formal_charge 0 + _chem_comp.pdbx_initial_date 2026-05-10 + _chem_comp.pdbx_modified_date 2026-05-10 + _chem_comp.pdbx_ambiguous_flag N + _chem_comp.pdbx_release_status REL + _chem_comp.pdbx_replaced_by ? + _chem_comp.pdbx_replaces ? + _chem_comp.formula_weight 89.094 + _chem_comp.one_letter_code {one_letter_code} + _chem_comp.three_letter_code {code} + _chem_comp.pdbx_model_coordinates_db_code ? + _chem_comp.pdbx_model_coordinates_details ? + _chem_comp.pdbx_ideal_coordinates_details ? + _chem_comp.pdbx_ideal_coordinates_missing_flag N + _chem_comp.pdbx_model_coordinates_missing_flag N + _chem_comp.pdbx_processing_site ? + # + loop_ + _chem_comp_atom.comp_id + _chem_comp_atom.atom_id + _chem_comp_atom.alt_atom_id + _chem_comp_atom.type_symbol + _chem_comp_atom.charge + _chem_comp_atom.pdbx_align + _chem_comp_atom.pdbx_aromatic_flag + _chem_comp_atom.pdbx_leaving_atom_flag + _chem_comp_atom.pdbx_stereo_config + _chem_comp_atom.model_Cartn_x + _chem_comp_atom.model_Cartn_y + _chem_comp_atom.model_Cartn_z + _chem_comp_atom.pdbx_model_Cartn_x_ideal + _chem_comp_atom.pdbx_model_Cartn_y_ideal + _chem_comp_atom.pdbx_model_Cartn_z_ideal + _chem_comp_atom.pdbx_component_atom_id + _chem_comp_atom.pdbx_component_comp_id + _chem_comp_atom.pdbx_ordinal + {code} N N N 0 1 N N N 0.000 0.000 0.000 0.000 0.000 0.000 N {code} 1 + {code} CA CA C 0 1 N N N 1.450 0.000 0.000 1.450 0.000 0.000 CA {code} 2 + {code} C C C 0 1 N N N 2.020 1.410 0.000 2.020 1.410 0.000 C {code} 3 + {code} O O O 0 1 N N N 1.320 2.400 0.000 1.320 2.400 0.000 O {code} 4 + {code} CB CB C 0 1 N N N 1.980 -0.780 -1.200 1.980 -0.780 -1.200 CB {code} 5 + {code} OXT OXT O 0 1 N Y N 3.250 1.560 0.000 3.250 1.560 0.000 OXT {code} 6 + # + loop_ + _chem_comp_bond.comp_id + _chem_comp_bond.atom_id_1 + _chem_comp_bond.atom_id_2 + _chem_comp_bond.value_order + _chem_comp_bond.pdbx_aromatic_flag + _chem_comp_bond.pdbx_stereo_config + _chem_comp_bond.pdbx_ordinal + {code} N CA SING N N 1 + {code} CA C SING N N 2 + {code} C O DOUB N N 3 + {code} CA CB SING N N 4 + {code} C OXT SING N N 5 + # + """ + ) + + +def _user_ccd_text() -> str: + return "\n".join( + [ + _component_block("ALA", "A"), + _component_block("UAA", "K"), + _component_block("THR", "T"), + ] + ) + + +def _job_with_user_ccd(**extra): + job = { + "name": "custom_ptm", + "sequences": [ + { + "proteinChain": { + "sequence": "AKT", + "count": 1, + "modifications": [{"ptmType": "CCD_UAA", "ptmPosition": 2}], + } + } + ], + } + job.update(extra) + return job + + +def _atom_index(atom_array, res_id: int, atom_name: str) -> int: + indices = np.where( + (atom_array.res_id == res_id) & (atom_array.atom_name == atom_name) + )[0] + assert len(indices) == 1 + return int(indices[0]) + + +def _has_bond(atom_array, atom_i: int, atom_j: int) -> bool: + bonds = atom_array.bonds.as_array()[:, :2] + return any(set(pair) == {atom_i, atom_j} for pair in bonds) + + +def test_user_ccd_path_supports_custom_internal_protein_ptm(tmp_path): + ccd_path = tmp_path / "components.cif" + ccd_path.write_text(_user_ccd_text()) + sample = SampleDictToFeatures( + _job_with_user_ccd(userCCDPath="components.cif"), input_json_dir=tmp_path + ) + + atom_array = sample.get_atom_array() + + custom_mask = atom_array.res_name == "UAA" + assert custom_mask.any() + assert set(atom_array.mol_type[custom_mask]) == {"protein"} + assert not np.any(custom_mask & (atom_array.atom_name == "OXT")) + assert np.all(atom_array.modified_res_mask[custom_mask] == 1) + assert set(atom_array.cano_seq_resname[custom_mask]) == {"LYS"} + + assert _has_bond( + atom_array, + _atom_index(atom_array, 1, "C"), + _atom_index(atom_array, 2, "N"), + ) + assert _has_bond( + atom_array, + _atom_index(atom_array, 2, "C"), + _atom_index(atom_array, 3, "N"), + ) + assert np.all(atom_array.ref_mask[custom_mask] == 1) + + +def test_inline_user_ccd_matches_path_user_ccd(tmp_path): + ccd_path = tmp_path / "components.cif" + ccd_text = _user_ccd_text() + ccd_path.write_text(ccd_text) + + from_path = SampleDictToFeatures( + _job_with_user_ccd(userCCDPath="components.cif"), input_json_dir=tmp_path + ).get_atom_array() + from_inline = SampleDictToFeatures( + _job_with_user_ccd(userCCD=ccd_text), input_json_dir=tmp_path + ).get_atom_array() + + np.testing.assert_array_equal(from_path.res_name, from_inline.res_name) + np.testing.assert_array_equal(from_path.atom_name, from_inline.atom_name) + np.testing.assert_array_equal(from_path.mol_type, from_inline.mol_type) + + +def test_user_ccd_custom_ptm_builds_feature_dict_with_geometry(tmp_path): + ccd_path = tmp_path / "components.cif" + ccd_path.write_text(_user_ccd_text()) + sample = SampleDictToFeatures( + _job_with_user_ccd(userCCDPath="components.cif"), + extract_features_for_tfg=True, + input_json_dir=tmp_path, + ) + + feature_dict, atom_array, token_array = sample.get_feature_dict() + + assert "UAA" in sample.input_dict["ccd_mols"] + assert len(token_array) == int(np.sum(atom_array.centre_atom_mask)) + assert feature_dict["ref_pos"].shape[0] == len(atom_array) + assert "pairwise_distance_index" in feature_dict + + +def test_user_ccd_rejects_mutually_exclusive_sources(tmp_path): + ccd_path = tmp_path / "components.cif" + ccd_path.write_text(_user_ccd_text()) + + with pytest.raises(ValueError, match="Only one of"): + SampleDictToFeatures( + _job_with_user_ccd(userCCD=_user_ccd_text(), userCCDPath="components.cif"), + input_json_dir=tmp_path, + ) + + +def test_user_ccd_rejects_missing_relative_path(tmp_path): + with pytest.raises(FileNotFoundError, match="userCCDPath"): + SampleDictToFeatures( + _job_with_user_ccd(userCCDPath="missing.cif"), input_json_dir=tmp_path + ) + + +def test_user_ccd_rejects_empty_inline_ccd(tmp_path): + with pytest.raises(ValueError, match="userCCD"): + SampleDictToFeatures(_job_with_user_ccd(userCCD=""), input_json_dir=tmp_path) + + +def test_custom_ptm_position_is_validated(tmp_path): + ccd_path = tmp_path / "components.cif" + ccd_path.write_text(_user_ccd_text()) + job = _job_with_user_ccd(userCCDPath="components.cif") + job["sequences"][0]["proteinChain"]["modifications"][0]["ptmPosition"] = 99 + + with pytest.raises(ValueError, match="ptmPosition"): + SampleDictToFeatures(job, input_json_dir=tmp_path)