Skip to content

Commit

Permalink
Trying out CME free energy
Browse files Browse the repository at this point in the history
  • Loading branch information
aevans1 committed Aug 3, 2023
1 parent 9823753 commit fd80fc3
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 2 deletions.
Binary file removed Lukes_folder/Lars_hsp90.zip
Binary file not shown.
Binary file removed Lukes_folder/Lars_hsp90/hsp90_posterior
Binary file not shown.
2 changes: 1 addition & 1 deletion Lukes_folder/MMD_testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"source": [
"train_config = json.load(open(\"Lars_hsp90/resnet18_encoder.json\"))\n",
"estimator = build_models.build_npe_flow_model(train_config)\n",
"estimator.load_state_dict(torch.load(\"Lars_hsp90/hsp90_posterior_alt.estimator\"))\n",
"estimator.load_state_dict(torch.load(\"Lars_hsp90/hsp90_posterior.estimator\"))\n",
"estimator.cuda()\n",
"estimator.eval();\n"
]
Expand Down
274 changes: 274 additions & 0 deletions Lukes_folder/cryoBIFE_test.ipynb

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions Lukes_folder/train_BIFE_prior.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Data/Packages/Utilities/miniconda3/envs/cryosbi_env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from cryo_sbi import CryoEmSimulator\n",
"from cryo_sbi import gen_training_set\n",
"from cryo_sbi.inference.NPE_train_from_disk import npe_train_from_disk\n",
"from cryo_sbi.inference.NPE_train_without_saving import npe_train_no_saving"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating particles and then training with them"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hsp90_models.npy\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [01:13<00:00, 136.60pair/s]\n",
"100%|██████████| 100/100 [00:00<00:00, 4779.94pair/s]\n"
]
}
],
"source": [
"gen_training_set(\n",
" config_file=\"config_file.json\",\n",
" num_train_samples=10000,\n",
" num_val_samples=100,\n",
" file_name=\"tut_imgs\",\n",
" save_as_tensor=False,\n",
" n_workers=2,\n",
" batch_size=100,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training neural netowrk:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [04:44<00:00, 9.48s/epoch, train_loss=-.981, val_loss=0.108] \n"
]
}
],
"source": [
"npe_train_from_disk(\n",
" train_config=\"resnet18_encoder.json\",\n",
" epochs=30,\n",
" train_data_dir=\"tut_imgs_train.h5\",\n",
" val_data_dir=\"tut_imgs_valid.h5\",\n",
" estimator_file=\"tut_estimator\",\n",
" loss_file=\"tut_loss\",\n",
" train_from_checkpoint=False,\n",
" model_state_dict=None,\n",
" n_workers=2,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training without saving images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"npe_train_no_saving(\n",
" image_config=\"image_params_snr01_128.json\",\n",
" train_config=\"resnet18_encoder.json\",\n",
" epochs=350,\n",
" estimator_file=\"resnet18_encoder.estimator\",\n",
" loss_file=\"resnet18_encoder.estimator\",\n",
" n_workers=2, # CHANGE\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cryosbi_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion Lukes_folder/trying_it_out.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
"source": [
"train_config = json.load(open(\"Lars_hsp90/resnet18_encoder.json\"))\n",
"estimator = build_models.build_npe_flow_model(train_config)\n",
"estimator.load_state_dict(torch.load(\"Lars_hsp90/hsp90_posterior_alt.estimator\"))\n",
"estimator.load_state_dict(torch.load(\"Lars_hsp90/hsp90_posterior.estimator\"))\n",
"estimator.cuda()\n",
"estimator.eval();\n"
]
Expand Down

0 comments on commit fd80fc3

Please sign in to comment.