Skip to content

Commit

Permalink
updated loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 21, 2024
1 parent 244dd4a commit d819cc8
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/cryo_sbi/inference/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,19 @@ class NPERobustStatsLoss(nn.Module):

def __init__(self, estimator: nn.Module, gamma: float):
super().__init__()

print("distance loss")
self.estimator = estimator
self.gamma = gamma
self.distance = nn.PairwiseDistance(p=10)

def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) -> torch.Tensor:

latent_vecs_x = self.estimator.embedding(x)
self.estimator.eval()
latent_vecs_x_obs = self.estimator.embedding(x_obs)
summary_stats_regularization = self.distance(latent_vecs_x, latent_vecs_x_obs).mean()
self.estimator.train()

summary_stats_regularization = self.gamma * mmd_unweighted(
latent_vecs_x,
latent_vecs_x_obs,
median_heuristic(x)
)
log_p = self.estimator.npe(self.estimator.standardize(theta), latent_vecs_x)
print(-log_p.mean().item(), summary_stats_regularization.mean().item())

return -log_p.mean() + summary_stats_regularization
return -log_p.mean() + 0.0 * summary_stats_regularization

0 comments on commit d819cc8

Please sign in to comment.