Skip to content

Commit

Permalink
more hypothesis testing on MMD
Browse files Browse the repository at this point in the history
  • Loading branch information
aevans1 committed Sep 6, 2023
1 parent c2a520c commit ac9cdf5
Show file tree
Hide file tree
Showing 34 changed files with 6,058 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Project specific files
ASPIRE-Python/
messing_around/
tests/test_simulator.ipynb
*.estimator
Expand Down
537 changes: 537 additions & 0 deletions 6wxb_files/6wxb_MMD.ipynb

Large diffs are not rendered by default.

106 changes: 73 additions & 33 deletions Lukes_folder/6xwb_images.ipynb → 6wxb_files/6wxb_images.ipynb

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions 6wxb_files/MMD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

def MMD_gaussian(X, Y, bandwidth=None):

# Compute all pairwise distances
Xdists = torch.cdist(X, X)**2
Ydists = torch.cdist(Y, Y)**2
XYdists = torch.cdist(X, Y)**2

# Use heuristic for bandwidth if none provided
if bandwidth == None:
bandwidth = torch.sqrt(torch.median(XYdists)*2)

# Compute all kernel sums
Xterm = torch.exp(-Xdists[None, ...]/bandwidth**2).mean()
Yterm = torch.exp(-Ydists[None, ...]/bandwidth**2).mean()
XYterm = torch.exp(-XYdists[None, ...]/bandwidth**2).mean()
return Xterm + Yterm - 2*XYterm

11 changes: 11 additions & 0 deletions 6wxb_files/image_params_mixed_training.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"N_PIXELS": 128,
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
"MODEL_FILE": "../data/protein_models/6wxb_mixed_models.npy",
"SHIFT": 40,
"DEFOCUS": [0.5, 5.0],
"SNR": [0.01,1.0],
"AMP": 0.1,
"B_FACTOR": [1.0, 100.0]
}
12 changes: 12 additions & 0 deletions 6wxb_files/resnet18_encoder.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{"EMBEDDING": "RESNET18",
"OUT_DIM": 256,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 10,
"HIDDEN_DIM_FLOW": 256,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 50,
"THETA_SCALE": 50,
"BATCH_SIZE": 256
}
12 changes: 12 additions & 0 deletions 6wxb_files/resnet18_fft_encoder.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{"EMBEDDING": "RESNET18_FFT_FILTER",
"OUT_DIM": 256,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 10,
"HIDDEN_DIM_FLOW": 256,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 50,
"THETA_SCALE": 50,
"BATCH_SIZE": 256
}
3 changes: 0 additions & 3 deletions Lukes_folder/6wxb/image_params_mixed_training.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
"MODEL_FILE": "../data/protein_models/6wxb_mixed_models.npy",
"ROTATIONS": true,
"SHIFT": true,
"CTF": true,
"NOISE": true,
"DEFOCUS": [0.5, 5.0],
"SNR": [0.001, 0.1],
"RADIUS_MASK": 100,
Expand Down
Loading

0 comments on commit ac9cdf5

Please sign in to comment.