Skip to content

Commit 35c3efa

Browse files
matsenwilldumm
andauthored
Paired heavy-light modeling (#92)
Enables training on paired heavy/light sequences. A separator token `^` is added after heavy and before light chain sequences. For now, only heavy chain sequences can be used for validation. --------- Co-authored-by: Will Dumm <[email protected]>
1 parent 685a84e commit 35c3efa

16 files changed

+240
-63
lines changed
Binary file not shown.

Diff for: netam/common.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
from torch import nn, Tensor
1414
import multiprocessing as mp
1515

16-
from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence
16+
from netam.sequences import (
17+
iter_codons,
18+
apply_aa_mask_to_nt_sequence,
19+
RESERVED_TOKEN_TRANSLATIONS,
20+
BASES,
21+
AA_TOKEN_STR_SORTED,
22+
)
1723

1824
BIG = 1e9
1925
SMALL_PROB = 1e-6
20-
BASES = ["A", "C", "G", "T"]
21-
BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
22-
AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY"
23-
AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X"
24-
MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1
2526

2627
# I needed some sequence to use to normalize the rate of mutation in the SHM model.
2728
# So, I chose perhaps the most famous antibody sequence, VRC01:
@@ -65,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str):
6566
character."""
6667
try:
6768
return torch.tensor(
68-
[AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int
69+
[AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
6970
)
7071
except ValueError:
7172
print(f"Found an invalid amino acid in the string: {aa_str}")
@@ -88,17 +89,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
8889
return mask
8990

9091

92+
def _consider_codon(codon):
93+
"""Return False if codon should be masked, True otherwise."""
94+
if "N" in codon:
95+
return False
96+
elif codon in RESERVED_TOKEN_TRANSLATIONS:
97+
return False
98+
else:
99+
return True
100+
101+
91102
def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
92103
"""Return a mask tensor indicating codons which contain at least one N.
93104
94105
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
95-
the "and" mask will be computed for all sequences
106+
the "and" mask will be computed for all sequences. Codons containing marker tokens
107+
are also masked.
96108
"""
97109
if aa_length is None:
98110
aa_length = len(nt_parent) // 3
99111
sequences = (nt_parent,) + other_nt_seqs
100112
mask = [
101-
all("N" not in codon for codon in codons)
113+
all(_consider_codon(codon) for codon in codons)
102114
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
103115
]
104116
if len(mask) < aa_length:
@@ -114,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor):
114126
115127
Args:
116128
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
117-
indices into AA_STR_SORTED_AMBIG.
129+
indices into AA_TOKEN_STR_SORTED.
118130
119131
Returns:
120132
List[str]: A list of amino acid strings with trailing 'X's removed.
@@ -123,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor):
123135

124136
aa_str_list = []
125137
for row in idx_tensor:
126-
aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist())
138+
aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist())
127139
aa_str_list.append(aa_str.rstrip("X"))
128140

129141
return aa_str_list

Diff for: netam/dasm.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def prediction_pair_of_batch(self, batch):
139139
raise ValueError(
140140
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
141141
)
142-
log_selection_factors = self.model(aa_parents_idxs, mask)
142+
# We need the model to see special tokens here. For every other purpose
143+
# they are masked out.
144+
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
145+
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
143146
return log_neutral_aa_probs, log_selection_factors
144147

145148
def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):

Diff for: netam/dnsm.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def build_selection_matrix_from_parent(self, parent: str):
163163
"""
164164
parent = sequences.translate_sequence(parent)
165165
selection_factors = self.model.selection_factors_of_aa_str(parent)
166-
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
166+
selection_matrix = torch.zeros(
167+
(len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float
168+
)
167169
# Every "off-diagonal" entry of the selection matrix is set to the selection
168170
# factor, where "diagonal" means keeping the same amino acid.
169171
selection_matrix[:, :] = selection_factors[:, None]

Diff for: netam/dxsm.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from tqdm import tqdm
1616

1717
from netam.common import (
18-
MAX_AMBIG_AA_IDX,
1918
aa_idx_tensor_of_str_ambig,
2019
stack_heterogeneous,
2120
codon_mask_tensor_of,
@@ -28,6 +27,8 @@
2827
translate_sequences,
2928
apply_aa_mask_to_nt_sequence,
3029
nt_mutation_frequency,
30+
MAX_AA_TOKEN_IDX,
31+
RESERVED_TOKEN_REGEX,
3132
)
3233

3334

@@ -43,8 +44,12 @@ def __init__(
4344
branch_lengths: torch.Tensor,
4445
multihit_model=None,
4546
):
46-
self.nt_parents = nt_parents
47-
self.nt_children = nt_children
47+
self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True)
48+
# We will replace reserved tokens with Ns but use the unmodified
49+
# originals for translation and mask creation.
50+
self.nt_children = nt_children.str.replace(
51+
RESERVED_TOKEN_REGEX, "N", regex=True
52+
)
4853
self.nt_ratess = nt_ratess
4954
self.nt_cspss = nt_cspss
5055
self.multihit_model = copy.deepcopy(multihit_model)
@@ -56,14 +61,16 @@ def __init__(
5661
assert len(self.nt_parents) == len(self.nt_children)
5762
pcp_count = len(self.nt_parents)
5863

59-
aa_parents = translate_sequences(self.nt_parents)
60-
aa_children = translate_sequences(self.nt_children)
64+
# Important to use the unmodified versions of nt_parents and
65+
# nt_children so they still contain special tokens.
66+
aa_parents = translate_sequences(nt_parents)
67+
aa_children = translate_sequences(nt_children)
6168
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
6269
# We have sequences of varying length, so we start with all tensors set
6370
# to the ambiguous amino acid, and then will fill in the actual values
6471
# below.
6572
self.aa_parents_idxss = torch.full(
66-
(pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX
73+
(pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX
6774
)
6875
self.aa_children_idxss = self.aa_parents_idxss.clone()
6976
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))
@@ -90,7 +97,7 @@ def __init__(
9097
)
9198

9299
assert torch.all(self.masks.sum(dim=1) > 0)
93-
assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX
100+
assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX
94101

95102
self._branch_lengths = branch_lengths
96103
self.update_neutral_probs()
@@ -296,9 +303,11 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
296303

297304
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
298305
worker_count = min(mp.cpu_count() // 2, 10)
299-
# # The following can be used when one wants a better traceback.
306+
# The following can be used when one wants a better traceback.
300307
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
301-
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
308+
# return burrito.serial_find_optimal_branch_lengths(
309+
# dataset, **optimization_kwargs
310+
# )
302311
our_optimize_branch_length = partial(
303312
worker_optimize_branch_length,
304313
self.__class__,

Diff for: netam/framework.py

+75-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
optimizer_of_name,
2323
tensor_to_np_if_needed,
2424
BASES,
25-
BASES_AND_N_TO_INDEX,
2625
BIG,
2726
VRC01_NT_SEQ,
2827
encode_sequences,
2928
parallelize_function,
3029
)
30+
from netam.sequences import BASES_AND_N_TO_INDEX
3131
from netam import models
3232
import netam.molevol as molevol
3333

@@ -352,31 +352,100 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
352352
return trimmed_rates, trimmed_csps
353353

354354

355+
def join_chains(pcp_df):
356+
"""Join the parent and child chains in the pcp_df.
357+
358+
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
359+
360+
If parent_h and parent_l are not present, then we assume that the parent is the
361+
heavy chain. If only one of parent_h or parent_l is present, then we place the ^^^
362+
padding to the right of heavy, or to the left of light.
363+
"""
364+
cols = pcp_df.columns
365+
# Look for heavy chain
366+
if "parent_h" in cols:
367+
assert "child_h" in cols, "child_h column missing!"
368+
assert "v_gene_h" in cols, "v_gene_h column missing!"
369+
elif "parent" in cols:
370+
assert "child" in cols, "child column missing!"
371+
assert "v_gene" in cols, "v_gene column missing!"
372+
pcp_df["parent_h"] = pcp_df["parent"]
373+
pcp_df["child_h"] = pcp_df["child"]
374+
pcp_df["v_gene_h"] = pcp_df["v_gene"]
375+
else:
376+
pcp_df["parent_h"] = ""
377+
pcp_df["child_h"] = ""
378+
pcp_df["v_gene_h"] = "N/A"
379+
# Look for light chain
380+
if "parent_l" in cols:
381+
assert "child_l" in cols, "child_l column missing!"
382+
assert "v_gene_l" in cols, "v_gene_l column missing!"
383+
else:
384+
pcp_df["parent_l"] = ""
385+
pcp_df["child_l"] = ""
386+
pcp_df["v_gene_l"] = "N/A"
387+
388+
if (pcp_df["parent_h"].str.len() + pcp_df["parent_l"].str.len()).min() < 3:
389+
raise ValueError("At least one PCP has fewer than three nucleotides.")
390+
391+
pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
392+
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]
393+
394+
pcp_df.drop(
395+
columns=["parent_h", "parent_l", "child_h", "child_l", "v_gene"],
396+
inplace=True,
397+
errors="ignore",
398+
)
399+
return pcp_df
400+
401+
355402
def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
356403
"""Load a PCP dataframe from a gzipped CSV file.
357404
358405
`orig_pcp_idx` is the index column from the original file, even if we subset by
359406
sampling or by choosing V families.
407+
408+
If we will join the heavy and light chain sequences into a single
409+
sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain
410+
sequence is present, this separator will be added to the appropriate side of the available sequence.
360411
"""
361412
pcp_df = (
362413
pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0)
363414
.reset_index()
364415
.rename(columns={"index": "orig_pcp_idx"})
365416
)
366-
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
417+
pcp_df = join_chains(pcp_df)
418+
419+
pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0]
420+
pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0]
367421
if chosen_v_families is not None:
368422
chosen_v_families = set(chosen_v_families)
369-
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
423+
pcp_df = pcp_df[
424+
pcp_df["v_family_h"].isin(chosen_v_families)
425+
& pcp_df["v_family_l"].isin(chosen_v_families)
426+
]
370427
if sample_count is not None:
371428
pcp_df = pcp_df.sample(sample_count)
372429
pcp_df.reset_index(drop=True, inplace=True)
373430
return pcp_df
374431

