Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 13, 2024
1 parent 6c32126 commit b9df856
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
7 changes: 4 additions & 3 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
def zero_predictions_along_diagonal(predictions, aa_parents_idxs):
"""Zero out the diagonal of a batch of predictions.
We do this so that we can sum then have the same type of predictions as for
the DNSM."""
We do this so that we can sum then have the same type of predictions as for the
DNSM.
"""
# We would like to do
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0
# but we have a batch dimension. Thus the following.
Expand Down Expand Up @@ -199,7 +200,7 @@ def loss_of_batch(self, batch):
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)

predictions = zero_predictions_along_diagonal(predictions, aa_parents_idxs)

predictions_of_mut = torch.sum(predictions, dim=-1)
Expand Down
15 changes: 12 additions & 3 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
add_shm_model_outputs_to_pcp_df,
)
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import DASMBurrito, train_val_datasets_of_pcp_df, zero_predictions_along_diagonal
from netam.dasm import (
DASMBurrito,
train_val_datasets_of_pcp_df,
zero_predictions_along_diagonal,
)


# TODO code dup
Expand Down Expand Up @@ -87,12 +91,17 @@ def test_zero_diagonal(dasm_burrito):
)
aa_parents_idxs = batch["aa_parents_idxs"].to(dasm_burrito.device)
zeroed_predictions = predictions.clone()
zeroed_predictions = zero_predictions_along_diagonal(zeroed_predictions, aa_parents_idxs)
zeroed_predictions = zero_predictions_along_diagonal(
zeroed_predictions, aa_parents_idxs
)
L = predictions.shape[1]
for batch_idx in range(2):
for i in range(L):
for j in range(20):
if j == aa_parents_idxs[batch_idx, i]:
assert zeroed_predictions[batch_idx, i, j] == 0.0
else:
assert zeroed_predictions[batch_idx, i, j] == predictions[batch_idx, i, j]
assert (
zeroed_predictions[batch_idx, i, j]
== predictions[batch_idx, i, j]
)

0 comments on commit b9df856

Please sign in to comment.