diff --git a/Wrappers/Python/cil/optimisation/utilities/callbacks.py b/Wrappers/Python/cil/optimisation/utilities/callbacks.py index a5c136695a..a94db702bd 100644 --- a/Wrappers/Python/cil/optimisation/utilities/callbacks.py +++ b/Wrappers/Python/cil/optimisation/utilities/callbacks.py @@ -4,7 +4,9 @@ from tqdm.auto import tqdm as tqdm_auto from tqdm.std import tqdm as tqdm_std import numpy as np - +from cil.processors import Slicer +import os +from cil.io import TIFFWriter class Callback(ABC): '''Base Callback to inherit from for use in :code:`Algorithm.run(callbacks: list[Callback])`. @@ -135,6 +137,7 @@ class LogfileCallback(TextProgressCallback): def __init__(self, log_file, mode='a', **kwargs): self.fd = open(log_file, mode=mode) super().__init__(file=self.fd, **kwargs) + class EarlyStoppingObjectiveValue(Callback): '''Callback that stops iterations if the change in the objective value is less than a provided threshold value. @@ -158,8 +161,9 @@ def __call__(self, algorithm): raise StopIteration class CGLSEarlyStopping(Callback): - '''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value. + r'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value. It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`, where `omega` is set to default as 1e6. + Parameters ---------- epsilon: float, default 1e-6 @@ -187,3 +191,45 @@ def __call__(self, algorithm): raise StopIteration +class SaveIterates(Callback): + r'''Callback to save iterates as tiff files every set number of iterations. + + Parameters + ---------- + interval: integer, + The iterates will be saved every `interval` number of iterations e.g. if `interval =4` the 0, 4, 8, 12,... iterates will be saved. + file_name : string + This defines the file name prefix, i.e. the file name without the extension. + dir_path : string + The place to store the images + roi: dict, optional default is None and no slicing will be applied + The region-of-interest to slice {'axis_name1':(start,stop,step), 'axis_name2':(start,stop,step)} + The `key` being the axis name to apply the processor to, the `value` holding a tuple containing the ROI description + Start: Starting index of input data. Must be an integer, or `None` defaults to index 0. + Stop: Stopping index of input data. Must be an integer, or `None` defaults to index N. + Step: Number of pixels to average together. Must be an integer or `None` defaults to 1. + compression : str, default None. Accepted values None, 'uint8', 'uint16' + The lossy compression to apply. The default None will not compress data. + uint8' or 'unit16' will compress to unsigned int 8 and 16 bit respectively. + ''' + def __init__(self, interval=1, file_name='iter', dir_path='./', roi=None, compression=None): + + self.file_path= os.path.join(dir_path, file_name) + + self.interval=interval + self.roi=roi + if self.roi is not None: + self.slicer= Slicer(roi=self.roi) + self.compression=compression + super(SaveIterates, self).__init__() + + def __call__(self, algo): + + if algo.iteration % self.interval ==0: + if self.roi is None: + TIFFWriter(data=algo.solution, file_name=self.file_path+f'_{algo.iteration:04d}.tiff', counter_offset=-1,compression=self.compression ).write() + else: + self.slicer.set_input(algo.solution) + TIFFWriter(self.slicer.get_output(), file_name=self.file_path+f'_{algo.iteration:04d}.tiff', counter_offset=-1,compression=self.compression ).write() + + diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index e3f6302d78..288fde8117 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -20,7 +20,7 @@ import unittest from os import unlink from tempfile import NamedTemporaryFile - +import os, glob import numpy as np import logging @@ -63,6 +63,8 @@ from unittest.mock import MagicMock +from cil.io import TIFFStackReader + log = logging.getLogger(__name__) initialise_tests() @@ -1428,6 +1430,92 @@ def test_EarlyStoppingObjectiveValue(self): callbacks.EarlyStoppingObjectiveValue(0.1)(alg) + +class TestSaveIteratesCallback(unittest.TestCase): + + class MockAlgo(Algorithm): + def __init__(self, initial, update_objective_interval=10, **kwargs): + super().__init__(update_objective_interval=update_objective_interval, **kwargs) + self.configured = True + self.x=initial + + def update(self): + self.x -= 1 + + def update_objective(self): + self.loss.append(2 ** getattr(self, 'x', np.nan)) + + + def setUp(self): + # Mock the algorithm object + + self.image_geometry = ImageGeometry(10, 2) + self.data = self.image_geometry.allocate(10) + self.mock_algorithm = self.MockAlgo(self.data) + self.file_name= 'myfile' + self.cwd = os.getcwd() + self.dir_path=os.path.join(self.cwd, 'test_tiff' ) + + def test_save_iterates_no_writer_no_roi(self): + # Test saving iterates to a list with no writer and no ROI + callback = callbacks.SaveIterates(interval=1, file_name= self.file_name, dir_path=self.dir_path) + + # Call the callback multiple times and increment iteration + self.mock_algorithm.run(5, callbacks=[callback]) + + # Check if iterates are saved correctly + files = glob.glob(os.path.join(glob.escape(self.dir_path), '*')) + assert len(files) == 6 + reader = TIFFStackReader(file_name = self.dir_path) + read = reader.read() + for i in range(6): + np.testing.assert_array_equal(read[i], (10-i)*np.ones((2,10))) + [os.remove(file) for file in files] + os.rmdir(self.dir_path) + + + def test_save_iterates_with_roi(self): + # Test saving iterates with an ROI applied + roi = {'horizontal_x': (0, 2, 1)} + + callback = callbacks.SaveIterates(interval=1, file_name= self.file_name, dir_path=self.dir_path, roi=roi) + + # Call the callback and check if slicer was used + callback(self.mock_algorithm) + # Check if iterates are saved correctly + files = glob.glob(os.path.join(glob.escape(self.dir_path), '*')) + assert len(files) == 1 + reader = TIFFStackReader(file_name = self.dir_path) + read = reader.read() + np.testing.assert_array_equal(read, 10*np.ones([2, 2])) + [os.remove(file) for file in files] + os.rmdir(self.dir_path) + + def test_save_iterates_with_interval(self): + # Test saving iterates with a specified interval + callback = callbacks.SaveIterates(interval=2, file_name= self.file_name, dir_path=self.dir_path) + + # Call the callback multiple times and increment iteration + self.mock_algorithm.run(5, callbacks=[callback]) + + # Check if iterates are saved correctly + files = glob.glob(os.path.join(glob.escape(self.dir_path), '*')) + print(files) + self.assertEqual( len(files), 3) + reader = TIFFStackReader(file_name = self.dir_path) + read = reader.read() + np.testing.assert_array_equal(read[0], (10-0)*np.ones((2,10))) + np.testing.assert_array_equal(read[1], (10-2)*np.ones((2,10))) + np.testing.assert_array_equal(read[2], (10-4)*np.ones((2,10))) + [os.remove(file) for file in files] + os.rmdir(self.dir_path) + + + + + + + class TestADMM(unittest.TestCase): def setUp(self): ig = ImageGeometry(2, 3, 2) diff --git a/docs/source/optimisation.rst b/docs/source/optimisation.rst index c7d1fa7eba..26aee43843 100644 --- a/docs/source/optimisation.rst +++ b/docs/source/optimisation.rst @@ -606,6 +606,16 @@ A list of :code:`Callback` s to be executed each iteration can be passed to `Alg Built-in callbacks include: +.. autoclass:: cil.optimisation.utilities.callbacks.SaveIterates + :members: + +.. autoclass:: cil.optimisation.utilities.callbacks.EarlyStoppingObjectiveValue + :members: + +.. autoclass:: cil.optimisation.utilities.callbacks.CGLSEarlyStopping + :members: + + .. autoclass:: cil.optimisation.utilities.callbacks.ProgressCallback :members: