Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: docstrings and overall design for expectation maximization iterative refinement #21

Open
wants to merge 175 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
175 commits
Select commit Hold shift + click to select a range
a2629da
Init
jedyeo Mar 20, 2022
ccc4359
precommit yaml
jedyeo Mar 20, 2022
186fddc
remove ref
jedyeo Mar 20, 2022
62b9f42
Format code with black and isort
deepsource-autofix[bot] Mar 20, 2022
ee5adab
Docstrings, precommit broken?
jedyeo Mar 20, 2022
b027987
Format code with black and isort
deepsource-autofix[bot] Mar 20, 2022
26dcf79
Update docstrings
jedyeo Mar 21, 2022
2110a8a
Format code with black and isort
deepsource-autofix[bot] Mar 21, 2022
f493da6
Dummy tests
Mar 22, 2022
00b8afb
docstrings and tests
jedyeo Mar 22, 2022
387c8ed
Format code with black and isort
deepsource-autofix[bot] Mar 22, 2022
be31e4c
Minor ds change. Fixed tests w/ fixture
jedyeo Mar 22, 2022
bb3b232
Minor ds change. Fixed tests w/ fixture
jedyeo Mar 22, 2022
7d42322
Format code with black and isort
deepsource-autofix[bot] Mar 22, 2022
e3becc4
Minor docstring changes
jedyeo Mar 23, 2022
e12673d
minor ds changes
jedyeo Mar 23, 2022
dd4d35b
ds changes
jedyeo Mar 23, 2022
2429d1b
Change class docstring
jedyeo Mar 24, 2022
e8503c2
Remove todo's breaking deepsource
jedyeo Mar 24, 2022
87a9a87
Format code with black and isort
deepsource-autofix[bot] Mar 24, 2022
65273f0
DeepSource changes
jedyeo Mar 24, 2022
5309ac8
deepsoruce changes
jedyeo Mar 24, 2022
d5ea497
Format code with black and isort
deepsource-autofix[bot] Mar 24, 2022
1a79772
literals removed
jedyeo Mar 24, 2022
c473d43
literals removed
jedyeo Mar 24, 2022
38a9a5e
add pre commit
jedyeo Mar 24, 2022
7a4de6f
Init tests
jedyeo Mar 24, 2022
df6c749
temporary static methods
jedyeo Mar 24, 2022
6849b0c
Format code with black and isort
deepsource-autofix[bot] Mar 24, 2022
4e927ab
temporary static methods
jedyeo Mar 24, 2022
93ca7fc
ds changes
jedyeo Mar 24, 2022
ea2c766
ds changes
jedyeo Mar 24, 2022
714b33d
changed deepsource
jedyeo Mar 24, 2022
5b4db6e
lint
jedyeo Mar 24, 2022
ccd0d04
add black back
jedyeo Mar 24, 2022
c8f3c60
Format code with black
deepsource-autofix[bot] Mar 24, 2022
c2726b2
update toml
jedyeo Mar 24, 2022
7f96273
ds
jedyeo Mar 24, 2022
3d4ad77
added .ini for pytest
jedyeo Mar 24, 2022
b293181
init
jedyeo Mar 24, 2022
02eec0a
Codecov files
Mar 24, 2022
49138c7
Linting files
Mar 24, 2022
6135c72
Dev requirements
Mar 24, 2022
f54ece8
Environment config
Mar 24, 2022
b9b540e
Added setup.py
Mar 24, 2022
afd7290
env yaml
jedyeo Mar 24, 2022
5a4a56e
Reformatted docstrings
Mar 24, 2022
e504e32
Format code with black
deepsource-autofix[bot] Mar 24, 2022
a282ffc
Testing workspace changes
Mar 24, 2022
43a440e
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
1a0af09
Fixed test docstrings
Mar 24, 2022
3d1be9b
Format code with black
deepsource-autofix[bot] Mar 24, 2022
5268b73
Edited module name
Mar 24, 2022
3e638af
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
984af8a
Linting fixes
Mar 24, 2022
eecb97e
Format code with black
deepsource-autofix[bot] Mar 24, 2022
f7f98b4
Linting fixes
Mar 24, 2022
1a79652
Linting fixes
Mar 24, 2022
d2faa25
Fixed imports
Mar 24, 2022
a3eedd6
Readded simSPI import
Mar 24, 2022
c898048
Fixed function reference
Mar 24, 2022
71356a5
Fixed test_build_ctf_array
Mar 24, 2022
0522a91
tets
jedyeo Mar 24, 2022
ee0ef8f
Format code with black
deepsource-autofix[bot] Mar 24, 2022
7141f37
Reworked tests to use consistent array sizing
Mar 24, 2022
5d0d1a1
Tab snuck in
Mar 24, 2022
7289762
Format code with black
deepsource-autofix[bot] Mar 24, 2022
f1bd414
Comment change to force checks
Mar 24, 2022
756adbe
Fixes to tests
Mar 24, 2022
40a6fb0
Refactored tests a bit
Mar 24, 2022
5c74fb3
Fix to test_split_array
Mar 24, 2022
c210d36
Fix to test_generate_xy_plane
Mar 24, 2022
8753a1f
Docstring fix
Mar 24, 2022
1ce2209
Deepsource tricked me
Mar 24, 2022
0673ebb
Dosctring to force checks
Mar 24, 2022
ae8c96f
Fixture fix:
Mar 24, 2022
16a0b90
Format code with black
deepsource-autofix[bot] Mar 24, 2022
3c3595c
Format code with black
deepsource-autofix[bot] Mar 24, 2022
a5bc9ee
Format code with black
deepsource-autofix[bot] Mar 24, 2022
b76ec98
Format code with black
deepsource-autofix[bot] Mar 24, 2022
b58ba2f
Forcing checks again...
Mar 24, 2022
1fb62fa
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
98a879c
Format code with black
deepsource-autofix[bot] Mar 24, 2022
a1dc0cd
Better ctf_info for tests
Mar 24, 2022
5818e9e
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
4bcf15a
Format code with black
deepsource-autofix[bot] Mar 24, 2022
ecee780
Strings not variables...
Mar 24, 2022
f9c4196
Merging
Mar 24, 2022
0ac923c
Format code with black
deepsource-autofix[bot] Mar 24, 2022
cb7b069
Fix
Mar 24, 2022
0fb3b1b
Fix
Mar 24, 2022
966ed3c
test fixes
Mar 24, 2022
ccbd07b
Format code with black
deepsource-autofix[bot] Mar 24, 2022
4858d54
Test fixes
Mar 24, 2022
b7b8688
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
bc88498
Force checks
Mar 24, 2022
ab014c5
Format code with black
deepsource-autofix[bot] Mar 24, 2022
656b5d2
test fixes
Mar 24, 2022
4a86e20
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 24, 2022
b9e72ad
Forcing checks
Mar 24, 2022
5577ded
Reshape fix
Mar 24, 2022
bb6e258
Shape fix
Mar 24, 2022
81e5faa
Structure changes
Mar 24, 2022
a2f4b52
Test framework adjustment
Mar 24, 2022
17f94ca
Import change
Mar 24, 2022
9f7ee21
Small typo fixes
thisFreya Mar 24, 2022
0d418b9
Added a slightly more comprehensive split test
thisFreya Mar 24, 2022
e6410cd
Format code with black
deepsource-autofix[bot] Mar 24, 2022
511d105
Added back references
Mar 24, 2022
b575e13
Added the big method
thisFreya Mar 25, 2022
41a5807
Format code with black
deepsource-autofix[bot] Mar 25, 2022
22efe76
Pre-commit and black changes
thisFreya Mar 25, 2022
5a22371
Merging
thisFreya Mar 25, 2022
bda6275
Deepsource
thisFreya Mar 25, 2022
976be19
Expanded split_array to allow lists and tuples
thisFreya Mar 25, 2022
00e7281
Format code with black
deepsource-autofix[bot] Mar 25, 2022
ff36588
Deepsource fix
thisFreya Mar 25, 2022
8de2090
Merge branch 'itr_ref_docstrings' of https://github.com/compSPI/recon…
thisFreya Mar 25, 2022
b8e37ef
Fixed split_array expansion
thisFreya Mar 25, 2022
abfd9aa
Fix to fsc test
thisFreya Mar 25, 2022
e1cd120
Data type fix
thisFreya Mar 25, 2022
83d9cf7
Concatenate fix
thisFreya Mar 25, 2022
f8dc07a
Change to insert slice shape
Mar 25, 2022
352a989
Format code with black
deepsource-autofix[bot] Mar 25, 2022
d381e3c
fix to numpy array initialization
Mar 25, 2022
56f02a1
Merging
Mar 25, 2022
294d4b5
Format code with black
deepsource-autofix[bot] Mar 25, 2022
3819306
num rotations change
Mar 25, 2022
2b25434
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 25, 2022
f28ccfa
Format code with black
deepsource-autofix[bot] Mar 25, 2022
c053980
fsc fix
Mar 25, 2022
89cdaa8
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 25, 2022
7083a54
Forgot what a dot product did for a second there
Mar 25, 2022
2a1eda7
Updated iterative refinement test to reflect fsc fix
Mar 25, 2022
cba9320
Removed superfluous imports
Mar 27, 2022
d40e6ae
Updated docstrings, fixed splitting arrs
Mar 27, 2022
6c6e1f0
Split n into n_pix, n_particles
Mar 27, 2022
1e11a9d
Format code with black
deepsource-autofix[bot] Mar 27, 2022
d621fb1
Added back numba
Mar 27, 2022
b48b5ef
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 27, 2022
496f88a
Force checks
Mar 27, 2022
8aaff78
n_pix in fft docstrings
Mar 28, 2022
33e3229
Normalizing half maps
Mar 28, 2022
01ceaf5
Format code with black
deepsource-autofix[bot] Mar 28, 2022
20f2232
Removed comments, refactored big method, fixed bayesian weights shapes
Mar 31, 2022
e76d8a2
Format code with black
deepsource-autofix[bot] Mar 31, 2022
6e7ed32
Infrastructure files
Mar 31, 2022
6a9ee69
Infrastructure files
Mar 31, 2022
fd9cefa
Missed a name
Mar 31, 2022
bb8b274
Removed last few comments, will re-add if needed
Mar 31, 2022
66c0ef3
Format code with black
deepsource-autofix[bot] Mar 31, 2022
b9e8c10
Merging changes
Mar 31, 2022
579f716
Fixed variable name
Mar 31, 2022
6fe971c
Merge branch 'itr_ref_docstrings' of github.com:compSPI/reconstructSP…
Mar 31, 2022
55ee718
Changed codecov parameters
Mar 31, 2022
0d18f66
Removed __init__.py files
Mar 31, 2022
7f81456
Revert "Removed __init__.py files"
Mar 31, 2022
887bff7
Merging
Mar 31, 2022
e583a54
Revert "Changed codecov parameters"
Mar 31, 2022
aa607f3
Merge branch 'reconstructSPI_infrastructure' of github.com:compSPI/re…
Mar 31, 2022
15840cf
Refactoring library format
Mar 31, 2022
2f91bc6
Directory fixes
Mar 31, 2022
bccb41c
Revert "Refactoring library format"
Mar 31, 2022
2c42e8d
Revert "Directory fixes"
Mar 31, 2022
9126f86
Removed dependencies
Mar 31, 2022
3330c11
Testing something
Mar 31, 2022
aaf1078
Merge branch 'reconstructSPI_infrastructure' of github.com:compSPI/re…
Mar 31, 2022
ef10249
Testing things
Mar 31, 2022
d15bdb6
Merge branch 'reconstructSPI_infrastructure' of github.com:compSPI/re…
Mar 31, 2022
8dda440
Force checks
Mar 31, 2022
749bcd2
Testing things
Mar 31, 2022
2ea9459
Added library requirements
Mar 31, 2022
5152547
Added dev branch to branches on which tests will run.
Mar 31, 2022
5a4ebb5
Merging changes. Have temporarily added infrastructure branch to work…
Mar 31, 2022
9d05c1a
removed itr_ref_docstrings branch checks
Mar 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 346 additions & 0 deletions iterative_refinement/expectation_maximization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
"""
Iterative refinement in Bayesian expection maximization setting.
"""

