Skip to content

Commit

Permalink
aha, it was the fake row
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 16, 2024
1 parent 6a8ae49 commit 7f5c4cf
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def __init__(self, *args, **kwargs):
self.xent_loss = torch.nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

# # TODO code dup
# def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
# if loss_reduction is None:
# loss_reduction = lambda x: torch.sum(x * self.loss_weights)
# TODO code dup
def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
if loss_reduction is None:
loss_reduction = lambda x: torch.sum(x * self.loss_weights)

# return super().process_data_loader(data_loader, train_mode, loss_reduction)
return super().process_data_loader(data_loader, train_mode, loss_reduction)

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
Expand Down Expand Up @@ -188,7 +188,7 @@ def loss_of_batch(self, batch):
# TODO we have ambiguous amino acids? Note that we're going to want to
# have a symbol for the junction between the heavy and light chains.
predictions = torch.cat(
[predictions, torch.zeros_like(predictions[:, :, :1])], dim=-1
[predictions, torch.full_like(predictions[:, :, :1], -BIG)], dim=-1
)

predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)
Expand All @@ -211,8 +211,7 @@ def loss_of_batch(self, batch):
csp_targets = aa_children_idxs.masked_select(subs_mask)
csp_loss = self.xent_loss(csp_pred, csp_targets)

return mut_pos_loss
#return torch.stack([mut_pos_loss, csp_loss])
return torch.stack([mut_pos_loss, csp_loss])


def build_selection_matrix_from_parent(self, parent: str):
Expand All @@ -225,13 +224,14 @@ def build_selection_matrix_from_parent(self, parent: str):

return selection_factors

# # TODO code dup
# def write_loss(self, loss_name, loss, step):
# rate_loss, csp_loss = loss.unbind()
# self.writer.add_scalar(
# "Mut pos " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
# )
# self.writer.add_scalar(
# "CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
# )

# TODO code dup
def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Mut pos " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)

0 comments on commit 7f5c4cf

Please sign in to comment.