Skip to content

Commit

Permalink
include physical information in simulator and implement pdb/traj parser
Browse files Browse the repository at this point in the history
closes #24
  • Loading branch information
DSilva27 committed Jul 29, 2023
1 parent 23e0a65 commit d953d01
Show file tree
Hide file tree
Showing 20 changed files with 680 additions and 227 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203
215 changes: 134 additions & 81 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/analysis_nma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"metadata": {},
"outputs": [],
"source": [
"images = CryoEmSimulator.simulate( return_parameters=False)"
"images = CryoEmSimulator.simulate(return_parameters=False)"
]
},
{
Expand Down
32 changes: 20 additions & 12 deletions notebooks/analysis_nma_refurbed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"import numpy as np\n",
"import torch\n",
"import torchvision.transforms as transforms\n",
"#import umap\n",
"\n",
"# import umap\n",
"import sklearn.cluster as cluster\n",
"import sklearn.manifold as manifold\n",
"\n",
Expand Down Expand Up @@ -155,7 +156,7 @@
" synthetic_images = torch.stack(\n",
" [cryosbi.simulator(index) for index in indices], dim=0\n",
" )\n",
" \n",
"\n",
" synthetic_images = img_utils.Mask(128, 60)(synthetic_images)\n",
"\n",
" samples_syntehtic = est_utils.sample_posterior(\n",
Expand Down Expand Up @@ -219,7 +220,9 @@
"source": [
"folder_with_micrographs = \"../../ceph/cryo_sbi_data/micrographs/\"\n",
"paths_to_micrographs = [\n",
" f\"{folder_with_micrographs}{file_name}\" for file_name in os.listdir(folder_with_micrographs) if file_name.endswith(\"mrc\")\n",
" f\"{folder_with_micrographs}{file_name}\"\n",
" for file_name in os.listdir(folder_with_micrographs)\n",
" if file_name.endswith(\"mrc\")\n",
"]\n",
"for path in paths_to_micrographs:\n",
" assert os.path.isfile(path), f\"Path {path} does not exist\""
Expand Down Expand Up @@ -294,7 +297,9 @@
"metadata": {},
"outputs": [],
"source": [
"average_psd = torch.load(\"../../ceph/cryo_sbi_data/whitening_filter/average_psd_19_micrographs.pt\")"
"average_psd = torch.load(\n",
" \"../../ceph/cryo_sbi_data/whitening_filter/average_psd_19_micrographs.pt\"\n",
")"
]
},
{
Expand Down Expand Up @@ -455,8 +460,8 @@
" linewidth=2,\n",
")\n",
"\n",
"#xticks = np.load(\"../../6wxb/6wxb_nma/nma_files/distance_to_reference_bending_mode.npy\")\n",
"#plt.xlabel(\"Posterior means (RMSD [A] to Reference)\", fontsize=15)\n",
"# xticks = np.load(\"../../6wxb/6wxb_nma/nma_files/distance_to_reference_bending_mode.npy\")\n",
"# plt.xlabel(\"Posterior means (RMSD [A] to Reference)\", fontsize=15)\n",
"plt.legend()\n",
"plt.yticks([])\n",
"if save_figures:\n",
Expand Down Expand Up @@ -579,11 +584,11 @@
" linewidth=2,\n",
")\n",
"\n",
"#xticks = np.load(\"../../6wxb/6wxb_nma/nma_files/distance_to_reference_bending_mode.npy\")\n",
"#plt.xlabel(\"Posterior means (RMSD [A] to Reference)\", fontsize=15)\n",
"#plt.xticks(\n",
"# xticks = np.load(\"../../6wxb/6wxb_nma/nma_files/distance_to_reference_bending_mode.npy\")\n",
"# plt.xlabel(\"Posterior means (RMSD [A] to Reference)\", fontsize=15)\n",
"# plt.xticks(\n",
"# ticks=[0, 20, 40, 60, 80, 100], labels=list(map(lambda x: f\"{x:.2f}\", xticks[::20]))\n",
"#)\n",
"# )\n",
"plt.legend()\n",
"plt.yticks([])\n",
"if save_figures:\n",
Expand Down Expand Up @@ -617,7 +622,7 @@
" cum_counts = np.cumsum(counts) / np.sum(counts)\n",
" plt.plot(bins[1:], cum_counts, label=f\"{2*alpha:.2f}\")\n",
"\n",
"#avg_index_to_rmsd = np.mean(np.abs(xticks[1:] - xticks[:-1]))\n",
"# avg_index_to_rmsd = np.mean(np.abs(xticks[1:] - xticks[:-1]))\n",
"\"\"\"x_tick_pos = np.linspace(0, 100, 9)\n",
"x_tick_labels = map(lambda x: f\"{x:.2f}\", x_tick_pos * avg_index_to_rmsd.item())\n",
"plt.xticks(\n",
Expand Down Expand Up @@ -702,6 +707,7 @@
"outputs": [],
"source": [
"from sklearn.manifold import SpectralEmbedding\n",
"\n",
"reducer = SpectralEmbedding(n_components=2)\n",
"embedding = reducer.fit_transform(cat_latent_samples.numpy())"
]
Expand Down Expand Up @@ -784,7 +790,9 @@
"metadata": {},
"outputs": [],
"source": [
"latent_embedding = manifold.SpectralEmbedding(n_components=5).fit_transform(particles_latent[::10])"
"latent_embedding = manifold.SpectralEmbedding(n_components=5).fit_transform(\n",
" particles_latent[::10]\n",
")"
]
},
{
Expand Down
84 changes: 31 additions & 53 deletions notebooks/test_gpu_simulation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@
"metadata": {},
"outputs": [],
"source": [
"%%timeit \n",
"%%timeit\n",
"project_density(\n",
" models[0],\n",
" torch.tensor(2.0),#torch.tensor([1.0, 3.0, 2.5, 3.3]),\n",
" torch.tensor(2.0), # torch.tensor([1.0, 3.0, 2.5, 3.3]),\n",
" torch.tensor(128),\n",
" torch.tensor(2.0),\n",
")"
Expand All @@ -220,12 +220,14 @@
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(project_density(\n",
" models[0],\n",
" torch.tensor(2.0),#torch.tensor([1.0, 3.0, 2.5, 3.3]),\n",
" torch.tensor(128),\n",
" torch.tensor(2.0)\n",
"))"
"plt.imshow(\n",
" project_density(\n",
" models[0],\n",
" torch.tensor(2.0), # torch.tensor([1.0, 3.0, 2.5, 3.3]),\n",
" torch.tensor(128),\n",
" torch.tensor(2.0),\n",
" )\n",
")"
]
},
{
Expand Down Expand Up @@ -292,11 +294,8 @@
"metadata": {},
"outputs": [],
"source": [
"%%timeit \n",
"gen_img(\n",
" models[0],\n",
" config\n",
")"
"%%timeit\n",
"gen_img(models[0], config)"
]
},
{
Expand Down Expand Up @@ -351,57 +350,57 @@
"def gen_rot_matrix(quat: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Generate a rotation matrix from a quaternion.\n",
" \n",
"\n",
" Args:\n",
" quat (torch.Tensor): Quaternion\n",
" \n",
"\n",
" Returns:\n",
" rot_matrix (torch.Tensor): Rotation matrix\n",
" \"\"\"\n",
"\n",
" rot_matrix = torch.zeros((3, 3), device=quat.device)\n",
"\n",
" rot_matrix[0, 0] = 1 - 2 * (quat[2]**2 + quat[3]**2)\n",
" rot_matrix[0, 0] = 1 - 2 * (quat[2] ** 2 + quat[3] ** 2)\n",
" rot_matrix[0, 1] = 2 * (quat[1] * quat[2] - quat[3] * quat[0])\n",
" rot_matrix[0, 2] = 2 * (quat[1] * quat[3] + quat[2] * quat[0])\n",
"\n",
" rot_matrix[1, 0] = 2 * (quat[1] * quat[2] + quat[3] * quat[0])\n",
" rot_matrix[1, 1] = 1 - 2 * (quat[1]**2 + quat[3]**2)\n",
" rot_matrix[1, 1] = 1 - 2 * (quat[1] ** 2 + quat[3] ** 2)\n",
" rot_matrix[1, 2] = 2 * (quat[2] * quat[3] - quat[1] * quat[0])\n",
"\n",
" rot_matrix[2, 0] = 2 * (quat[1] * quat[3] - quat[2] * quat[0])\n",
" rot_matrix[2, 1] = 2 * (quat[2] * quat[3] + quat[1] * quat[0])\n",
" rot_matrix[2, 2] = 1 - 2 * (quat[1]**2 + quat[2]**2)\n",
" rot_matrix[2, 2] = 1 - 2 * (quat[1] ** 2 + quat[2] ** 2)\n",
"\n",
" return -rot_matrix\n",
"\n",
"\n",
"def gen_rot_matrix_batched(quats: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Generate a rotation matrix from a quaternion.\n",
" \n",
"\n",
" Args:\n",
" quat (torch.Tensor): Quaternion\n",
" \n",
"\n",
" Returns:\n",
" rot_matrix (torch.Tensor): Rotation matrix\n",
" \"\"\"\n",
"\n",
" rot_matrix = torch.zeros((quats.shape[0], 3, 3), device=quats.device)\n",
"\n",
" rot_matrix[:, 0, 0] = 1 - 2 * (quats[:, 2]**2 + quats[:, 3]**2)\n",
" rot_matrix[:, 0, 0] = 1 - 2 * (quats[:, 2] ** 2 + quats[:, 3] ** 2)\n",
" rot_matrix[:, 0, 1] = 2 * (quats[:, 1] * quats[:, 2] - quats[:, 3] * quats[:, 0])\n",
" rot_matrix[:, 0, 2] = 2 * (quats[:, 1] * quats[:, 3] + quats[:, 2] * quats[:, 0])\n",
"\n",
" rot_matrix[:, 1, 0] = 2 * (quats[:, 1] * quats[:, 2] + quats[:, 3] * quats[:, 0])\n",
" rot_matrix[:, 1, 1] = 1 - 2 * (quats[:, 1]**2 + quats[:, 3]**2)\n",
" rot_matrix[:, 1, 1] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 3] ** 2)\n",
" rot_matrix[:, 1, 2] = 2 * (quats[:, 2] * quats[:, 3] - quats[:, 1] * quats[:, 0])\n",
"\n",
" rot_matrix[:, 2, 0] = 2 * (quats[:, 1] * quats[:, 3] - quats[:, 2] * quats[:, 0])\n",
" rot_matrix[:, 2, 1] = 2 * (quats[:, 2] * quats[:, 3] + quats[:, 1] * quats[:, 0])\n",
" rot_matrix[:, 2, 2] = 1 - 2 * (quats[:, 1]**2 + quats[:, 2]**2)\n",
" rot_matrix[:, 2, 2] = 1 - 2 * (quats[:, 1] ** 2 + quats[:, 2] ** 2)\n",
"\n",
" return -rot_matrix\n",
"\n"
" return -rot_matrix"
]
},
{
Expand Down Expand Up @@ -473,7 +472,7 @@
"metadata": {},
"outputs": [],
"source": [
"i[:, 10: 246, 10: 246].shape"
"i[:, 10:246, 10:246].shape"
]
},
{
Expand All @@ -482,7 +481,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"def project_density(\n",
" coord: torch.Tensor, sigma: float, num_pxels: int, pixel_size: float\n",
") -> torch.Tensor:\n",
Expand Down Expand Up @@ -539,12 +537,7 @@
],
"source": [
"%%timeit\n",
"mini_pro = project_density(\n",
" models[0],\n",
" 2.0,\n",
" 80,\n",
" 2.06\n",
")"
"mini_pro = project_density(models[0], 2.0, 80, 2.06)"
]
},
{
Expand All @@ -554,12 +547,7 @@
"outputs": [],
"source": [
"%%timeit\n",
"mini_pro = project_density(\n",
" models[0],\n",
" 2.0,\n",
" 60,\n",
" 2.06\n",
")"
"mini_pro = project_density(models[0], 2.0, 60, 2.06)"
]
},
{
Expand All @@ -569,12 +557,7 @@
"outputs": [],
"source": [
"%%timeit\n",
"mini_pro = project_density(\n",
" models[0],\n",
" 2.0,\n",
" 128,\n",
" 2.06\n",
")"
"mini_pro = project_density(models[0], 2.0, 128, 2.06)"
]
},
{
Expand All @@ -583,12 +566,7 @@
"metadata": {},
"outputs": [],
"source": [
"mini_pro = project_density(\n",
" models[50],\n",
" 2.0,\n",
" 70,\n",
" 2.06\n",
")\n",
"mini_pro = project_density(models[50], 2.0, 70, 2.06)\n",
"plt.imshow(mini_pro)"
]
},
Expand All @@ -601,7 +579,7 @@
"pixel_size = 2.06\n",
"num_pxels = 128\n",
"coord = models[50]\n",
"sigma=2"
"sigma = 2"
]
},
{
Expand Down Expand Up @@ -630,7 +608,7 @@
}
],
"source": [
"%%timeit \n",
"%%timeit\n",
"gauss_x = torch.exp_(-0.5 * (((grid[:, None] - coord[0, :]) / sigma) ** 2))\n",
"gauss_y = torch.exp_(-0.5 * (((grid[:, None] - coord[1, :]) / sigma) ** 2))"
]
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator
from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator
6 changes: 5 additions & 1 deletion src/cryo_sbi/inference/command_line_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def cl_npe_train_no_saving():
"--saving_freq", action="store", type=int, required=False, default=20
)
cl_parser.add_argument(
"--simulation_batch_size", action="store", type=int, required=False, default=1024
"--simulation_batch_size",
action="store",
type=int,
required=False,
default=1024,
)

args = cl_parser.parse_args()
Expand Down
22 changes: 12 additions & 10 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET34_256_LP")
class ResNet34_Encoder(nn.Module):
def __init__(self, output_dimension: int):
Expand All @@ -324,26 +324,28 @@ def forward(self, x):
class VGG19_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(VGG19_Encoder, self).__init__()

self.vgg19 = models.vgg19_bn().features
self.vgg19[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

self.vgg19[0] = nn.Conv2d(
1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)

self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))

self.feedforward = nn.Sequential(
*[
nn.Linear(in_features=25088, out_features=4096),
nn.ReLU(inplace=True),
nn.Linear(in_features=4096, out_features=output_dimension, bias=True),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
]
)
#self._fft_filter = LowPassFilter(256, 50)

# self._fft_filter = LowPassFilter(256, 50)

def forward(self, x):
# Low pass filter images
#x = self._fft_filter(x)
# x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.vgg19(x)
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
"""

samples_standardized = self.flow(x).sample(shape)
return self.standardize.transform(samples_standardized)
return self.standardize.transform(samples_standardized)
Loading

0 comments on commit d953d01

Please sign in to comment.