Skip to content

Commit

Permalink
first prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 3, 2023
1 parent f66ff91 commit 65eb909
Show file tree
Hide file tree
Showing 19 changed files with 1,312 additions and 1,024 deletions.
900 changes: 900 additions & 0 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions notebooks/hsp90_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "488a4635-ea67-4728-b016-a9ecebe23e4a",
"metadata": {},
Expand Down Expand Up @@ -450,7 +449,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a7213a4a-3a5d-4e04-9edd-d1906ff02818",
"metadata": {},
Expand Down Expand Up @@ -559,7 +557,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5050ccaa-fcff-4e37-8255-ed65a172e6a1",
"metadata": {},
Expand Down Expand Up @@ -682,7 +679,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "cryo_sbi",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -696,7 +693,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0]"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
Expand Down
2 changes: 1 addition & 1 deletion notebooks/image_params_mixed_training.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"SNR": [0.01, 0.1],
"RADIUS_MASK": 100,
"AMP": 0.1,
"B_FACTOR": 1.0,
"B_FACTOR": [1.0, 100.0],
"ELECWAVE": 0.019866,
"NOISE_INTENSITY": 0.5
}
4 changes: 2 additions & 2 deletions notebooks/plots_hsp90.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "cryo_sbi",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -694,7 +694,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0]"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
Expand Down
12 changes: 12 additions & 0 deletions notebooks/resnet18_encoder.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{"EMBEDDING": "RESNET18",
"OUT_DIM": 256,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 10,
"HIDDEN_DIM_FLOW": 256,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 50,
"THETA_SCALE": 50,
"BATCH_SIZE": 256
}
3 changes: 1 addition & 2 deletions src/cryo_sbi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator
from cryo_sbi.inference.generate_training_set import gen_training_set

145 changes: 0 additions & 145 deletions src/cryo_sbi/inference/command_line_tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
from cryo_sbi.inference.train_npe_model import (
npe_train_no_saving,
npe_train_from_vram,
npe_train_from_disk,
)
from cryo_sbi.inference.generate_training_set import gen_training_set

Expand Down Expand Up @@ -54,149 +52,6 @@ def cl_npe_train_no_saving():
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
model_state_dict=args.state_dict_file,
n_workers=args.n_workers,
device=args.train_device,
saving_frequency=args.saving_freq,
whitening_filter=args.whitening_filter,
)


def cl_npe_train_from_vram():
cl_parser = argparse.ArgumentParser()
cl_parser.add_argument(
"--train_config_file", action="store", type=str, required=True
)
cl_parser.add_argument("--epochs", action="store", type=int, required=True)
cl_parser.add_argument(
"--training_data_file", action="store", type=str, required=True
)
cl_parser.add_argument(
"--validation_data_file", action="store", type=str, required=True
)
cl_parser.add_argument("--estimator_file", action="store", type=str, required=True)
cl_parser.add_argument("--loss_file", action="store", type=str, required=True)
cl_parser.add_argument(
"--train_from_checkpoint",
action="store",
type=bool,
nargs="?",
required=False,
const=True,
default=False,
)
cl_parser.add_argument(
"--state_dict_file", action="store", type=str, required=False, default=False
)
cl_parser.add_argument(
"--train_device", action="store", type=str, required=False, default="cpu"
)
cl_parser.add_argument(
"--saving_freq", action="store", type=int, required=False, default=20
)
args = cl_parser.parse_args()

npe_train_from_vram(
train_config=args.train_config_file,
epochs=args.epochs,
train_data_dir=args.training_data_file,
val_data_Dir=args.validation_data_file,
estimator_file=args.estimator_file,
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
state_dict_file=args.state_dict_file,
device=args.train_device,
saving_frequency=args.saving_freq,
)


def cl_npe_train_from_disk():
cl_parser = argparse.ArgumentParser()
cl_parser.add_argument(
"--train_config_file", action="store", type=str, required=True
)
cl_parser.add_argument("--epochs", action="store", type=int, required=True)
cl_parser.add_argument(
"--training_data_file", action="store", type=str, required=True
)
cl_parser.add_argument(
"--validation_data_file", action="store", type=str, required=True
)
cl_parser.add_argument("--estimator_file", action="store", type=str, required=True)
cl_parser.add_argument("--loss_file", action="store", type=str, required=True)
cl_parser.add_argument(
"--train_from_checkpoint",
action="store",
type=bool,
nargs="?",
required=False,
const=True,
default=False,
)
cl_parser.add_argument(
"--state_dict_file", action="store", type=str, required=False, default=False
)

cl_parser.add_argument(
"--n_workers", action="store", type=int, required=False, default=0
)
cl_parser.add_argument(
"--train_device", action="store", type=str, required=False, default="cpu"
)
cl_parser.add_argument(
"--saving_freq", action="store", type=int, required=False, default=20
)
args = cl_parser.parse_args()

npe_train_from_disk(
train_config=args.train_config_file,
epochs=args.epochs,
train_data_dir=args.training_data_file,
val_data_Dir=args.validation_data_file,
estimator_file=args.estimator_file,
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
state_dict_file=args.state_dict_file,
device=args.train_device,
saving_frequency=args.saving_freq,
)


def cl_generate_training_data():
cl_parser = argparse.ArgumentParser()

cl_parser.add_argument("--config_file", action="store", type=str, required=True)

cl_parser.add_argument(
"--num_train_samples", action="store", type=int, required=True
)

cl_parser.add_argument("--num_val_samples", action="store", type=int, required=True)

cl_parser.add_argument("--file_name", action="store", type=str, required=True)

cl_parser.add_argument(
"--save_as_tensor",
action="store",
type=bool,
nargs="?",
required=False,
const=True,
default=False,
)

cl_parser.add_argument("--n_workers", action="store", type=int, required=True)

cl_parser.add_argument(
"--batch_size", action="store", type=int, required=False, default=1000
)

args = cl_parser.parse_args()
gen_training_set(
args.config_file,
args.num_train_samples,
args.num_val_samples,
args.file_name,
args.save_as_tensor,
args.n_workers,
args.batch_size,
)
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/generate_training_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from lampe.data import JointLoader, H5Dataset
from tqdm import tqdm

from cryo_sbi.inference.priors import get_uniform_prior_1d
from cryo_sbi import CryoEmSimulator
# from cryo_sbi.inference.priors import get_uniform_prior_1d
# from cryo_sbi import CryoEmSimulator


def gen_training_set(
Expand Down
22 changes: 22 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,27 @@ def forward(self, x):
return x


@add_embedding("RESNET18_FFT_FILTER_132")
class ResNet18_FFT_Encoder_132(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_FFT_Encoder_132, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(132, 25)

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x

if __name__ == "__main__":
pass
Loading

0 comments on commit 65eb909

Please sign in to comment.