Skip to content

Commit

Permalink
Adding a per-AA loss to the DASM (#66)
Browse files Browse the repository at this point in the history
* moving DASM prediction to be output in log space
* added a CSP component of a weighted loss
* refactored two loss functionality into a mixin class
  • Loading branch information
matsen authored Oct 16, 2024
1 parent ac86175 commit f204f6b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 53 deletions.
4 changes: 4 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def clamp_probability(x: Tensor) -> Tensor:
return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB))


def clamp_log_probability(x: Tensor) -> Tensor:
return torch.clamp(x, max=np.log(1.0 - SMALL_PROB))


def print_parameter_count(model):
total = 0
for name, module in model.named_modules():
Expand Down
72 changes: 47 additions & 25 deletions netam/dasm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Here we define a mutation-selection model that is per-amino-acid."""
"""Here we define a model that outputs a vector of 20 amino acid preferences."""

import torch
import torch.nn.functional as F
Expand All @@ -11,9 +11,12 @@
import pandas as pd

from netam.common import (
clamp_log_probability,
clamp_probability,
BIG,
)
import netam.dnsm as dnsm
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
Expand Down Expand Up @@ -81,6 +84,7 @@ def update_neutral_probs(self):
def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxs[idx],
"aa_children_idxs": self.aa_children_idxs[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probs[idx],
Expand All @@ -90,21 +94,19 @@ def __getitem__(self, idx):

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_children_idxs = self.aa_children_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)


def zero_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Zero out the diagonal of a batch of predictions.
We do this so that we can sum then have the same type of predictions as for the
DNSM.
"""
# We would like to do
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0
def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to
-BIG."""
# This is effectively
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = -BIG
# but we have a batch dimension. Thus the following.

batch_size, L, _ = predictions.shape
Expand All @@ -113,12 +115,16 @@ def zero_predictions_along_diagonal(predictions, aa_parents_idxs):
batch_indices[:, None],
torch.arange(L, device=predictions.device),
aa_parents_idxs,
] = 0.0
] = -BIG

return predictions


class DASMBurrito(dnsm.DNSMBurrito):
class DASMBurrito(framework.TwoLossMixin, dnsm.DNSMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = torch.nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
Expand All @@ -134,12 +140,13 @@ def prediction_pair_of_batch(self, batch):
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the
# selection factors, namely p_{j, a} * f_{j, a}.
# In contrast to a DNSM, each of these now have last dimension of 20.
predictions = torch.exp(log_neutral_aa_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
"""Take the sum of the neutral mutation log probabilities and the selection
factors.
In contrast to a DNSM, each of these now have last dimension of 20.
"""
predictions = log_neutral_aa_probs + log_selection_factors
assert torch.isnan(predictions).sum() == 0
return predictions

def predictions_of_batch(self, batch):
Expand All @@ -157,28 +164,43 @@ def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)
aa_children_idxs = batch["aa_children_idxs"].to(self.device)
masked_aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch)
# Add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids. This is the conservative choice.
# It might be faster to reassign all the 20s to 0s if we are confident
# in our masking. Perhaps we should always output a 21st dimension
# for the ambiguous amino acids (see issue #16).
# Note that we're going to want to have a symbol for the junction
# between the heavy and light chains.
# If we change something here we should also change the test code
# in test_dasm.py::test_zero_diagonal.
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
[predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1
)

predictions = zero_predictions_along_diagonal(predictions, aa_parents_idxs)
predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zeroing out the diagonal, we are effectively summing over the
# After zapping out the diagonal, we can effectively sum over the
# off-diagonal elements to get the probability of a nonsynonymous
# mutation.
predictions_of_mut = torch.sum(predictions, dim=-1)
predictions_of_mut = predictions_of_mut.masked_select(mask)
predictions_of_mut = clamp_probability(predictions_of_mut)
return self.bce_loss(predictions_of_mut, aa_subs_indicator)
mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
mut_pos_pred = mut_pos_pred.masked_select(mask)
mut_pos_pred = clamp_probability(mut_pos_pred)
mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator)

# We now need to calculate the conditional substitution probability
# (CSP) loss. We have already zapped out the diagonal, and we're in
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator == 1
csp_pred = predictions[subs_mask]
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
Expand Down
11 changes: 3 additions & 8 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""Here we define a mutation-selection model that is just about mutation vs no mutation,
and is trainable.
We'll use these conventions:
* B is the batch size
* L is the max sequence length
"""
"""Defining the deep natural selection model (DNSM)."""

