Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unpaired alignment #398

Open
wants to merge 12 commits into
base: beta
Choose a base branch
from
2 changes: 1 addition & 1 deletion colabfold/alphafold/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def make_fixed_size_multimer(
feat: Mapping[str, Any],
shape_schema,
num_res,
msa_cluster_size,
num_templates) -> FeatureDict:
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
msa_cluster_size = feat["bert_mask"].shape[0]
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
Expand Down
69 changes: 62 additions & 7 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

from colabfold.inputs import (
get_queries_pairwise, unpack_a3ms,
parse_fasta, get_queries,
parse_fasta, get_queries, msa_to_str
)
from colabfold.run_alphafold import set_model_type

from colabfold.download import default_data_dir, download_alphafold_params

import sys
import logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -299,24 +300,43 @@ def main():
headers_list[0].remove(headers_list[0][0])
header_first = headers[0]

queries_temp = []
queries_rest = []
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
max_msa_cluster = 0
else:
max_msa_cluster = None
for jobname, batch in enumerate(output):
query_seqs_unique = []
for x in batch:
if x not in query_seqs_unique:
query_seqs_unique.append(x)
query_seqs_cardinality = [0] * len(query_seqs_unique)
for seq in batch:
seq_idx = query_seqs_unique.index(seq)
query_seqs_cardinality[seq_idx] += 1
use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode
paired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname))),
str(Path(args.results).joinpath(str(jobname)+"_paired")),
use_env=use_env,
use_pairwise=True,
use_pairing=True,
host_url=args.host_url,
)

path_o = Path(args.results).joinpath(f"{jobname}_pairwise")
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env")
unpaired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname)+"_unpaired")),
use_env=use_env,
use_pairwise=False,
use_pairing=False,
host_url=args.host_url,
)
path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise")
for filenum in path_o.iterdir():
queries_new = []
queries_new = []
if Path(filenum).suffix.lower() == ".a3m":
outdir = path_o.joinpath("tmp")
unpack_a3ms(filenum, outdir)
Expand All @@ -326,14 +346,49 @@ def main():
query_sequence = seqs[0]
a3m_lines = [Path(file).read_text()]
val = int(header[0].split('\t')[1][1:]) - 102
# match paired seq id and unpaired seq id
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
paired_query_a3m_lines = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1]
# a3m_lines = [msa_to_str(
# [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]
# )]
## Another way: do not use msa_to_str and unserialize function rather
## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments..
a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]]
queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines))

### generate features then find max_msa_cluster
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
inputs = ([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]])
from colabfold.inputs import generate_msa_size
msa_size = generate_msa_size(inputs, query_seqs_unique, args.use_templates, is_complex, model_type)
# config.model.embeddings_and_evoformer.extra_msa_seqs=2048
# config.model.embeddings_and_evoformer.num_msa=508
# if msa_size < 2048 + 508, pop the sequences and run the model with recompilation
if msa_size < 2556:
queries_rest.append(queries_new.pop())
continue
max_msa_cluster = max(max_msa_cluster, msa_size)

if args.sort_queries_by == "length":
queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True)
elif args.sort_queries_by == "random":
random.shuffle(queries_new)
queries_temp.append(queries_new)

queries_sel = sum(queries_temp, [])
run_params["max_msa_cluster"] = max_msa_cluster
run_params["interaction_scan"] = args.interaction_scan
run(queries=queries_sel, **run_params)

run(queries=queries_new, **run_params)
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
if len(queries_rest) > 0:
if args.sort_queries_by == "length":
queries_rest.sort(key=lambda t: len(''.join(t[1])),reverse=True)
elif args.sort_queries_by == "random":
random.shuffle(queries_rest)
run_params["max_msa_cluster"] = None
run(queries=queries_rest, **run_params)

