Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Sep 30, 2024
1 parent e42529e commit 193822b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
19 changes: 9 additions & 10 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
translate_sequences,
)


class DASMDataset(dnsm.DNSMDataset):

# TODO should we rename this?
Expand Down Expand Up @@ -74,9 +75,7 @@ def update_neutral_aa_mut_probs(self):
print(f"rates: {rates}")
print(f"subs_probs: {subs_probs}")
print(f"branch_length: {branch_length}")
raise ValueError(
f"neutral_aa_probs is not finite: {neutral_aa_probs}"
)
raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}")

# Ensure that all values are positive before taking the log later
neutral_aa_probs = clamp_probability(neutral_aa_probs)
Expand Down Expand Up @@ -134,13 +133,11 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
return train_dataset, val_dataset




class DASMBurrito(dnsm.DNSMBurrito):

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch
of data."""
"""Get log neutral AA probabilities and log selection factors for a batch of
data."""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device)
Expand Down Expand Up @@ -182,15 +179,17 @@ def loss_of_batch(self, batch):
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
)
# Now we make predictions of mutation by taking everything off the diagonal.
#predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0
# predictions[torch.arange(len(aa_parents_idxs)), aa_parents_idxs] = 0.0

# Get batch size and sequence length
batch_size, L, _ = predictions.shape
# Create indices for each batch
batch_indices = torch.arange(batch_size, device=self.device)
# Zero out the diagonal by setting predictions[batch_idx, site_idx, aa_idx] to 0
predictions[batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs] = 0.0

predictions[
batch_indices[:, None], torch.arange(L, device=self.device), aa_parents_idxs
] = 0.0

predictions_of_mut = torch.sum(predictions, dim=-1)
predictions_of_mut = predictions_of_mut.masked_select(mask)
return self.bce_loss(predictions_of_mut, aa_subs_indicator)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
)
from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import DASMBurrito,train_val_datasets_of_pcp_df
from netam.dasm import DASMBurrito, train_val_datasets_of_pcp_df


#TODO code dup
# TODO code dup
@pytest.fixture
def pcp_df():
df = load_pcp_df(
Expand All @@ -35,7 +35,11 @@ def dasm_burrito(pcp_df):
train_dataset, val_dataset = train_val_datasets_of_pcp_df(pcp_df)

model = TransformerBinarySelectionModelWiggleAct(
nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2, output_dim=20,
nhead=2,
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=20,
)

burrito = DASMBurrito(
Expand All @@ -46,7 +50,9 @@ def dasm_burrito(pcp_df):
learning_rate=0.001,
min_learning_rate=0.0001,
)
burrito.joint_train(epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False)
burrito.joint_train(
epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False
)
return burrito


Expand Down

0 comments on commit 193822b

Please sign in to comment.