import copy
import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +67,7 @@ def __init__(
self.aa_parents_idxs = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX
)
self.aa_children_idxs = self.aa_parents_idxs.clone()
self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len))

self.mask = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)
Expand All @@ -82,6 +76,7 @@ def __init__(
self.mask[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
aa_seq_len = len(aa_parent)
self.aa_parents_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent)
self.aa_children_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child)
self.aa_subs_indicator_tensor[i, :aa_seq_len] = aa_subs_indicator_tensor_of(
aa_parent, aa_child
)
Expand Down
40 changes: 24 additions & 16 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,35 @@ def to_crepe(self):
return Crepe(encoder, self.model, training_hyperparameters)


class RSSHMBurrito(SHMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)
class TwoLossMixin:
"""A mixin for models that have two losses, one for mutation position and one for
conditional substitution probability (CSP)."""

def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
if loss_reduction is None:
loss_reduction = lambda x: torch.sum(x * self.loss_weights)

return super().process_data_loader(data_loader, train_mode, loss_reduction)

def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Mut pos " + loss_name,
rate_loss.item(),
step,
walltime=self.execution_time(),
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


class RSSHMBurrito(TwoLossMixin, SHMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

def evaluate(self):
val_loader = self.build_val_loader()
return super().process_data_loader(
Expand All @@ -890,7 +907,7 @@ def loss_of_batch(self, batch):
mut_prob = 1 - torch.exp(-rates * branch_lengths.unsqueeze(-1))
mut_prob_masked = mut_prob[masks]
mutation_indicator_masked = mutation_indicators[masks].float()
rate_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)
mut_pos_loss = self.bce_loss(mut_prob_masked, mutation_indicator_masked)

# Conditional substitution probability (CSP) loss calculation
# Mask the new_base_idxs to focus only on positions with mutations
Expand All @@ -902,7 +919,7 @@ def loss_of_batch(self, batch):
assert (new_base_idxs_masked >= 0).all()
csp_loss = self.xent_loss(csp_logits_masked, new_base_idxs_masked)

return torch.stack([rate_loss, csp_loss])
return torch.stack([mut_pos_loss, csp_loss])

def _find_optimal_branch_length(
self,
Expand Down Expand Up @@ -983,15 +1000,6 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

return torch.tensor(optimal_lengths)

def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


def burrito_class_of_model(model):
if isinstance(model, models.RSCNNModel):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import pytest

from netam.common import BIG
from netam.framework import (
crepe_exists,
load_crepe,
Expand All @@ -11,7 +12,7 @@
from netam.dasm import (
DASMBurrito,
DASMDataset,
zero_predictions_along_diagonal,
zap_predictions_along_diagonal,
)


Expand Down Expand Up @@ -67,23 +68,23 @@ def test_crepe_roundtrip(dasm_burrito):
assert torch.equal(t1, t2)


def test_zero_diagonal(dasm_burrito):
def test_zap_diagonal(dasm_burrito):
batch = dasm_burrito.val_dataset[0:2]
predictions = dasm_burrito.predictions_of_batch(batch)
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)
aa_parents_idxs = batch["aa_parents_idxs"].to(dasm_burrito.device)
zeroed_predictions = predictions.clone()
zeroed_predictions = zero_predictions_along_diagonal(
zeroed_predictions = zap_predictions_along_diagonal(
zeroed_predictions, aa_parents_idxs
)
L = predictions.shape[1]
for batch_idx in range(2):
for i in range(L):
for j in range(20):
if j == aa_parents_idxs[batch_idx, i]:
assert zeroed_predictions[batch_idx, i, j] == 0.0
assert zeroed_predictions[batch_idx, i, j] == -BIG
else:
assert (
zeroed_predictions[batch_idx, i, j]
Expand Down

0 comments on commit f204f6b

Please sign in to comment.