375432

376433
def add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
377-
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
378-
pcp_df["nt_rates"] = rates
379-
pcp_df["nt_csps"] = csps
434+
# Split parent heavy and light chains to apply neutral model separately
435+
split_parents = pcp_df["parent"].str.split(pat="^^^", expand=True, regex=False)
436+
# To keep prediction aligned to joined h/l sequence, pad parent
437+
h_parents = split_parents[0] + "NNN"
438+
l_parents = split_parents[1]
439+
440+
h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents)
441+
l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents)
442+
# Join predictions
443+
pcp_df["nt_rates"] = [
444+
torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates)
445+
]
446+
pcp_df["nt_csps"] = [
447+
torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps)
448+
]
380449
return pcp_df
381450

382451

Diff for: netam/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch import Tensor
1111

1212
from netam.hit_class import apply_multihit_correction
13+
from netam.sequences import MAX_AA_TOKEN_IDX
1314
from netam.common import (
14-
MAX_AMBIG_AA_IDX,
1515
aa_idx_tensor_of_str_ambig,
1616
PositionalEncoding,
1717
generate_kmers,
@@ -622,7 +622,7 @@ def __init__(
622622
self.nhead = nhead
623623
self.dim_feedforward = dim_feedforward
624624
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
625-
self.amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model)
625+
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
626626
self.encoder_layer = nn.TransformerEncoderLayer(
627627
d_model=self.d_model,
628628
nhead=nhead,

Diff for: netam/molevol.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from torch import Tensor, optim
1111

12-
from netam.sequences import CODON_AA_INDICATOR_MATRIX
12+
from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX
1313

1414
import netam.sequences as sequences
1515

@@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of(
444444
"""
445445

446446
assert len(parent) % 3 == 0
447-
assert sel_matrix.shape == (len(parent) // 3, 20)
447+
assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1)
448448

449449
parent_idxs = sequences.nt_idx_tensor_of_str(parent)
450450
child_idxs = sequences.nt_idx_tensor_of_str(child)

0 commit comments

Comments
 (0)