Skip to content

Commit

Permalink
changed definition of atom sigma to bioEM radius
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Aug 25, 2023
1 parent b9f8ddb commit f4e2481
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 52 deletions.
120 changes: 96 additions & 24 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

31 changes: 13 additions & 18 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,17 @@ def get_image_priors(
Returns:
zuko.distributions.BoxUniform: prior
"""
if isinstance(image_config["RES"], list) and len(image_config["RES"]) == 2:
if isinstance(image_config["DELTA_SIGMA"], list) and len(image_config["DELTA_SIGMA"]) == 2:
lower = torch.tensor(
[[image_config["RES"][0]]], dtype=torch.float32, device=device
[[image_config["DELTA_SIGMA"][0]]], dtype=torch.float32, device=device
)
upper = torch.tensor(
[[image_config["RES"][1]]], dtype=torch.float32, device=device
[[image_config["DELTA_SIGMA"][1]]], dtype=torch.float32, device=device
)

# assert (
# lower > 2.0 * image_config["PIXEL_SIZE"]
# ), "The lower bound for RES must be at least 2 times the pixel size."
assert lower <= upper, "Lower bound must be smaller or equal than upper bound."

# assert lower < upper, "Lower bound must be smaller than upper bound."

res = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)
delta_sigma = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)

shift = zuko.distributions.BoxUniform(
lower=torch.tensor(
Expand All @@ -73,8 +69,8 @@ def get_image_priors(
[[image_config["DEFOCUS"][1]]], dtype=torch.float32, device=device
)

# assert lower > 0.0, "The lower bound for DEFOCUS must be positive."
# assert lower < upper, "Lower bound must be smaller than upper bound."
assert lower > 0.0, "The lower bound for DEFOCUS must be positive."
assert lower <= upper, "Lower bound must be smaller or equal than upper bound."

defocus = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)

Expand All @@ -89,8 +85,8 @@ def get_image_priors(
[[image_config["B_FACTOR"][1]]], dtype=torch.float32, device=device
)

# assert lower > 0.0, "The lower bound for DEFOCUS must be positive."
# assert lower < upper, "Lower bound must be smaller than upper bound."
assert lower > 0.0, "The lower bound for B_FACTOR must be positive."
assert lower <= upper, "Lower bound must be smaller or equal than upper bound."

b_factor = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)

Expand All @@ -102,8 +98,7 @@ def get_image_priors(
[[image_config["SNR"][1]]], dtype=torch.float32, device=device
).log10()

# assert lower > 0.0, "The lower bound for DEFOCUS must be positive."
# assert lower < upper, "Lower bound must be smaller than upper bound."
assert lower <= upper, "Lower bound must be smaller or equal than upper bound."

snr = zuko.distributions.BoxUniform(lower=lower, upper=upper, ndims=1)

Expand All @@ -129,7 +124,7 @@ def get_image_priors(
return ImagePrior(
index_prior,
quaternion_prior,
res,
delta_sigma,
shift,
defocus,
b_factor,
Expand Down Expand Up @@ -165,7 +160,7 @@ def __init__(
self,
index_prior,
quaternion_prior,
res_prior,
delta_sigma_prior,
shift_prior,
defocus_prior,
b_factor_prior,
Expand All @@ -176,7 +171,7 @@ def __init__(
self.priors = [
index_prior,
quaternion_prior,
res_prior,
delta_sigma_prior,
shift_prior,
defocus_prior,
b_factor_prior,
Expand Down
4 changes: 1 addition & 3 deletions src/cryo_sbi/utils/pdb_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def pdb_parser_resid_(fname: str) -> torch.tensor:
atomic_model = torch.zeros((5, residues.n_residues))
atomic_model[0:3, :] = torch.from_numpy(univ.select_atoms("name CA").positions.T)
atomic_model[3, :] = torch.tensor([resid_density[x] for x in residues.resnames])
atomic_model[4, :] = (
torch.tensor([resid_radius[x] for x in residues.resnames]) / torch.pi
) ** 2
atomic_model[4, :] = 2 * (torch.tensor([resid_radius[x] / 2 for x in residues.resnames]) ** 2) # Residue radius is will be the 2 sigma interval of the gaussian

return atomic_model

Expand Down
4 changes: 1 addition & 3 deletions src/cryo_sbi/utils/traj_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def traj_parser_resid_(top_file: str, traj_file: str) -> torch.tensor:
atomic_models[i, 3, :] = torch.tensor(
[resid_density[x] for x in residues.resnames]
)
atomic_models[i, 4, :] = (
torch.tensor([resid_radius[x] for x in residues.resnames]) / torch.pi
) ** 2
atomic_models[i, 4, :] = 2 * (torch.tensor([resid_radius[x] / 2 for x in residues.resnames]) ** 2) # Residue radius is will be the 2 sigma interval of the gaussian

return atomic_models

Expand Down
6 changes: 3 additions & 3 deletions src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def gen_rot_matrix(quats: torch.Tensor) -> torch.Tensor:
def project_density(
atomic_model: torch.Tensor,
quats: torch.Tensor,
res: torch.Tensor,
delta_sigma: torch.Tensor,
shift: torch.Tensor,
num_pixels: int,
pixel_size: float,
Expand All @@ -73,7 +73,7 @@ def project_density(

num_batch, _, num_atoms = atomic_model.shape

variances = atomic_model[:, 4, :] * res[:, 0] ** 2
variances = atomic_model[:, 4, :] * delta_sigma[:, 0]
amplitudes = atomic_model[:, 3, :] / torch.sqrt((2 * torch.pi * variances))

grid_min = -pixel_size * num_pixels * 0.5
Expand All @@ -100,6 +100,6 @@ def project_density(
) * amplitudes.unsqueeze(1)

image = torch.bmm(gauss_x, gauss_y.transpose(1, 2)) # * norms
image /= torch.norm(image, dim=[-2, -1]).reshape(-1, 1, 1)
image /= torch.norm(image, dim=[-2, -1]).reshape(-1, 1, 1) # do we need this normalization?

return image
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/validate_image_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_image_params(config: dict) -> None:
needed_keys = [
"N_PIXELS",
"PIXEL_SIZE",
"RES",
"DELTA_SIGMA",
"SHIFT",
"DEFOCUS",
"SNR",
Expand Down

0 comments on commit f4e2481

Please sign in to comment.