Skip to content

Commit

Permalink
trying max mean discrepancy
Browse files Browse the repository at this point in the history
  • Loading branch information
aevans1 committed Aug 1, 2023
1 parent 5e2f7c4 commit 59115bc
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 32 deletions.
283 changes: 283 additions & 0 deletions Lukes_folder/MMD_testing.ipynb

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions Lukes_folder/image_params_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"N_PIXELS": 128,
"PIXEL_SIZE": 1.5,
"SIGMA": 4.0,
"MODEL_FILE": "../data/protein_models/hsp90_models.npy",
"SHIFT": 30,
"DEFOCUS": 1.5,
"SNR": 0.01,
"RADIUS_MASK": 64,
"AMP": 0.1,
"B_FACTOR": 1.0,
"ELECWAVE": 0.019866
}
45 changes: 16 additions & 29 deletions Lukes_folder/trying_it_out.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_image_priors(
[[image_config["SIGMA"][1]]], dtype=torch.float32, device=device
)
sigma = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
else:
sigma = image_config["SIGMA"]

shift = zuko.distributions.BoxUniform(
lower=torch.tensor(
Expand All @@ -67,6 +69,8 @@ def get_image_priors(
[[image_config["DEFOCUS"][1]]], dtype=torch.float32, device=device
)
defocus = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
else:
defocus = image_config["DEFOCUS"]

if (
isinstance(image_config["B_FACTOR"], list)
Expand All @@ -79,6 +83,8 @@ def get_image_priors(
[[image_config["B_FACTOR"][1]]], dtype=torch.float32, device=device
)
b_factor = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
else:
b_factor = image_config["B_FACTOR"]

if isinstance(image_config["SNR"], list) and len(image_config["SNR"]) == 2:
lower = torch.tensor(
Expand All @@ -88,6 +94,8 @@ def get_image_priors(
[[image_config["SNR"][1]]], dtype=torch.float32, device=device
).log10()
snr = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
else:
snr = image_config["SNR"]

amp = zuko.distributions.BoxUniform(
lower=torch.tensor([[image_config["AMP"]]], dtype=torch.float32, device=device),
Expand Down
5 changes: 2 additions & 3 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def max_index(self) -> int:

def simulate(self, num_sim, indices=None, return_parameters=False):
parameters = self._priors.sample((num_sim,))
indices = parameters[0] if indices is None else indices
if indices is not None:
assert isinstance(
indices, torch.Tensor
Expand All @@ -117,8 +116,8 @@ def simulate(self, num_sim, indices=None, return_parameters=False):
assert (
indices.ndim == 2
), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
indices = torch.tensor(indices, dtype=torch.float32)

else:
indices = parameters[0]
images = cryo_em_simulator(
self._models,
indices,
Expand Down

0 comments on commit 59115bc

Please sign in to comment.