else:
run(queries=queries, **run_params)
Expand Down
22 changes: 19 additions & 3 deletions colabfold/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def pad_input_multimer(
model_runner: model.RunModel,
model_name: str,
pad_len: int,
msa_cluster_size: Optional[int],
use_templates: bool,
) -> model.features.FeatureDict:
model_config = model_runner.config
shape_schema = {
"aatype": ["num residues placeholder"],
"residue_index": ["num residues placeholder"],
Expand Down Expand Up @@ -123,6 +123,7 @@ def pad_input_multimer(
input_features,
shape_schema,
num_res=pad_len,
msa_cluster_size=msa_cluster_size,
num_templates=4,
)
return input_fix
Expand Down Expand Up @@ -654,6 +655,20 @@ def generate_input_feature(
}
return (input_feature, domain_names)

def generate_msa_size(inputs, query_seqs_unique, use_templates, is_complex, model_type):
template_features_ = []
from colabfold.inputs import mk_mock_template
from colabfold.inputs import generate_input_feature
for query_seq in query_seqs_unique:
template_feature = mk_mock_template(query_seq)
template_features_.append(template_feature)
if not use_templates: template_features = template_features_
else: raise NotImplementedError

(feature_dict, _) \
= generate_input_feature(*inputs, template_features, is_complex, model_type)
return feature_dict["bert_mask"].shape[0]

def unserialize_msa(
a3m_lines: List[str], query_sequence: Union[List[str], str]
) -> Tuple[
Expand Down Expand Up @@ -696,7 +711,7 @@ def unserialize_msa(
)
prev_query_start += query_len
paired_msa = [""] * len(query_seq_len)
unpaired_msa = None
unpaired_msa = [""] * len(query_seq_len)
already_in = dict()
for i in range(1, len(a3m_lines), 2):
header = a3m_lines[i]
Expand Down Expand Up @@ -734,7 +749,6 @@ def unserialize_msa(
paired_msa[j] += ">" + header_no_faster_split[j] + "\n"
paired_msa[j] += seqs_line[j] + "\n"
else:
unpaired_msa = [""] * len(query_seq_len)
for j, seq in enumerate(seqs_line):
if has_amino_acid[j]:
unpaired_msa[j] += header + "\n"
Expand All @@ -752,6 +766,8 @@ def unserialize_msa(
template_feature = mk_mock_template(query_seq)
template_features.append(template_feature)

if unpaired_msa == [""] * len(query_seq_len):
unpaired_msa = None
return (
unpaired_msa,
paired_msa,
Expand Down
140 changes: 100 additions & 40 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,50 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Union
import os
import os, pandas

from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise
from colabfold.inputs import get_queries, msa_to_str, parse_fasta
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

def get_queries_pairwise(
input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
"""Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
of job name, sequence and the optional a3m lines"""
input_path = Path(input_path)
if not input_path.exists():
raise OSError(f"{input_path} could not be found")
if input_path.is_file():
if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
sep = "\t" if input_path.suffix == ".tsv" else ","
df = pandas.read_csv(input_path, sep=sep)
assert "id" in df.columns and "sequence" in df.columns
queries = [
(str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None)
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False))
]
elif input_path.suffix == ".a3m":
raise NotImplementedError()
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
(sequences, headers) = parse_fasta(input_path.read_text())
queries = []
for i, (sequence, header) in enumerate(zip(sequences, headers)):
sequence = sequence.upper()
if sequence.count(":") == 0:
# Single sequence
queries.append((header, sequence, None))
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
else:
raise ValueError(f"Unknown file format {input_path.suffix}")
else:
raise NotImplementedError()

is_complex = True
return queries, is_complex

def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
params_log = " ".join(str(i) for i in params)
Expand Down Expand Up @@ -61,7 +99,6 @@ def mmseqs_search_monomer(
used_dbs.append(template_db)
if use_env:
used_dbs.append(metagenomic_db)

for db in used_dbs:
if not dbbase.joinpath(f"{db}.dbtype").is_file():
raise FileNotFoundError(f"Database {db} does not exist")
Expand Down Expand Up @@ -405,9 +442,9 @@ def main():
args = parser.parse_args()

if args.interaction_scan:
queries, is_complex = get_queries_pairwise(args.query, None)
queries, is_complex = get_queries_pairwise(args.query)
else:
queries, is_complex = get_queries(args.query, None)
queries, is_complex = get_queries(args.query)

queries_unique = []
for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries):
Expand Down Expand Up @@ -437,10 +474,9 @@ def main():
query_seqs_cardinality,
) in enumerate(queries_unique):
if job_number==0:
f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n")
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{raw_jobname}_0\n{query_sequences}\n")
else:
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n")
else:
with query_file.open("w") as f:
for job_number, (
Expand All @@ -454,18 +490,6 @@ def main():
args.mmseqs,
["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"],
)
with args.base.joinpath("qdb.lookup").open("w") as f:
id = 0
file_number = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
for seq in query_sequences:
f.write(f"{id}\t{raw_jobname}\t{file_number}\n")
id += 1
file_number += 1

mmseqs_search_monomer(
mmseqs=args.mmseqs,
Expand Down Expand Up @@ -498,30 +522,66 @@ def main():
interaction_scan=args.interaction_scan,
)

if args.interaction_scan:
if len(queries_unique) > 1:
for i in range(len(queries_unique)-2):
idx = 2 + i*2
## delete duplicated query files 2.paired, 4.paired...
os.remove(args.base.joinpath(f"{idx}.paired.a3m"))
for j in range(len(queries_unique)-2):
# replace targets' right file name
id1 = j*2 + 3
id2 = j + 2
os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m"))

id = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
if not args.interaction_scan:
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
paired_msa = []
else:
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
else:
for job_number, _ in enumerate(queries_unique[:-1]):
query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]]
unpaired_msa = []
paired_msa = []
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
with args.base.joinpath(f"0.a3m").open("r") as f:
unpaired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)

with args.base.joinpath(f"0.paired.a3m").open("r") as f:
paired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, [1,1]
)
args.base.joinpath(f"{job_number}_final.a3m").write_text(msa)
for job_number, _ in enumerate(queries_unique):
args.base.joinpath(f"{job_number}.a3m").unlink()
args.base.joinpath(f"{job_number}.paired.a3m").unlink()
for job_number, _ in enumerate(queries_unique[:-1]):
os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m"))
query_file.unlink()
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])
Expand Down
Loading