Skip to content

Commit

Permalink
cleaning up; Mixin class
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 16, 2024
1 parent 7f5c4cf commit 3183d76
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 48 deletions.
45 changes: 12 additions & 33 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BIG,
)
import netam.dnsm as dnsm
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
from netam.sequences import (
Expand Down Expand Up @@ -122,19 +123,12 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
return predictions


class DASMBurrito(dnsm.DNSMBurrito):
class DASMBurrito(framework.TwoLossMixin, dnsm.DNSMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = torch.nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

# TODO code dup
def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
if loss_reduction is None:
loss_reduction = lambda x: torch.sum(x * self.loss_weights)

return super().process_data_loader(data_loader, train_mode, loss_reduction)

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
data."""
Expand All @@ -149,15 +143,13 @@ def prediction_pair_of_batch(self, batch):
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
"""
Take the sum of the neutral mutation log probabilities and the
selection factors.
"""Take the sum of the neutral mutation log probabilities and the selection
factors.
In contrast to a DNSM, each of these now have last dimension of 20.
"""
predictions = log_neutral_aa_probs + log_selection_factors
assert torch.isnan(predictions).sum() == 0
#predictions = clamp_log_probability(predictions)
return predictions

def predictions_of_batch(self, batch):
Expand All @@ -183,37 +175,36 @@ def loss_of_batch(self, batch):
# It might be faster to reassign all the 20s to 0s if we are confident
# in our masking. Perhaps we should always output a 21st dimension
# for the ambiguous amino acids (see issue #16).
# Note that we're going to want to have a symbol for the junction
# between the heavy and light chains.
# If we change something here we should also change the test code
# in test_dasm.py::test_zero_diagonal.
# TODO we have ambiguous amino acids? Note that we're going to want to
# have a symbol for the junction between the heavy and light chains.
predictions = torch.cat(
[predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1
)

predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zeroing out the diagonal, we are effectively summing over the
# After zapping out the diagonal, we can effectively sum over the
# off-diagonal elements to get the probability of a nonsynonymous
# mutation.
mut_pos_pred = torch.sum(torch.exp(predictions), dim=-1)
mut_pos_pred = mut_pos_pred.masked_select(mask)
mut_pos_pred = clamp_probability(mut_pos_pred)
mut_pos_loss = self.bce_loss(mut_pos_pred, masked_aa_subs_indicator)

# We now need to handle the conditional substitution probability (CSP) loss.
# We have already zapped out the diagonal, and we're in logit space, so
# we are set up for using the cross entropy loss. However we have to
# mask out the sites that are not substituted, i.e. the sites for which
# aa_subs_indicator is 0.
# We now need to calculate the conditional substitution probability
# (CSP) loss. We have already zapped out the diagonal, and we're in
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator.bool()
csp_pred = predictions[subs_mask]
csp_targets = aa_children_idxs.masked_select(subs_mask)
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])


def build_selection_matrix_from_parent(self, parent: str):
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly.
Expand All @@ -223,15 +214,3 @@ def build_selection_matrix_from_parent(self, parent: str):
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors


# TODO code dup
def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Mut pos " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)

37 changes: 22 additions & 15 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,35 @@ def to_crepe(self):
return Crepe(encoder, self.model, training_hyperparameters)


class RSSHMBurrito(SHMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)
class TwoLossMixin:
"""A mixin for models that have two losses, one for mutation position and one for
conditional substitution probability (CSP)."""

def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
if loss_reduction is None:
loss_reduction = lambda x: torch.sum(x * self.loss_weights)

return super().process_data_loader(data_loader, train_mode, loss_reduction)

def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Mut pos " + loss_name,
rate_loss.item(),
step,
walltime=self.execution_time(),
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


class RSSHMBurrito(TwoLossMixin, SHMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

def evaluate(self):
val_loader = self.build_val_loader()
return super().process_data_loader(
Expand Down Expand Up @@ -984,16 +1001,6 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

return torch.tensor(optimal_lengths)

def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
# TODO rename?
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


def burrito_class_of_model(model):
if isinstance(model, models.RSCNNModel):
Expand Down

0 comments on commit 3183d76

Please sign in to comment.