import numpy as np
from compSPI.transforms import do_fft, do_ifft
# currently only 2D ffts in compSPI.transforms. can use torch.fft for 3d fft and convert back to numpy array


def do_iterative_refinement(map_3d_init, particles, ctf_info):
"""
Performs interative refimenent in a Bayesian expectation maximization setting,
i.e. maximum a posteriori estimation.

Input
-----
map_3d_init
initial estimate
input map
shape (n_pix,n_pix,n_pix)
particles
particles to be reconstructed
shape (n_pix,n_pix)


Returns
-------

map_3d_final
shape (n_pix,n_pix,n_pix)

map_3d_r_final
final updated map
shape (n_pix,n_pix,n_pix)
half_map_3d_r_1
half map 1
half_map_3d_r_2
half map 2
fsc_1d
final 1d fsc
shape (n_pix//2,)

"""

# split particles up into two half sets for statistical validation

def do_split(arr):
idx_half = arr.shape[0] // 2
arr_1, arr_2 = arr[:idx_half], arr[idx_half:]
assert arr_1.shape[0] == arr_2.shape[0]
return parr_1, arr_2

particles_1, particles_2 = do_split(particles)

def do_build_ctf(ctf_params):
"""
Build 2D array of ctf from ctf params

Input
___
Params of ctfs (defocus, etc)
Suggest list of dicts, one for each particle.

Returns
___
ctfs
type np.ndarray
shape (n_ctfs,n_pix,n_pix)

"""
n_ctfs = len(ctf_params)
# TODO: see simSPI.transfer
# https://github.com/compSPI/simSPI/blob/master/simSPI/transfer.py#L57
ctfs = np.ones((n_ctfs,n_pix,n_pix))

return ctfs

ctfs = do_build_ctf(ctf_info)
ctfs_1, ctfs_2 = do_split(ctfs)

# work in Fourier space. so particles can stay in Fourier space the whole time.
# they are experimental measurements and are fixed in the algorithm
particles_f_1 = do_fft(particles_1)
particles_f_2 = do_fft(particles_2)

n_pix = map_3d_init.shape[0]
# suggest 32 or 64 to start with. real data will be more like 128 or 256.
# can have issues with ctf at small pixels and need to zero pad to avoid artefacts
# artefacts from ctf not going to zero at edges, and sinusoidal ctf rippling too fast
# can zero pad when do Fourier convolution (fft is on zero paded and larger sized array)

max_n_iters = 7 # in practice programs using 3-10 iterations.

half_map_3d_r_1,half_map_3d_r_2 = map_3d_init, map_3d_init.copy()
# should diverge because different particles averaging in

for iteration in range(max_n_iters):

half_map_3d_f_1 = do_fft(half_map_3d_r_1,d=3)
half_map_3d_f_2 = do_fft(half_map_3d_r_2,d=3)


# align particles to 3D volume
# decide on granularity of rotations
# i.e. how finely the rotational SO(3) space is sampled in a grid search.
# smarter method is branch and bound...
# perhaps can make grid up front of slices, and then only compute norms on finer grid later. so re-use slices


