Skip to content

Commit

Permalink
pos_loss not budging with simplified setup
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 15, 2024
1 parent bac67d2 commit 6a8ae49
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def __init__(self, *args, **kwargs):
self.xent_loss = torch.nn.CrossEntropyLoss()
self.loss_weights = torch.tensor([1.0, 0.01]).to(self.device)

# TODO code dup
def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
if loss_reduction is None:
loss_reduction = lambda x: torch.sum(x * self.loss_weights)
# # TODO code dup
# def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
# if loss_reduction is None:
# loss_reduction = lambda x: torch.sum(x * self.loss_weights)

return super().process_data_loader(data_loader, train_mode, loss_reduction)
# return super().process_data_loader(data_loader, train_mode, loss_reduction)

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
Expand All @@ -157,7 +157,7 @@ def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
"""
predictions = log_neutral_aa_probs + log_selection_factors
assert torch.isnan(predictions).sum() == 0
predictions = clamp_log_probability(predictions)
#predictions = clamp_log_probability(predictions)
return predictions

def predictions_of_batch(self, batch):
Expand Down Expand Up @@ -211,7 +211,8 @@ def loss_of_batch(self, batch):
csp_targets = aa_children_idxs.masked_select(subs_mask)
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([mut_pos_loss, csp_loss])
return mut_pos_loss
#return torch.stack([mut_pos_loss, csp_loss])


def build_selection_matrix_from_parent(self, parent: str):
Expand All @@ -224,13 +225,13 @@ def build_selection_matrix_from_parent(self, parent: str):

return selection_factors

# TODO code dup
def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Mut pos " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)
# # TODO code dup
# def write_loss(self, loss_name, loss, step):
# rate_loss, csp_loss = loss.unbind()
# self.writer.add_scalar(
# "Mut pos " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
# )
# self.writer.add_scalar(
# "CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
# )

0 comments on commit 6a8ae49

Please sign in to comment.