diff --git a/docs/colabfold_compatible_msa.md b/docs/colabfold_compatible_msa.md index f3f4c8c..04e10d8 100644 --- a/docs/colabfold_compatible_msa.md +++ b/docs/colabfold_compatible_msa.md @@ -7,6 +7,38 @@ Here's an example: python3 scripts/colabfold_msa.py examples/dimer.fasta dimer_colabfold_msa --db1 uniref30_2103_db --db3 colabfold_envdb_202108_db --mmseqs_path ``` +#### Using Existing Monomer A3Ms Directly + +If you already generated one A3M file per monomer, you can reference those files directly in the inference JSON with `monomerMsaPath`: + +```json +[ + { + "name": "two_chain_complex_with_cached_monomer_msas", + "sequences": [ + { + "proteinChain": { + "sequence": "ACDEFGHIK", + "count": 1, + "monomerMsaPath": "/abs/path/to/chain_a.a3m" + } + }, + { + "proteinChain": { + "sequence": "LMNPQRSTV", + "count": 1, + "monomerMsaPath": "/abs/path/to/chain_b.a3m" + } + } + ] + } +] +``` + +Protenix uses all valid rows in each file as unpaired MSA signal. Rows whose headers contain recognized species or taxonomy identifiers are also used for internal taxonomy-based pairing across chains. Compatible examples include Protenix-style `UniRef100__/` headers, UniProt-style headers, and headers with explicit taxonomy tags such as `OX=9606`. + +If a public ColabFold A3M does not contain species or taxonomy identifiers, Protenix will still use it as unpaired MSA, but it cannot create meaningful paired rows from that file alone. Existing split outputs remain supported through `pairedMsaPath` and `unpairedMsaPath`. + #### Configuring Colabfold_search Installation of colabfold and mmseqs2 is required. diff --git a/docs/infer_json_format.md b/docs/infer_json_format.md index f03b708..404c4a3 100644 --- a/docs/infer_json_format.md +++ b/docs/infer_json_format.md @@ -51,6 +51,7 @@ There are 5 kinds of supported sequences: "ptmPosition": 5 } ], + "monomerMsaPath": "/path/to/monomer_colabfold.a3m", "pairedMsaPath": "/path/to/pairing.a3m", "unpairedMsaPath": "/path/to/non_pairing.a3m", "templatesPath": "/path/to/hmmsearch.a3m" @@ -63,13 +64,16 @@ There are 5 kinds of supported sequences: * `modifications`: An optional list of dictionaries that describe post-translational modifications. * `ptmType`: A string containing CCD code of the modification. * `ptmPosition`: The position of the modified amino acid (integer). +* `monomerMsaPath`: The path to a single precomputed monomer MSA file, typically an externally generated `.a3m`. Protenix will use all valid rows as unpaired MSA signal and will also use rows with recognized species/taxonomy identifiers for internal MSA pairing. **Absolute paths are recommended.** * `pairedMsaPath`: The path to a precomputed MSA file used for pairing (typically `pairing.a3m`). **Absolute paths are recommended.** * `unpairedMsaPath`: The path to a precomputed non-pairing MSA file (typically `non_pairing.a3m`). **Absolute paths are recommended.** * `templatesPath`: The path to a precomputed template file. Supported formats include `.a3m` (e.g., generated by `hmmsearch`) and `.hhr`. **Absolute paths are recommended.** * For `.a3m` examples, see `examples/examples_with_template/example_9fm7.json`. * For `.hhr` examples, see `examples/examples_with_template/example_mgyp004658859411.json`. -> 💡 **Note**: `pairedMsaPath`, `unpairedMsaPath`, and `templatesPath` are all **Optional**. If these fields are not provided, the model will proceed with inference without using the corresponding features, which may lead to a potential decrease in prediction accuracy. +> 💡 **Note**: `monomerMsaPath`, `pairedMsaPath`, `unpairedMsaPath`, and `templatesPath` are all **Optional**. If these fields are not provided, the model will proceed with inference without using the corresponding features, which may lead to a potential decrease in prediction accuracy. + +> 💡 **MSA precedence**: If `pairedMsaPath` or `unpairedMsaPath` is provided, Protenix uses those explicit split MSA files. Otherwise, if `monomerMsaPath` is provided, Protenix logically separates the single monomer A3M into paired-capable rows and unpaired rows in memory. Species/taxonomy identifiers are required for internal pairing; speciesless rows are still used as unpaired MSA. > ⚠️ **Note**: The previous `msa` field, which used a dictionary format (e.g., `"msa": {"precomputed_msa_dir": "...", "pairing_db": "uniref100"}`), is still compatible but is being deprecated. For an example of this old format, see `examples/example.json`. It is recommended to use the new fields `pairedMsaPath` and `unpairedMsaPath` instead. @@ -395,4 +399,4 @@ The contents of each output file are as follows: - `has_clash` - Boolean flag indicating if there are steric clashes in the predicted structure. - `disorder` - Predicted regions of intrinsic disorder within the protein, highlighting residues that may be flexible or unstructured. - `ranking_score` - Predicted confidence score for ranking complexes. Higher values indicate greater confidence. - - `num_recycles`: Number of recycling steps used during inference. \ No newline at end of file + - `num_recycles`: Number of recycling steps used during inference. diff --git a/examples/example_monomer_msa_path.json b/examples/example_monomer_msa_path.json new file mode 100644 index 0000000..a3de5a4 --- /dev/null +++ b/examples/example_monomer_msa_path.json @@ -0,0 +1,21 @@ +[ + { + "name": "example_monomer_msa_path", + "sequences": [ + { + "proteinChain": { + "sequence": "ACDEFGHIK", + "count": 1, + "monomerMsaPath": "./examples/monomer_msa/chain_a.a3m" + } + }, + { + "proteinChain": { + "sequence": "LMNPQRSTV", + "count": 1, + "monomerMsaPath": "./examples/monomer_msa/chain_b.a3m" + } + } + ] + } +] diff --git a/examples/monomer_msa/chain_a.a3m b/examples/monomer_msa/chain_a.a3m new file mode 100644 index 0000000..9fd320e --- /dev/null +++ b/examples/monomer_msa/chain_a.a3m @@ -0,0 +1,6 @@ +>query +ACDEFGHIK +>UniRef100_CHAINA_9606/ +ACDEFGHIK +>environmental_hit_a +ACD-FGHIK diff --git a/examples/monomer_msa/chain_b.a3m b/examples/monomer_msa/chain_b.a3m new file mode 100644 index 0000000..6d3156e --- /dev/null +++ b/examples/monomer_msa/chain_b.a3m @@ -0,0 +1,6 @@ +>query +LMNPQRSTV +>UniRef100_CHAINB_9606/ +LMNPQRSTV +>environmental_hit_b +LMN-QRSTV diff --git a/protenix/data/msa/msa_featurizer.py b/protenix/data/msa/msa_featurizer.py index 926d5aa..2e32743 100644 --- a/protenix/data/msa/msa_featurizer.py +++ b/protenix/data/msa/msa_featurizer.py @@ -33,7 +33,7 @@ NUM_SEQ_NUM_RES_MSA_FEATURES, RawMsa, ) -from protenix.utils.file_io import load_json_cached +from protenix.data.msa.msa_input import split_monomer_msa from protenix.utils.logger import get_logger logger = get_logger(__name__) @@ -400,6 +400,8 @@ def __init__( ) -> None: self.dataset_name = dataset_name super().__init__() + from protenix.utils.file_io import load_json_cached + # Initialize source managers for protein and RNA self.prot_mgr = MSASourceManager( prot_msadir_raw_paths, @@ -616,6 +618,30 @@ def make_msa_feature( if p_a3m is None and c.get("pairedMsaPath"): with open(c["pairedMsaPath"]) as f: p_a3m = f.read() + if u_a3m is None and (p_a3m is None) and c.get("monomerMsaPath"): + with open(c["monomerMsaPath"]) as f: + monomer_a3m = f.read() + monomer_msa = split_monomer_msa( + query_sequence=seq, + a3m=monomer_a3m, + source_name=c["monomerMsaPath"], + ) + p_a3m = monomer_msa.paired_a3m + u_a3m = monomer_msa.unpaired_a3m + if monomer_msa.pairable_rows == 0: + logger.warning( + "monomerMsaPath for protein entity %s contains no " + "recognized species/taxonomy identifiers; using it as " + "unpaired MSA only.", + eid + 1, + ) + if monomer_msa.invalid_rows > 0: + logger.warning( + "monomerMsaPath for protein entity %s skipped %s rows " + "whose aligned length did not match the query sequence.", + eid + 1, + monomer_msa.invalid_rows, + ) if u_a3m is None and (p_a3m is None): if c.get("msa"): msa_dir = c["msa"].get("precomputed_msa_dir") diff --git a/protenix/data/msa/msa_input.py b/protenix/data/msa/msa_input.py new file mode 100644 index 0000000..77a2906 --- /dev/null +++ b/protenix/data/msa/msa_input.py @@ -0,0 +1,95 @@ +# Copyright 2024 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from dataclasses import dataclass + +from protenix.data.constants import MSA_PROTEIN_SEQ_TO_ID +from protenix.data.msa.msa_utils import extract_species_id +from protenix.data.tools.common import parse_fasta + + +@dataclass(frozen=True) +class MonomerMsaInput: + """Logical split of one monomer A3M into paired and unpaired MSA streams.""" + + paired_a3m: str + unpaired_a3m: str + total_rows: int + pairable_rows: int + unpairable_rows: int + invalid_rows: int + + +def _aligned_len(seq: str) -> int: + """Count A3M aligned columns using the same protein alphabet as MSA encoding.""" + return sum(1 for c in seq if c in MSA_PROTEIN_SEQ_TO_ID) + + +def _to_a3m(rows: Sequence[tuple[str, str]]) -> str: + return "".join(f">{description}\n{sequence}\n" for description, sequence in rows) + + +def split_monomer_msa( + query_sequence: str, + a3m: str, + source_name: str = "", +) -> MonomerMsaInput: + """ + Split one monomer A3M into Protenix paired and unpaired MSA streams. + + All valid non-query rows are kept as unpaired MSA signal. Rows with an + extractable species/taxonomy identifier are also exposed to the existing + taxonomy-based pairing path. + """ + sequences, descriptions = parse_fasta(a3m) + if not sequences: + raise ValueError(f"{source_name} does not contain any A3M/FASTA rows.") + + query_aligned_len = _aligned_len(query_sequence) + first_aligned_len = _aligned_len(sequences[0]) + if first_aligned_len != query_aligned_len: + raise ValueError( + f"{source_name} query row aligned length ({first_aligned_len}) does not " + f"match protein sequence length ({query_aligned_len})." + ) + + query_row = ("query", query_sequence) + paired_rows = [query_row] + unpaired_rows = [query_row] + pairable_rows = 0 + unpairable_rows = 0 + invalid_rows = 0 + + for seq, desc in zip(sequences[1:], descriptions[1:]): + if _aligned_len(seq) != query_aligned_len: + invalid_rows += 1 + continue + + unpaired_rows.append((desc, seq)) + species_id = extract_species_id(desc) + if species_id: + paired_rows.append((desc, seq)) + pairable_rows += 1 + else: + unpairable_rows += 1 + + return MonomerMsaInput( + paired_a3m=_to_a3m(paired_rows), + unpaired_a3m=_to_a3m(unpaired_rows), + total_rows=len(sequences), + pairable_rows=pairable_rows, + unpairable_rows=unpairable_rows, + invalid_rows=invalid_rows, + ) diff --git a/protenix/data/msa/msa_utils.py b/protenix/data/msa/msa_utils.py index f2d8c95..7346aea 100644 --- a/protenix/data/msa/msa_utils.py +++ b/protenix/data/msa/msa_utils.py @@ -42,6 +42,9 @@ r"(?:tr|sp)\|[A-Z0-9]{6,10}(?:_\d+)?\|(?:[A-Z0-9]{1,10}_)(?P[A-Z0-9]{1,5})" ) _UNIREF_REGEX = re.compile(r"^UniRef100_[^_]+_([^_/]+)") +_TAXONOMY_TAG_REGEX = re.compile( + r"(?:^|\s)(?:OX|TaxID|taxid|Tax)=(?P\d+)(?:\b|$)" +) MSA_GAP_IDX = STD_RESIDUES_WITH_GAP.get("-") NUM_SEQ_NUM_RES_MSA_FEATURES = ("msa", "msa_mask", "deletion_matrix") @@ -49,6 +52,19 @@ MSA_PAD_VALUES = {"msa": MSA_GAP_IDX, "msa_mask": 1, "deletion_matrix": 0} +def extract_species_id(description: str) -> str: + """Extract a species/taxonomy identifier from an MSA description line.""" + description = description.strip() + m = _UNIPROT_REGEX.match(description) or _UNIREF_REGEX.match(description) + if m: + return m.group("SpeciesId") if "SpeciesId" in m.groupdict() else m.group(1) + + m = _TAXONOMY_TAG_REGEX.search(description) + if m: + return m.group("SpeciesId") + return "" + + class MSACore: """Basic MSA parsing and numerical conversion operations.""" @@ -265,17 +281,7 @@ class MSAPairingEngine: @staticmethod def get_species_ids(descriptions: Sequence[str]) -> List[str]: """Extracts species identifiers from UniProt or UniRef description lines.""" - ids = [] - for d in descriptions: - d = d.strip() - m = _UNIPROT_REGEX.match(d) or _UNIREF_REGEX.match(d) - if m: - ids.append( - m.group("SpeciesId") if "SpeciesId" in m.groupdict() else m.group(1) - ) - else: - ids.append("") - return ids + return [extract_species_id(d) for d in descriptions] @staticmethod def _align_species( diff --git a/runner/msa_search.py b/runner/msa_search.py index 07260c8..1e1107f 100644 --- a/runner/msa_search.py +++ b/runner/msa_search.py @@ -17,7 +17,6 @@ from typing import Any, Sequence, Tuple from protenix.utils.logger import get_logger -from protenix.web_service.colab_request_parser import RequestParser logger = get_logger(__name__) @@ -33,29 +32,34 @@ def need_msa_search(json_data: dict) -> bool: bool: True if an MSA search is required, False otherwise. """ need_msa = False - # the new format of msa filed is `pairedMsaPath` and `unpairedMsaPath` - # we need to check `pairedMsaPath` and `unpairedMsaPath` + # the new format of msa field is `pairedMsaPath` and `unpairedMsaPath`; + # `monomerMsaPath` is an inference-time single-file MSA source. for sequence in json_data["sequences"]: if "proteinChain" in sequence: protein_chain = sequence["proteinChain"] - paired_msa_path = protein_chain.get("pairedMsaPath") - unpaired_msa_path = protein_chain.get("unpairedMsaPath") - - if paired_msa_path is None and unpaired_msa_path is None: + split_msa_paths = { + "pairedMsaPath": protein_chain.get("pairedMsaPath"), + "unpairedMsaPath": protein_chain.get("unpairedMsaPath"), + } + provided_paths = { + field_name: path + for field_name, path in split_msa_paths.items() + if path is not None + } + monomer_msa_path = protein_chain.get("monomerMsaPath") + + if not provided_paths and monomer_msa_path is not None: + provided_paths = {"monomerMsaPath": monomer_msa_path} + + if not provided_paths: need_msa = True else: - if paired_msa_path is not None and not os.path.exists(paired_msa_path): - logger.warning( - f"pairedMsaPath {paired_msa_path} does not exist, will re-search MSA." - ) - need_msa = True - if unpaired_msa_path is not None and not os.path.exists( - unpaired_msa_path - ): - logger.warning( - f"unpairedMsaPath {unpaired_msa_path} does not exist, will re-search MSA." - ) - need_msa = True + for field_name, path in provided_paths.items(): + if not os.path.exists(path): + logger.warning( + f"{field_name} {path} does not exist, will re-search MSA." + ) + need_msa = True return need_msa @@ -136,6 +140,8 @@ def msa_search( Returns: Sequence[str]: List of directories containing MSA results for each sequence. """ + from protenix.web_service.colab_request_parser import RequestParser + os.makedirs(msa_res_dir, exist_ok=True) tmp_fasta_fpath = os.path.join(msa_res_dir, f"tmp_{uuid.uuid4().hex}.fasta") msa_res_subdirs = RequestParser.msa_search( @@ -180,13 +186,13 @@ def update_seq_msa(infer_seq: dict, msa_res_dir: str, mode: str) -> dict: sequence["proteinChain"]["sequence"] ] if os.path.exists(f"{precomputed_msa_dir}/pairing.a3m"): - sequence["proteinChain"][ - "pairedMsaPath" - ] = f"{precomputed_msa_dir}/pairing.a3m" + sequence["proteinChain"]["pairedMsaPath"] = ( + f"{precomputed_msa_dir}/pairing.a3m" + ) if os.path.exists(f"{precomputed_msa_dir}/non_pairing.a3m"): - sequence["proteinChain"][ - "unpairedMsaPath" - ] = f"{precomputed_msa_dir}/non_pairing.a3m" + sequence["proteinChain"]["unpairedMsaPath"] = ( + f"{precomputed_msa_dir}/non_pairing.a3m" + ) return infer_seq diff --git a/tests/test_inference_monomer_msa.py b/tests/test_inference_monomer_msa.py new file mode 100644 index 0000000..164e646 --- /dev/null +++ b/tests/test_inference_monomer_msa.py @@ -0,0 +1,79 @@ +import logging + +import numpy as np + +from protenix.data.msa.msa_featurizer import InferenceMSAFeaturizer + + +class FakeAtomArray: + def __init__(self, asym_id_int, chain_id, res_id): + self.asym_id_int = np.asarray(asym_id_int) + self.chain_id = np.asarray(chain_id) + self.res_id = np.asarray(res_id) + self.centre_atom_mask = np.ones(len(self.res_id), dtype=bool) + + def __getitem__(self, item): + return FakeAtomArray( + asym_id_int=self.asym_id_int[item], + chain_id=self.chain_id[item], + res_id=self.res_id[item], + ) + + +def _two_chain_atom_array(seq_a: str, seq_b: str) -> FakeAtomArray: + return FakeAtomArray( + asym_id_int=[0] * len(seq_a) + [1] * len(seq_b), + chain_id=["A"] * len(seq_a) + ["B"] * len(seq_b), + res_id=list(range(1, len(seq_a) + 1)) + list(range(1, len(seq_b) + 1)), + ) + + +def test_inference_monomer_msa_path_produces_paired_rows(tmp_path): + seq_a, seq_b = "ACDEF", "FGHIK" + msa_a = tmp_path / "chain_a.a3m" + msa_b = tmp_path / "chain_b.a3m" + msa_a.write_text( + ">query\nACDEF\n>UniRef100_HITA_9606/\nACDEF\n>env_a\nAC-EF\n", + encoding="utf-8", + ) + msa_b.write_text( + ">query\nFGHIK\n>UniRef100_HITB_9606/\nFGHIK\n>env_b\nF-HIK\n", + encoding="utf-8", + ) + bioassembly = [ + {"proteinChain": {"sequence": seq_a, "count": 1, "monomerMsaPath": str(msa_a)}}, + {"proteinChain": {"sequence": seq_b, "count": 1, "monomerMsaPath": str(msa_b)}}, + ] + + features = InferenceMSAFeaturizer.make_msa_feature( + bioassembly=bioassembly, + atom_array=_two_chain_atom_array(seq_a, seq_b), + msa_pair_as_unpair=False, + ) + + assert int(features["prot_paired_num_alignments"]) == 2 + assert int(features["prot_unpaired_num_alignments"]) >= 1 + assert features["msa"].shape[1] == len(seq_a) + len(seq_b) + + +def test_inference_monomer_msa_path_without_taxonomy_is_unpaired_only(tmp_path, caplog): + seq_a, seq_b = "ACDEF", "FGHIK" + msa_a = tmp_path / "chain_a.a3m" + msa_b = tmp_path / "chain_b.a3m" + msa_a.write_text(">query\nACDEF\n>env_a\nAC-EF\n", encoding="utf-8") + msa_b.write_text(">query\nFGHIK\n>env_b\nF-HIK\n", encoding="utf-8") + bioassembly = [ + {"proteinChain": {"sequence": seq_a, "count": 1, "monomerMsaPath": str(msa_a)}}, + {"proteinChain": {"sequence": seq_b, "count": 1, "monomerMsaPath": str(msa_b)}}, + ] + + with caplog.at_level(logging.WARNING): + features = InferenceMSAFeaturizer.make_msa_feature( + bioassembly=bioassembly, + atom_array=_two_chain_atom_array(seq_a, seq_b), + msa_pair_as_unpair=False, + ) + + assert int(features["prot_paired_num_alignments"]) == 1 + assert int(features["prot_unpaired_num_alignments"]) == 1 + assert "using it as unpaired MSA only" in caplog.text diff --git a/tests/test_msa_input.py b/tests/test_msa_input.py new file mode 100644 index 0000000..45e6fb1 --- /dev/null +++ b/tests/test_msa_input.py @@ -0,0 +1,119 @@ +import pytest + +from protenix.data.msa.msa_input import split_monomer_msa +from protenix.data.msa.msa_utils import MSAPairingEngine +from protenix.data.tools.common import parse_fasta +from runner.msa_search import need_msa_search + + +def test_species_ids_support_explicit_taxonomy_tags(): + descriptions = [ + "UniRef100_HITA_9606/", + "sp|P12345|ABC_HUMAN Some protein OX=9606", + "hit TaxID=10090 description", + "hit taxid=7227 description", + "hit Tax=7955 description", + "speciesless_hit_42", + ] + + assert MSAPairingEngine.get_species_ids(descriptions) == [ + "9606", + "HUMAN", + "10090", + "7227", + "7955", + "", + ] + + +def test_split_monomer_msa_separates_pairable_and_unpairable_rows(): + split = split_monomer_msa( + query_sequence="ACDE", + a3m=( + ">query\n" + "ACDE\n" + ">UniRef100_HITA_9606/\n" + "ACDE\n" + ">environmental_hit\n" + "ACdDE\n" + ">bad_length\n" + "ACD\n" + ), + source_name="test.a3m", + ) + + paired_seqs, paired_descs = parse_fasta(split.paired_a3m) + unpaired_seqs, unpaired_descs = parse_fasta(split.unpaired_a3m) + + assert split.total_rows == 4 + assert split.pairable_rows == 1 + assert split.unpairable_rows == 1 + assert split.invalid_rows == 1 + assert paired_descs == ["query", "UniRef100_HITA_9606/"] + assert paired_seqs == ["ACDE", "ACDE"] + assert unpaired_descs == ["query", "UniRef100_HITA_9606/", "environmental_hit"] + assert unpaired_seqs == ["ACDE", "ACDE", "ACdDE"] + + +def test_split_monomer_msa_rejects_query_length_mismatch(): + with pytest.raises(ValueError, match="query row aligned length"): + split_monomer_msa( + query_sequence="ACDE", + a3m=">query\nACD\n", + source_name="bad.a3m", + ) + + +def test_need_msa_search_accepts_existing_monomer_msa_path(tmp_path): + msa_path = tmp_path / "chain_a.a3m" + msa_path.write_text(">query\nACDE\n", encoding="utf-8") + + json_data = { + "sequences": [ + { + "proteinChain": { + "sequence": "ACDE", + "count": 1, + "monomerMsaPath": str(msa_path), + } + } + ] + } + + assert need_msa_search(json_data) is False + + +def test_need_msa_search_researches_missing_monomer_msa_path(tmp_path): + json_data = { + "sequences": [ + { + "proteinChain": { + "sequence": "ACDE", + "count": 1, + "monomerMsaPath": str(tmp_path / "missing.a3m"), + } + } + ] + } + + assert need_msa_search(json_data) is True + + +def test_need_msa_search_honors_split_path_precedence(tmp_path): + paired_msa_path = tmp_path / "pairing.a3m" + paired_msa_path.write_text(">query\nACDE\n", encoding="utf-8") + + json_data = { + "sequences": [ + { + "proteinChain": { + "sequence": "ACDE", + "count": 1, + "pairedMsaPath": str(paired_msa_path), + "monomerMsaPath": str(tmp_path / "missing_monomer.a3m"), + } + } + ] + } + + assert need_msa_search(json_data) is False