# def do_adaptive_grid_search(particle, map_3d):
# # a la branch and bound
# # not sure exactly how you decide how finely gridded to make it.
# # perhaps heuristics based on how well the signal agrees in half_map_1, half_map_2 (Fourier frequency)


def grid_SO3_uniform(n_rotations):
"""
uniformly grid (not sample) SO(3)
can use some intermediate encoding of SO(3) like quaternions, axis angle, Euler
final output 3x3 rotations

"""
# TODO: sample over the sphere at given granularity.
# easy: draw uniform samples of rotations on sphere. lots of code for this all over the internet. quick solution in geomstats
# harder: draw samples around some rotation using ProjectedNormal distribution (ask Geoff)
# unknown difficulty: make a regular grid of SO(3) at given granularity. Khanh says non-trivial.
rots = np.ones((n_rotations,3,3))
return rots

n_rotations = 1000
rots = grid_SO3(n_rotations)

def do_xy0_plane(n_pix):
"""
generate xy0 plane
xy values are over the xy plane
all z values are 0
see how meshgrid and generate coordinates functions used in https://github.com/geoffwoollard/compSPI/blob/stash_simulate/src/simulate.py#L96

"""

### methgrid
xy0_plane = np.ones(n_pix**2,3)
return xy0

xy0_plane = do_xy0_plane(n_pix):


