Skip to content

Commit

Permalink
Trying pre-commit to see if it fixes the imports
Browse files Browse the repository at this point in the history
  • Loading branch information
thisTyler committed Apr 2, 2022
1 parent d51f0bc commit b251e5c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Iterative refinement with Bayesian expectation maximization."""

import numpy as np

from geomstats.geometry import special_orthogonal
from simSPI.transfer import eval_ctf

Expand Down Expand Up @@ -205,7 +204,7 @@ def normalize_map(map_3d, counts, norm_const):
Shape (n_pix, n_pix, n_pix)
map normalized by counts.
"""
return map_3d * counts / (norm_const + counts**2)
return map_3d * counts / (norm_const + counts ** 2)

@staticmethod
def apply_noise_model(map_3d_f_norm_1, map_3d_f_norm_2):
Expand Down Expand Up @@ -360,7 +359,7 @@ def generate_slices(map_3d_f, xy_plane, n_pix, rots):
map_3d_f = np.ones_like(map_3d_f)
xyz_rotated = np.ones_like(xy_plane)

size = n_rotations * n_pix**2
size = n_rotations * n_pix ** 2
slices = np.random.normal(size=size)
slices = slices.reshape((n_rotations, n_pix, n_pix))
return slices, xyz_rotated
Expand Down
6 changes: 3 additions & 3 deletions tests/test_expectation_maximization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_grid_SO3_uniform(test_ir, n_particles):
def test_generate_xy_plane(test_ir, n_pix):
"""Test generation of xy plane."""
xy_plane = test_ir.generate_xy_plane(n_pix)
assert xy_plane.shape == (n_pix**2, 3)
assert xy_plane.shape == (n_pix ** 2, 3)


def test_generate_slices(test_ir, n_particles, n_pix):
Expand All @@ -101,9 +101,9 @@ def test_generate_slices(test_ir, n_particles, n_pix):

slices, xyz_rotated = test_ir.generate_slices(map_3d, xy_plane, n_pix, rots)

assert xy_plane.shape == (n_pix**2, 3)
assert xy_plane.shape == (n_pix ** 2, 3)
assert slices.shape == (n_particles, n_pix, n_pix)
assert xyz_rotated.shape == (n_pix**2, 3)
assert xyz_rotated.shape == (n_pix ** 2, 3)


def test_apply_ctf_to_slice(test_ir, n_pix):
Expand Down

0 comments on commit b251e5c

Please sign in to comment.