Skip to content

Commit

Permalink
notes live
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Sep 30, 2024
1 parent 193822b commit 8cff29f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
12 changes: 8 additions & 4 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def update_neutral_aa_mut_probs(self):
parent_idxs, subs_probs[:parent_len, :]
)

# BIG TODO here: this is where we need to change to the new per-aa probabilities.
# Then we need to think carefully about masking and padding etc.
neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
Expand Down Expand Up @@ -113,7 +111,7 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


# TODO code dup
# TODO second step. code dup: class method as in dnsm.py
def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""Perform a train-val split based on a "in_train" column.
Expand Down Expand Up @@ -150,6 +148,8 @@ def prediction_pair_of_batch(self, batch):

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
# NOTE each of these now have last dimension of 20
# this is p_{j, a} * f_{j, a}
predictions = torch.exp(log_neutral_aa_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
Expand All @@ -174,18 +174,22 @@ def loss_of_batch(self, batch):
predictions = self.predictions_of_batch(batch)
# add one entry, zero, to the last dimension of the predictions tensor
# to handle the ambiguous amino acids
# TODO perhaps we can do better
# TODO perhaps we can do better: perhaps we can be confident in our masking that is going to take care of this if we re assign all the 20s to 0s.
# OR should we just always output a 21st dimension?
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)
# Now we make predictions of mutation by taking everything off the diagonal.
# We would like to do
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0
# but we have a batch dimension. Thus the following.

# Get batch size and sequence length
batch_size, L, _ = predictions.shape
# Create indices for each batch
batch_indices = torch.arange(batch_size, device=self.device)
# Zero out the diagonal by setting predictions[batch_idx, site_idx, aa_idx] to 0
# TODO play around with this in the notebook? Or just print things?
predictions[
batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs
] = 0.0
Expand Down
1 change: 1 addition & 0 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


# TODO second step. package this inside of DNSMDataset as a class method.
def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""Perform a train-val split based on a "in_train" column.
Expand Down

0 comments on commit 8cff29f

Please sign in to comment.