def do_slices(map_3d_f,rots):
"""
generates slice coordinates by rotating xy0 plane
interpolate values from map_3d_f onto 3D coordinates
see how scipy map_values used to interpolate in https://github.com/geoffwoollard/compSPI/blob/stash_simulate/src/simulate.py#L111

Returns
___
slices
slice of map_3d_f
by Fourier slice theorem corresponds to Fourier transform of projection of rotated map_3d_f

"""
n_rotations = rots.shape[0]
### TODO: map_values interpolation
xyz_rotated = np.ones_like(xy0_plane)
slices = np.random.normal(size=n_rotations*n_pix**2).reshape(n_rotations,n_pix,n_pix)
return slices, xyz_rotated

slices_1, xyz_rotated = do_slices(half_map_3d_1_f,rots) # Here rots are the same for the half maps, but could be different in general
slices_2, xyz_rotated = do_slices(half_map_3d_2_f,rots)


def do_conv_ctf(projection_f, ctf):
"""
Apply CTF to projection
"""

# TODO: vectorize and have shape match
projection_f_conv_ctf = ctf*projection_f
return



def do_bayesean_weights(particle, slices):
"""
compute bayesean weights of particle to slice
under gaussian white noise model

Input
____
slices
shape (n_slices, n_pix,n_pix)
dtype complex32 or complex64

Returns
___
bayesean_weights
shape (n_slices,)
dtyle float32 or float64

"""
n_slices = slices.shape[0]
particle_l2 = np.linalg.norm(particle, ord='fro')**2
slices_l2 = np.linalg.norm(slices,axis=(1,2),ord='fro')**2 # TODO: check right axis. should n_slices l2 norms, one for each slice
# can precompute slices_l2 and keep for all particles if slices the same for different particles

corr = np.zeros(slices)
a_slice_corr = particle.dot(slices) # |particle|^2 - particle.dot(a_slice) + |a_slice|^2
### see Sigrowth et al and Nelson for how to get bayes factors
bayes_factors = np.random.normal(n_slices) # TODO: replace placeholder with right shape
return bayes_factors



# initialize
map_3d_f_updated_1 = np.zeros_like(half_map_3d_f_1) # complex
map_3d_f_updated_2 = np.zeros_like(half_map_3d_f_2) # complex
counts_3d_updated_1 = np.zeros_like(half_map_3d_r_1) # float/real
counts_3d_updated_2 = np.zeros_like(half_map_3d_r_2) # float/real

for particle_idx in range(particles_1_f.shape[0]):
ctf_1 = ctfs_1[particle_idx]
ctf_2 = ctfs_2[particle_idx]
particle_f_1 = particles_f_1[particle_idx]
particle_f_2 = particles_f_2[particle_idx]

def do_wiener_filter(projection, ctf, small_number):
wfilter = ctf/(ctf*ctf+small_number)
projection_wfilter_f = projection*w_filter
return projection_wfilter_f


particle_f_deconv_1 = do_wiener_filter(particles_f_1, ctf_1)
particle_f_deconv_1 = do_wiener_filter(particles_f_1, ctf_1)

slices_conv_ctfs_1 = do_conv_ctf(slices_1, ctf_1) # all slices get convolved with the ctf for the particle
slices_conv_ctfs_2 = do_conv_ctf(slices_2, ctf_2)

bayes_factors_1 = do_bayesean_weights(particles_1_f[particle_idx], slices_conv_ctfs_1)
bayes_factors_2 = do_bayesean_weights(particles_2_f[particle_idx], slices_conv_ctfs_2)

def do_insert_slice(slice_real,xyz,n_pix):
"""
Update map_3d_f_updated with values from slice. Requires interpolation of off grid
see "Insert Fourier slices" in https://github.com/geoffwoollard/learn_cryoem_math/blob/master/nb/fourier_slice_2D_3D_with_trilinear.ipynb
# TODO: vectorize so can take in many slices;
# i.e. do the compoutation in a vetorized way and return inserted_slices_3d, counts_3d of shape (n_slice,n_pix,n_pix,n_pix)

Input
___

slice
type array of shape (n_pix,n_pix)
dtype float32 or float64. real since imaginary and real part done separately
xyz
type array of shape (n_pix**2,3)
volume_3d_shape: float n_pix




Return
___
inserted_slice_3d
count_3d

"""

volume_3d = np.zeros((n_pix,n_pix,n_pix))
# TODO: write insertion code. use linear interpolation (order of interpolation kernel) so not expensive.
# nearest neightbors cheaper, but we can afford to do better than that

return inserted_slice_3d, count_3d


for one_slice_idx in range(bayes_factors_1.shape[0]):
xyz = xyz_rotated[one_slice_idx]
inserted_slice_3d_r, count_3d_r = do_insert_slice(particle_f_deconv_1.real,xyz,volume_3d) # if this can be vectorized, can avoid loop over slices
inserted_slice_3d_i, count_3d_i = do_insert_slice(particle_f_deconv_1.imag,xyz,volume_3d) # if this can be vectorized, can avoid loop over slices
map_3d_f_updated_1 += inserted_slice_3d_r + 1j*inserted_slice_3d_i
counts_3d_updated_1 += count_3d_r + count_3d_i

for one_slice_idx in range(bayes_factors_2.shape[0]):
xyz = xyz_rotated[one_slice_idx]
inserted_slice_3d_r, count_3d_r = do_insert_slice(particle_f_deconv_2.real,xyz,volume_3d) # if this can be vectorized, can avoid loop over slices
inserted_slice_3d_i, count_3d_i = do_insert_slice(particle_f_deconv_2.imag,xyz,volume_3d) # if this can be vectorized, can avoid loop over slices
map_3d_f_updated_2 += inserted_slice_3d_r + 1j*inserted_slice_3d_i
counts_3d_updated_2 += count_3d_r + count_3d_i


# apply noise model
# half_map_1, half_map_2 come from doing the above independently
# filter by noise estimate (e.g. multiply both half maps by FSC)

def do_fsc(map_3d_f_1,map_3d_f_2):
"""
Estimate noise from half maps
for now do noise estimate as FSC between half maps

"""
# TODO: write fast vectorized fsc from code snippets in
# https://github.com/geoffwoollard/learn_cryoem_math/blob/master/nb/fsc.ipynb
# https://github.com/geoffwoollard/learn_cryoem_math/blob/master/nb/mFSC.ipynb
# https://github.com/geoffwoollard/learn_cryoem_math/blob/master/nb/guinier_fsc_sharpen.ipynb
n_pix = map_3d_f_1.shape[0]
fsc_1d = np.ones(n_pix//2)
return noise_estimate


fsc_1d = do_estimate_noise(map_3d_f_updated_1,map_3d_f_updated_2)

def do_expand_1d_3d(arr_1d):
n_pix = arr_1d.shape[0]*2
arr_3d = np.ones((n_pix,n_pix,n_pix))
# TODO: arr_1d fsc_1d to 3d (spherical shells)
return arr_3d

fsc_3d = do_expand_1d_3d(fsc_1d)

# multiplicative filter on maps with fsc
# The FSC is 1D, one number per spherical shells
# it can be expanded back to a multiplicative filter of the same shape as the maps
map_3d_f_filtered_1 = map_3d_f_updated_1*fsc_3d
map_3d_f_filtered_2 = map_3d_f_updated_2*fsc_3d

# update iteration
half_map_3d_f_1 = map_3d_f_filtered_1
half_map_3d_f_2 = map_3d_f_filtered_2

# final map
fsc_1d = do_estimate_noise(half_map_3d_f_1,half_map_3d_f_2)
fsc_3d = do_expand_1d_3d(fsc_1d)
map_3d_f_final = (half_map_3d_f_1 + half_map_3d_f_2 / 2)*fsc_3d
map_3d_r_final = do_ifft(map_3d_f_final)
half_map_3d_r_1 = do_ifft(half_map_3d_f_1)
half_map_3d_r_2 = do_ifft(half_map_3d_f_2)

return map_3d_r_final, half_map_3d_r_1, half_map_3d_r_2, fsc_1d









9 changes: 9 additions & 0 deletions tests/test_expectation_maximization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def test_do_iterative_refinement():
n_pix = 64
map_3d_init = np.random.normal(size=n_pix**3).reshape(n_pix,n_pix,n_pix)
particles = np.random.normal(size=n_pix**2).reshape(n_pix,n_pix)
jedyeo marked this conversation as resolved.
Show resolved Hide resolved
map_3d_r_final, half_map_3d_r_1, half_map_3d_r_2, fsc_1d = iterative_refinement(map_3d_init, particles)
assert map_3d_r_final.shape == (n_pix,n_pix,n_pix)
assert fsc_1d.dtype = np.float32
assert half_map_3d_r_1.dtype = np.float32
assert half_map_3d_r_2.dtype = np.float32