Skip to content

Commit

Permalink
happy linter
Browse files Browse the repository at this point in the history
  • Loading branch information
dkazanc committed Jun 19, 2023
1 parent dce5055 commit 0dace70
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 20 deletions.
1 change: 1 addition & 0 deletions httomo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import IntEnum, unique
from pathlib import Path


@unique
class PipelineTasks(IntEnum):
"""An enumeration of available pipeline stages."""
Expand Down
1 change: 0 additions & 1 deletion httomo/data/hdf/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from httomo.utils import Colour, _parse_preview, log_once, log_rank



@dataclass
class LoaderData:
data: ndarray
Expand Down
23 changes: 13 additions & 10 deletions httomo/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def _get_method_funcs(yaml_config: Path, comm: MPI.Comm) -> List[MethodFunc]:
parameters=method_conf,
cpu=True if not is_httomolibgpu else wrapper_init_module.meta.cpu,
gpu=False if not is_httomolibgpu else wrapper_init_module.meta.gpu,
calc_max_slices=None if not is_httomolibgpu else wrapper_init_module.calc_max_slices,
calc_max_slices=None
if not is_httomolibgpu
else wrapper_init_module.calc_max_slices,
reslice_ahead=False,
pattern=Pattern.all,
is_loader=False,
Expand Down Expand Up @@ -810,7 +812,9 @@ def run_method(
recon_algorithm = None
if recon_center is not None:
slice_dim = 1
recon_algorithm = dict_params_method.pop("algorithm", None) # covers tomopy case
recon_algorithm = dict_params_method.pop(
"algorithm", None
) # covers tomopy case
else:
slice_dim = _get_slicing_dim(current_func.pattern)

Expand Down Expand Up @@ -1121,7 +1125,7 @@ def _get_available_gpu_memory(safety_margin_percent: float = 10.0) -> int:
def _update_max_slices(
section: PlatformSection,
process_data_shape: Optional[Tuple[int, int, int]],
input_data_type: Optional[np.dtype]
input_data_type: Optional[np.dtype],
) -> Tuple[np.dtype, Tuple[int, int]]:
if process_data_shape is None or input_data_type is None:
return
Expand All @@ -1142,20 +1146,19 @@ def _update_max_slices(
output_dims = non_slice_dims_shape
if section.gpu:
available_memory = _get_available_gpu_memory(10.0)
available_memory_in_GB = round(available_memory/(1024**3),2)
max_slices_methods = [max_slices]*len(section.methods)
available_memory_in_GB = round(available_memory / (1024**3), 2)
max_slices_methods = [max_slices] * len(section.methods)
idx = 0
for m in section.methods:
if m.calc_max_slices is not None:
(slices_estimated, data_type, output_dims) = m.calc_max_slices(
slice_dim,
non_slice_dims_shape,
data_type,
available_memory
slice_dim, non_slice_dims_shape, data_type, available_memory
)
max_slices_methods[idx] = min(max_slices, slices_estimated)
idx += 1
non_slice_dims_shape = output_dims # overwrite input dims with estimated output ones
non_slice_dims_shape = (
output_dims # overwrite input dims with estimated output ones
)
section.max_slices = min(max_slices_methods)
else:
# TODO: How do we determine the output dtype in functions that aren't on GPU, tomopy, etc.
Expand Down
7 changes: 2 additions & 5 deletions httomo/wrappers_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mpi4py.MPI import Comm


def _gpumem_cleanup():
"""cleans up GPU memory and also the FFT plan cache"""
if gpu_enabled:
Expand Down Expand Up @@ -358,9 +359,5 @@ def calc_max_slices(
default_args[name] = par.default
kwargs = {**default_args, **self.dict_params}
return self.meta.calc_max_slices(
slice_dim,
non_slice_dims_shape,
dtype,
available_memory,
**kwargs
slice_dim, non_slice_dims_shape, dtype, available_memory, **kwargs
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def sample_pipelines():
def gpu_pipeline():
return "samples/pipeline_template_examples/03_basic_gpu_pipeline_tomo_standard.yaml"


@pytest.fixture
def merge_yamls():
def _merge_yamls(*yamls) -> None:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_method_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_get_invalid_method():
with pytest.raises(KeyError, match="key doesntexist is not present"):
get_method_info("tomopy.misc.corr", "doesntexist", "pattern")


def test_get_invalid_attr():
with pytest.raises(KeyError, match="attribute doesntexist is not present"):
get_method_info("tomopy.misc.corr", "median_filter", "doesntexist")
Expand All @@ -45,7 +46,9 @@ def test_httomolibgpu_memfunc():
def test_httomolibgpu_meta():
from httomolibgpu import MethodMeta

assert isinstance(get_httomolibgpu_method_meta("prep.normalize.normalize"), MethodMeta)
assert isinstance(
get_httomolibgpu_method_meta("prep.normalize.normalize"), MethodMeta
)


def test_httomolibgpu_meta_incomplete_path():
Expand Down
7 changes: 5 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def test_i12_testing_pipeline_output(
assert "Saving intermediate file: 6-tomopy-recon-tomo-gridrec.h5" in log_contents
assert "INFO | ~~~ Pipeline finished ~~~" in log_contents


def test_diad_testing_pipeline_output(
cmd, diad_data, diad_loader, testing_pipeline, output_folder, merge_yamls
):
Expand Down Expand Up @@ -383,6 +384,7 @@ def test_sweep_pipeline_with_save_all_using_mpi(
log_files = list(filter(lambda x: ".log" in x, files))
assert len(log_files) == 2


"""
# Something weird going on here with the logs
Expand All @@ -403,6 +405,7 @@ def test_sweep_pipeline_with_save_all_using_mpi(
)
"""


def test_sweep_range_pipeline_with_step_absent(
cmd, standard_data, sample_pipelines, output_folder
):
Expand Down Expand Up @@ -440,13 +443,13 @@ def test_multi_inputs_pipeline(cmd, standard_data, sample_pipelines, output_fold

with h5py.File(h5_files[0], "r") as f:
arr = np.array(f["data"])
assert arr.shape == (20, 128, 160)
assert arr.shape == (20, 128, 160)
assert arr.dtype == np.uint16
with h5py.File(h5_files[1], "r") as f:
arr = np.array(f["data"])
assert arr.shape == (20, 128, 160)
assert arr.dtype == np.uint16
with h5py.File(h5_files[2], "r") as f:
arr = np.array(f["data"])
assert arr.shape == (180, 128, 160)
assert arr.shape == (180, 128, 160)
assert arr.dtype == np.uint16
2 changes: 1 addition & 1 deletion tests/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_determine_platform_sections_platform_change() -> None:
"sino-all-sino",
"all-sino-sino",
"all-all-all",
]
],
)
def test_determine_platform_sections_pattern_all_combine(
pattern1: Pattern, pattern2: Pattern, expected: Pattern
Expand Down

0 comments on commit 0dace70

Please sign in to comment.