diff --git a/CHANGELOG.md b/CHANGELOG.md index b83c5f43b2..7f7102b335 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ * 21.2.1 + - CCPi Regularisation plugin is refactored, only FGP_TV, FGP_dTV, TGV and TNV are exposed. Docstrings and functionality unit tests are added. Tests of the functions are meant to be in the CCPi-Regularisation toolkit itself. - Add dtype for ImageGeometry, AcquisitionGeometry, VectorGeometry, BlockGeometry - Fix GradientOperator to handle pseudo 2D CIL geometries - Created Reconstructor base class for simpler use of CIL methods diff --git a/Wrappers/Python/cil/framework/framework.py b/Wrappers/Python/cil/framework/framework.py index 9841ba8a8c..eb6c93bf0f 100644 --- a/Wrappers/Python/cil/framework/framework.py +++ b/Wrappers/Python/cil/framework/framework.py @@ -2401,11 +2401,6 @@ def log(self, *args, **kwargs): '''Applies log pixel-wise to the DataContainer''' return self.pixel_wise_unary(numpy.log, *args, **kwargs) - #def __abs__(self): - # operation = FM.OPERATION.ABS - # return self.callFieldMath(operation, None, self.mask, self.maskOnValue) - # __abs__ - ## reductions def sum(self, *args, **kwargs): return self.as_array().sum(*args, **kwargs) @@ -3085,7 +3080,12 @@ def get_order_for_engine(engine, geometry): if isinstance(geometry, AcquisitionGeometry): dim_order = DataOrder.TIGRE_AG_LABELS else: - dim_order = DataOrder.TIGRE_IG_LABELS + dim_order = DataOrder.TIGRE_IG_LABELS + elif engine == 'cil': + if isinstance(geometry, AcquisitionGeometry): + dim_order = DataOrder.CIL_AG_LABELS + else: + dim_order = DataOrder.CIL_IG_LABELS else: raise ValueError("Unknown engine expected 'tigre' or 'astra' got {}".format(engine)) diff --git a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/__init__.py b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/__init__.py index aea6902e7d..1b27aef58e 100644 --- a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/__init__.py +++ b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/__init__.py @@ -15,5 +15,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .regularisers import FGP_TV, ROF_TV, TGV, LLT_ROF, FGP_dTV,\ - SB_TV, TNV +from .regularisers import FGP_TV, TGV, FGP_dTV, TNV diff --git a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py index 9940c95766..b0d39f6222 100644 --- a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py +++ b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py @@ -25,21 +25,36 @@ "Minimal version is 20.04") -from cil.framework import DataContainer +from cil.framework import DataOrder from cil.optimisation.functions import Function import numpy as np import warnings +from numbers import Number class RegulariserFunction(Function): def proximal(self, x, tau, out=None): + '''Generic proximal method for a RegulariserFunction + + :param x: image to be regularised + :type x: an ImageData + :param tau: + :type tau: Number + :param out: a placeholder for the result + :type out: same as x: ImageData + + If the ImageData contains complex data, rather than the default float32, the regularisation + is run indipendently on the real and imaginary part. + ''' + + self.check_input(x) arr = x.as_array() if arr.dtype in [np.complex, np.complex64]: # do real and imag part indep in_arr = np.asarray(arr.real, dtype=np.float32, order='C') - res, info = self.proximal_numpy(in_arr, tau, out) + res, info = self.proximal_numpy(in_arr, tau) arr.real = res[:] in_arr = np.asarray(arr.imag, dtype=np.float32, order='C') - res, info = self.proximal_numpy(in_arr, tau, out) + res, info = self.proximal_numpy(in_arr, tau) arr.imag = res[:] self.info = info if out is not None: @@ -50,7 +65,7 @@ def proximal(self, x, tau, out=None): return out else: arr = np.asarray(x.as_array(), dtype=np.float32, order='C') - res, info = self.proximal_numpy(arr, tau, out) + res, info = self.proximal_numpy(arr, tau) self.info = info if out is not None: out.fill(res) @@ -58,9 +73,12 @@ def proximal(self, x, tau, out=None): out = x.copy() out.fill(res) return out - def proximal_numpy(self, xarr, tau, out=None): + def proximal_numpy(self, xarr, tau): raise NotImplementedError('Please implement proximal_numpy') + def check_input(self, input): + pass + class TV_Base(RegulariserFunction): def __call__(self,x): in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C') @@ -70,28 +88,25 @@ def __call__(self,x): def convex_conjugate(self,x): return 0.0 -class ROF_TV(TV_Base): - def __init__(self,lambdaReg,iterationsTV,tolerance,time_marchstep,device): - # set parameters - self.alpha = lambdaReg - self.max_iteration = iterationsTV - self.time_marchstep = time_marchstep - self.device = device # string for 'cpu' or 'gpu' - self.tolerance = tolerance - - def proximal_numpy(self, in_arr, tau, out = None): - res , info = regularisers.ROF_TV(in_arr, - self.alpha, - self.max_iteration, - self.time_marchstep, - self.tolerance, - self.device) - - return res, info class FGP_TV(TV_Base): - def __init__(self, alpha=1, max_iteration=100, tolerance=1e-6, isotropic=True, nonnegativity=True, printing=False, device='cpu'): + def __init__(self, alpha=1, max_iteration=100, tolerance=0, isotropic=True, nonnegativity=True, device='cpu'): + '''Creator of FGP_TV Function + + :param alpha: regularisation parameter + :type alpha: number, default 1 + :param isotropic: Whether it uses L2 (isotropic) or L1 (unisotropic) norm + :type isotropic: boolean, default True, can range between 1 and 2 + :param nonnegativity: Whether to add the non-negativity constraint + :type nonnegativity: boolean, default True + :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached + :type max_iteration: integer, default 100 + :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. + :type tolerance: float, default 0 + :param device: determines if the code runs on CPU or GPU + :type device: string, default 'cpu', can be 'gpu' if GPU is installed + ''' if isotropic == True: self.methodTV = 0 else: @@ -108,40 +123,83 @@ def __init__(self, alpha=1, max_iteration=100, tolerance=1e-6, isotropic=True, n self.nonnegativity = nonnegativity self.device = device # string for 'cpu' or 'gpu' - def proximal_numpy(self, in_arr, tau, out = None): + def proximal_numpy(self, in_arr, tau): res , info = regularisers.FGP_TV(\ in_arr,\ - self.alpha*tau,\ + self.alpha * tau,\ self.max_iteration,\ self.tolerance,\ self.methodTV,\ self.nonnegativity,\ self.device) return res, info + + def __rmul__(self, scalar): + '''Define the multiplication with a scalar + + this changes the regularisation parameter in the plugin''' + if not isinstance (scalar, Number): + raise NotImplemented + else: + self.alpha *= scalar + return self + def check_input(self, input): + if input.geometry.length > 3: + raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) class TGV(RegulariserFunction): - def __init__(self, regularisation_parameter, alpha1, alpha2, iter_TGV, LipshitzConstant, torelance, device ): - self.regularisation_parameter = regularisation_parameter - self.alpha1 = alpha1 - self.alpha2 = alpha2 - self.iter_TGV = iter_TGV - self.LipshitzConstant = LipshitzConstant - self.torelance = torelance + def __init__(self, alpha=1, gamma=1, max_iteration=100, tolerance=0, device='cpu' , **kwargs): + '''Creator of Total Generalised Variation Function + + :param alpha: regularisation parameter + :type alpha: number, default 1 + :param gamma: ratio of TGV terms + :type gamma: number, default 1, can range between 1 and 2 + :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached + :type max_iteration: integer, default 100 + :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. + :type tolerance: float, default 0 + :param device: determines if the code runs on CPU or GPU + :type device: string, default 'cpu', can be 'gpu' if GPU is installed + + ''' + + self.alpha = alpha + self.gamma = gamma + self.max_iteration = max_iteration + self.tolerance = tolerance self.device = device + + if kwargs.get('iter_TGV', None) is not None: + # raise ValueError('iter_TGV parameter has been superseded by num_iter. Use that instead.') + self.num_iter = kwargs.get('iter_TGV') def __call__(self,x): warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__)) return np.nan + @property + def gamma(self): + return self.__gamma + @gamma.setter + def gamma(self, value): + if value <= 2 and value >= 1: + self.__gamma = value + @property + def alpha2(self): + return self.alpha1 * self.gamma + @property + def alpha1(self): + return 1. - def proximal_numpy(self, in_arr, tau, out = None): + def proximal_numpy(self, in_arr, tau): res , info = regularisers.TGV(in_arr, - self.regularisation_parameter, + self.alpha * tau, self.alpha1, self.alpha2, - self.iter_TGV, + self.max_iteration, self.LipshitzConstant, - self.torelance, + self.tolerance, self.device) # info: return number of iteration and reached tolerance @@ -152,46 +210,51 @@ def proximal_numpy(self, in_arr, tau, out = None): def convex_conjugate(self, x): warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__)) return np.nan - -class LLT_ROF(RegulariserFunction): - - def __init__(self, regularisation_parameterROF, - regularisation_parameterLLT, - iter_LLT_ROF, time_marching_parameter, torelance, device ): - self.regularisation_parameterROF = regularisation_parameterROF - self.regularisation_parameterLLT = regularisation_parameterLLT - self.iter_LLT_ROF = iter_LLT_ROF - self.time_marching_parameter = time_marching_parameter - self.torelance = torelance - self.device = device + def __rmul__(self, scalar): + '''Define the multiplication with a scalar - def __call__(self,x): - warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__)) - return np.nan - - def proximal_numpy(self, in_arr, tau, out = None): - res , info = regularisers.LLT_ROF(in_arr, - self.regularisation_parameterROF, - self.regularisation_parameterLLT, - self.iter_LLT_ROF, - self.time_marching_parameter, - self.torelance, - self.device) - - # info: return number of iteration and reached tolerance - # https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/master/src/Core/regularisers_CPU/TGV_core.c#L168 - # Stopping Criteria || u^k - u^(k-1) ||_{2} / || u^{k} ||_{2} - - return res, info + this changes the regularisation parameter in the plugin''' + if not isinstance (scalar, Number): + raise NotImplemented + else: + self.alpha *= scalar + return self - def convex_conjugate(self, x): - warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__)) - return np.nan + # f = TGV() + # f = alpha * f + + def check_input(self, input): + if len(input.dimension_labels) == 2: + self.LipshitzConstant = 12 + elif len(input.dimension_labels) == 3: + self.LipshitzConstant = 16 # Vaggelis to confirm + else: + raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) + class FGP_dTV(RegulariserFunction): + '''Creator of FGP_dTV Function + + :param reference: reference image + :type reference: ImageData + :param alpha: regularisation parameter + :type alpha: number, default 1 + :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached + :type max_iteration: integer, default 100 + :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. + :type tolerance: float, default 0 + :param eta: smoothing constant to calculate gradient of the reference + :type eta: number, default 0.01 + :param isotropic: Whether it uses L2 (isotropic) or L1 (anisotropic) norm + :type isotropic: boolean, default True, can range between 1 and 2 + :param nonnegativity: Whether to add the non-negativity constraint + :type nonnegativity: boolean, default True + :param device: determines if the code runs on CPU or GPU + :type device: string, default 'cpu', can be 'gpu' if GPU is installed + ''' def __init__(self, reference, alpha=1, max_iteration=100, - tolerance=1e-6, eta=0.01, isotropic=True, nonnegativity=True, device='cpu'): + tolerance=0, eta=0.01, isotropic=True, nonnegativity=True, device='cpu'): if isotropic == True: self.methodTV = 0 @@ -214,11 +277,11 @@ def __call__(self,x): warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__)) return np.nan - def proximal_numpy(self, in_arr, tau, out = None): + def proximal_numpy(self, in_arr, tau): res , info = regularisers.FGP_dTV(\ in_arr,\ self.reference,\ - self.alpha*tau,\ + self.alpha * tau,\ self.max_iteration,\ self.tolerance,\ self.eta,\ @@ -231,47 +294,68 @@ def convex_conjugate(self, x): warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__)) return np.nan -class SB_TV(TV_Base): - def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,printing,device): - # set parameters - self.alpha = lambdaReg - self.max_iteration = iterationsTV - self.tolerance = tolerance - self.methodTV = methodTV - self.printing = printing - self.device = device # string for 'cpu' or 'gpu' - - def proximal_numpy(self, in_arr, tau, out = None): - res , info = regularisers.SB_TV(in_arr, - self.alpha*tau, - self.max_iteration, - self.tolerance, - self.methodTV, - self.device) + def __rmul__(self, scalar): + '''Define the multiplication with a scalar - return res, info + this changes the regularisation parameter in the plugin''' + if not isinstance (scalar, Number): + raise NotImplemented + else: + self.alpha *= scalar + return self + + def check_input(self, input): + if input.geometry.length > 3: + raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) class TNV(RegulariserFunction): - def __init__(self,regularisation_parameter,iterationsTNV,tolerance): - + def __init__(self,alpha=1, max_iteration=100, tolerance=0): + '''Creator of TNV Function + + :param alpha: regularisation parameter + :type alpha: number, default 1 + :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached + :type max_iteration: integer, default 100 + :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. + :type tolerance: float, default 0 + ''' # set parameters - self.regularisation_parameter = regularisation_parameter - self.iterationsTNV = iterationsTNV + self.alpha = alpha + self.max_iteration = max_iteration self.tolerance = tolerance def __call__(self,x): warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__)) return np.nan - def proximal_numpy(self, in_arr, tau, out = None): + def proximal_numpy(self, in_arr, tau): + if in_arr.ndim != 3: + # https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/413c6001003c6f1272aeb43152654baaf0c8a423/src/Python/src/cpu_regularisers.pyx#L584-L588 + raise ValueError('Only 3D data is supported. Passed data has {} dimensions'.format(in_arr.ndim)) res = regularisers.TNV(in_arr, - self.regularisation_parameter, - self.iterationsTNV, + self.alpha * tau, + self.max_iteration, self.tolerance) - return res, [] def convex_conjugate(self, x): warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__)) - return np.nan \ No newline at end of file + return np.nan + + def __rmul__(self, scalar): + '''Define the multiplication with a scalar + + this changes the regularisation parameter in the plugin''' + if not isinstance (scalar, Number): + raise NotImplemented + else: + self.alpha *= scalar + return self + + def check_input(self, input): + '''TNV requires 2D+channel data with the first dimension as the channel dimension''' + DataOrder.check_order_for_engine('cil', input.geometry) + if ( input.geometry.channels == 1 ) or ( not input.geometry.length == 3) : + raise ValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels)) + \ No newline at end of file diff --git a/Wrappers/Python/test/test_PluginsRegularisation.py b/Wrappers/Python/test/test_PluginsRegularisation.py index 8950af7f10..372ea5f676 100644 --- a/Wrappers/Python/test/test_PluginsRegularisation.py +++ b/Wrappers/Python/test/test_PluginsRegularisation.py @@ -40,10 +40,11 @@ # "Minimal version is 20.04") has_regularisation_toolkit = False print ("has_regularisation_toolkit", has_regularisation_toolkit) +TNV_fixed = False class TestPlugin(unittest.TestCase): def setUp(self): - print ("test plugins") + # print ("test plugins") pass def tearDown(self): pass @@ -55,13 +56,7 @@ def test_import_FGP_TV(self): except ModuleNotFoundError as ie: print (ie) assert False - @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") - def test_import_ROF_TV(self): - try: - from cil.plugins.ccpi_regularisation.functions import ROF_TV - assert True - except ModuleNotFoundError as ie: - assert False + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") def test_import_TGV(self): try: @@ -69,13 +64,7 @@ def test_import_TGV(self): assert True except ModuleNotFoundError as ie: assert False - @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") - def test_import_LLT_ROF(self): - try: - from cil.plugins.ccpi_regularisation.functions import LLT_ROF - assert True - except ModuleNotFoundError as ie: - assert False + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") def test_import_FGP_dTV(self): try: @@ -83,13 +72,7 @@ def test_import_FGP_dTV(self): assert True except ModuleNotFoundError as ie: assert False - @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") - def test_import_SB_TV(self): - try: - from cil.plugins.ccpi_regularisation.functions import SB_TV - assert True - except ModuleNotFoundError as ie: - assert False + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") def test_import_TNV(self): try: @@ -97,6 +80,7 @@ def test_import_TNV(self): assert True except ModuleNotFoundError as ie: assert False + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") def test_FGP_TV_complex(self): data = dataexample.CAMERA.get(size=(256,256)) @@ -108,4 +92,209 @@ def test_FGP_TV_complex(self): reg = FGP_TV() out = reg.proximal(data, 1) outarr = out.as_array() - np.testing.assert_almost_equal(outarr.imag, outarr.real) \ No newline at end of file + np.testing.assert_almost_equal(outarr.imag, outarr.real) + + def rmul_test(self, f): + + alpha = f.alpha + scalar = 2.123 + af = scalar*f + + assert (id(af) == id(f)) + assert af.alpha == scalar * alpha + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_TV_rmul(self): + from cil.plugins.ccpi_regularisation.functions import FGP_TV + f = FGP_TV() + + self.rmul_test(f) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_TGV_rmul(self): + from cil.plugins.ccpi_regularisation.functions import FGP_TGV + f = FGP_TGV() + + self.rmul_test(f) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_TGV_rmul(self): + from cil.plugins.ccpi_regularisation.functions import TNV + f = TNV() + + self.rmul_test(f) + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_dTV_rmul(self): + from cil.plugins.ccpi_regularisation.functions import FGP_dTV + data = dataexample.CAMERA.get(size=(256,256)) + f = FGP_dTV(data) + + self.rmul_test(f) + + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_functionality_FGP_TV(self): + + data = dataexample.CAMERA.get(size=(256,256)) + datarr = data.as_array() + from cil.plugins.ccpi_regularisation.functions import FGP_TV + from ccpi.filters import regularisers + + tau = 1. + fcil = FGP_TV() + outcil = fcil.proximal(data, tau=tau) + # use CIL defaults + outrgl, info = regularisers.FGP_TV(datarr, fcil.alpha*tau, fcil.max_iteration, fcil.tolerance, 0, 1, 'cpu' ) + np.testing.assert_almost_equal(outrgl, outcil.as_array()) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_functionality_TGV(self): + + data = dataexample.CAMERA.get(size=(256,256)) + datarr = data.as_array() + from cil.plugins.ccpi_regularisation.functions import TGV + from ccpi.filters import regularisers + + tau = 1. + fcil = TGV() + outcil = fcil.proximal(data, tau=tau) + # use CIL defaults + outrgl, info = regularisers.TGV(datarr, fcil.alpha*tau, 1,1, fcil.max_iteration, 12, fcil.tolerance, 'cpu' ) + + np.testing.assert_almost_equal(outrgl, outcil.as_array()) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_functionality_FGP_dTV(self): + + data = dataexample.CAMERA.get(size=(256,256)) + datarr = data.as_array() + ref = data*0.3 + from cil.plugins.ccpi_regularisation.functions import FGP_dTV + from ccpi.filters import regularisers + + tau = 1. + fcil = FGP_dTV(ref) + outcil = fcil.proximal(data, tau=tau) + # use CIL defaults + outrgl, info = regularisers.FGP_dTV(datarr, ref.as_array(), fcil.alpha*tau, fcil.max_iteration, fcil.tolerance, 0.01, 0, 1, 'cpu' ) + np.testing.assert_almost_equal(outrgl, outcil.as_array()) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_functionality_TNV(self): + + # fake a 2D+channel image + d = dataexample.SYNCHROTRON_PARALLEL_BEAM_DATA.get() + ig = ImageGeometry(160, 135, channels=91) + data = ig.allocate(None) + data.fill(d) + del d + + datarr = data.as_array() + from cil.plugins.ccpi_regularisation.functions import TNV + from ccpi.filters import regularisers + + tau = 1. + + # CIL defaults + outrgl = regularisers.TNV(datarr, 1, 100, 1e-6 ) + + fcil = TNV() + outcil = fcil.proximal(data, tau=tau) + np.testing.assert_almost_equal(outrgl, outcil.as_array()) + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_TNV_raise_on_2D(self): + + # data = dataexample.SYNCHROTRON_PARALLEL_BEAM_DATA.get() + data = dataexample.CAMERA.get(size=(256,256)) + datarr = data.as_array() + from cil.plugins.ccpi_regularisation.functions import TNV + + tau = 1. + + fcil = TNV() + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_TNV_raise_on_3D_nochannel(self): + + # data = dataexample.SYNCHROTRON_PARALLEL_BEAM_DATA.get() + data = dataexample.CAMERA.get(size=(256,256)) + datarr = data.as_array() + from cil.plugins.ccpi_regularisation.functions import TNV + + tau = 1. + + fcil = TNV() + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_TNV_raise_on_4D(self): + + from cil.plugins.ccpi_regularisation.functions import TNV + + data = ImageGeometry(3,4,5,channels=5).allocate(1) + + tau = 1. + + fcil = TNV() + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_TV_raise_on_4D_data(self): + + from cil.plugins.ccpi_regularisation.functions import FGP_TV + + tau = 1. + fcil = FGP_TV() + data = ImageGeometry(3,4,5,channels=10).allocate(0) + + + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True + + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_TGV_raise_on_4D_data(self): + + from cil.plugins.ccpi_regularisation.functions import TGV + + tau = 1. + fcil = TGV() + data = ImageGeometry(3,4,5,channels=10).allocate(0) + + + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True + @unittest.skipUnless(has_regularisation_toolkit, "Skipping as CCPi Regularisation Toolkit is not installed") + def test_FGP_dTV_raise_on_4D_data(self): + + from cil.plugins.ccpi_regularisation.functions import FGP_dTV + + tau = 1. + + data = ImageGeometry(3,4,5,channels=10).allocate(0) + ref = data * 2 + + fcil = FGP_dTV(ref) + + try: + outcil = fcil.proximal(data, tau=tau) + assert False + except ValueError: + assert True diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 328604b0f6..a061426a16 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -988,10 +988,8 @@ def test_compare_regularisation_toolkit(self): r_tolerance = 1e-9 r_iso = True r_nonneg = True - r_printing = 0 - # g_CCPI_reg_toolkit = alpha * FGP_TV(1., r_iterations, r_tolerance, r_iso, r_nonneg, r_printing, 'cpu') g_CCPI_reg_toolkit = alpha * FGP_TV(max_iteration=r_iterations, tolerance=r_tolerance, - isotropic=r_iso, nonnegativity=r_nonneg, printing=r_printing, device='cpu') + isotropic=r_iso, nonnegativity=r_nonneg, device='cpu') t2 = timer() res2 = g_CCPI_reg_toolkit.proximal(noisy_data, 1.) @@ -1021,10 +1019,8 @@ def test_compare_regularisation_toolkit(self): r_tolerance = 1e-9 r_iso = True r_nonneg = True - r_printing = 0 - # g_CCPI_reg_toolkit = alpha * FGP_TV(1., r_iterations, r_tolerance, r_iso, r_nonneg, r_printing, 'cpu') g_CCPI_reg_toolkit = alpha * FGP_TV(max_iteration=r_iterations, tolerance=r_tolerance, - isotropic=r_iso, nonnegativity=r_nonneg, printing=r_printing, device='cpu') + isotropic=r_iso, nonnegativity=r_nonneg, device='cpu') t2 = timer() res2 = g_CCPI_reg_toolkit.proximal(noisy_data, 1.) @@ -1073,9 +1069,8 @@ def test_compare_regularisation_toolkit_tomophantom(self): r_tolerance = 1e-9 r_iso = True r_nonneg = True - r_printing = 0 g_CCPI_reg_toolkit = alpha * FGP_TV(max_iteration=r_iterations, tolerance=r_tolerance, - isotropic=r_iso, nonnegativity=r_nonneg, printing=r_printing, device='cpu') + isotropic=r_iso, nonnegativity=r_nonneg, device='cpu') t2 = timer() @@ -1085,52 +1080,6 @@ def test_compare_regularisation_toolkit_tomophantom(self): np.testing.assert_allclose(res1.as_array(), res2.as_array(), atol=7.5e-2) - # the following were in the unit tests but didn't assert anything - # # CIL_FGP_TV no tolerance - # g_CIL.tolerance = None - # t0 = timer() - # res1 = g_CIL.proximal(noisy_data, 1.) - # t1 = timer() - # # print(t1-t0) - - # ################################################################### - # ################################################################### - # ################################################################### - # ################################################################### - - # data = dataexample.PEPPERS.get(size=(256, 256)) - # ig = data.geometry - # ag = ig - - # noisy_data = noise.gaussian(data, seed=10) - - # alpha = 0.1 - # iters = 1000 - - # # CIL_FGP_TV no tolerance - # g_CIL = alpha * TotalVariation(iters, tolerance=None) - # t0 = timer() - # res1 = g_CIL.proximal(noisy_data, 1.) - # t1 = timer() - # # print(t1-t0) - - # # CCPi Regularisation toolkit high tolerance - - # r_alpha = alpha - # r_iterations = iters - # r_tolerance = 1e-9 - # r_iso = True - # r_nonneg = True - # r_printing = 0 - # g_CCPI_reg_toolkit = alpha * FGP_TV(max_iteration=r_iterations, tolerance=r_tolerance, - # isotropic=r_iso, nonnegativity=r_nonneg, printing=r_printing, device='cpu') - - - # t2 = timer() - # res2 = g_CCPI_reg_toolkit.proximal(noisy_data, 1.) - # t3 = timer() - # # print (t3-t2) - class TestKullbackLeiblerNumba(unittest.TestCase): def setUp(self):