diff --git a/CHANGELOG.md b/CHANGELOG.md index 484b6e2b9..6198f8407 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ - Bug fix for missing factor of 1/2 in SIRT update objective and catch in place errors in the SIRT constraint - Bug fix to allow safe in place calculation for the soft shrinkage algorithm - Allow Masker to take integer arrays in addition to boolean + - Add remote data class to example data to enable download of relevant datasets from remote repositories - Improved import error/warning messages - New adjoint operator - Bug fix for complex matrix adjoint diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index abf986b25..a6a4f020e 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -23,57 +23,91 @@ import os import os.path import sys -from cil.io import NEXUSDataReader - -data_dir = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../data/') -) - -# this is the default location after a conda install -data_dir = os.path.abspath( - os.path.join(sys.prefix, 'share','cil') -) +from zipfile import ZipFile +from urllib.request import urlopen +from io import BytesIO +from scipy.io import loadmat +from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader class DATA(object): @classmethod def dfile(cls): return None + +class CILDATA(DATA): + data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil')) @classmethod def get(cls, size=None, scale=(0,1), **kwargs): - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = TestData(data_dir=ddir) return loader.load(cls.dfile(), size, scale, **kwargs) + +class REMOTEDATA(DATA): + + FOLDER = '' + URL = '' + FILE_SIZE = '' + + @classmethod + def get(cls, data_dir): + return None + + @classmethod + def _download_and_extract_from_url(cls, data_dir): + with urlopen(cls.URL) as response: + with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path = data_dir) + + @classmethod + def download_data(cls, data_dir): + ''' + Download a dataset from a remote repository + + Parameters + ---------- + data_dir: str, optional + The path to the data directory where the downloaded data should be stored + + ''' + if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): + print("Dataset already exists in " + data_dir) + else: + if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": + print('Downloading dataset from ' + cls.URL) + cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) + print('Download complete') + else: + print('Download cancelled') -class BOAT(DATA): +class BOAT(CILDATA): @classmethod def dfile(cls): return TestData.BOAT -class CAMERA(DATA): +class CAMERA(CILDATA): @classmethod def dfile(cls): return TestData.CAMERA -class PEPPERS(DATA): +class PEPPERS(CILDATA): @classmethod def dfile(cls): return TestData.PEPPERS -class RESOLUTION_CHART(DATA): +class RESOLUTION_CHART(CILDATA): @classmethod def dfile(cls): return TestData.RESOLUTION_CHART -class SIMPLE_PHANTOM_2D(DATA): +class SIMPLE_PHANTOM_2D(CILDATA): @classmethod def dfile(cls): return TestData.SIMPLE_PHANTOM_2D -class SHAPES(DATA): +class SHAPES(CILDATA): @classmethod def dfile(cls): return TestData.SHAPES -class RAINBOW(DATA): +class RAINBOW(CILDATA): @classmethod def dfile(cls): return TestData.RAINBOW -class SYNCHROTRON_PARALLEL_BEAM_DATA(DATA): +class SYNCHROTRON_PARALLEL_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -90,11 +124,11 @@ def get(cls, **kwargs): The DLS dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs')) return loader.read() -class SIMULATED_PARALLEL_BEAM_DATA(DATA): +class SIMULATED_PARALLEL_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -111,11 +145,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs')) return loader.read() -class SIMULATED_CONE_BEAM_DATA(DATA): +class SIMULATED_CONE_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -132,11 +166,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs')) return loader.read() -class SIMULATED_SPHERE_VOLUME(DATA): +class SIMULATED_SPHERE_VOLUME(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -151,12 +185,117 @@ def get(cls, **kwargs): ------- ImageData The simulated spheres volume - ''' - - ddir = kwargs.get('data_dir', data_dir) + ''' + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) return loader.read() + +class WALNUT(REMOTEDATA): + ''' + A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + ''' + FOLDER = 'walnut' + URL = 'https://zenodo.org/record/4822516/files/walnut.zip' + FILE_SIZE = '6.4 GB' + + @classmethod + def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + This function returns the raw projection data from the .txrm file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) + + Returns + ------- + ImageData + The walnut dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) + +class USB(REMOTEDATA): + ''' + A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + ''' + FOLDER = 'USB' + URL = 'https://zenodo.org/record/4822516/files/usb.zip' + FILE_SIZE = '3.2 GB' + + @classmethod + def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + This function returns the raw projection data from the .txrm file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) + + Returns + ------- + ImageData + The usb dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) + +class KORN(REMOTEDATA): + ''' + A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + ''' + FOLDER = 'korn' + URL = 'https://zenodo.org/record/6874123/files/korn.zip' + FILE_SIZE = '2.9 GB' + + @classmethod + def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + This function returns the raw projection data from the .xtekct file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir) + + Returns + ------- + ImageData + The korn dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') + try: + loader = NikonDataReader(file_name=filepath) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset .xtekct file not found in: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) + + +class SANDSTONE(REMOTEDATA): + ''' + A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + ''' + FOLDER = 'sandstone' + URL = 'https://zenodo.org/records/4912435/files/small.zip' + FILE_SIZE = '227 MB' class TestData(object): '''Class to return test data @@ -178,9 +317,9 @@ class TestData(object): SHAPES = 'shapes.png' RAINBOW = 'rainbow.png' - def __init__(self, **kwargs): - self.data_dir = kwargs.get('data_dir', data_dir) - + def __init__(self, data_dir): + self.data_dir = data_dir + def load(self, which, size=None, scale=(0,1), **kwargs): ''' Return a test data of the requested image @@ -506,4 +645,4 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): if clip: out = np.clip(out, low_clip, 1.0) - return out + return out \ No newline at end of file diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index fa6e0d49e..2baf4eaad 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -21,10 +21,14 @@ from cil.framework.framework import ImageGeometry,AcquisitionGeometry from cil.utilities import dataexample from cil.utilities import noise -import os, sys +import os, sys, shutil from testclass import CCPiTestClass import platform import numpy as np +from unittest.mock import patch, MagicMock +from urllib import request +from zipfile import ZipFile +from io import StringIO initialise_tests() @@ -147,4 +151,77 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): .set_panel((128,128),(64,64))\ .set_angles(np.linspace(0,360,300,False)) - self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch") + self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch") + +class TestRemoteData(unittest.TestCase): + + def setUp(self): + + self.data_list = ['WALNUT','USB','KORN','SANDSTONE'] + self.tmp_file = 'tmp.txt' + self.tmp_zip = 'tmp.zip' + with ZipFile(self.tmp_zip, 'w') as zipped_file: + zipped_file.writestr(self.tmp_file, np.array([1, 2, 3])) + with open(self.tmp_zip, 'rb') as zipped_file: + self.zipped_bytes = zipped_file.read() + + def tearDown(self): + for data in self.data_list: + test_func = getattr(dataexample, data) + if os.path.exists(os.path.join(test_func.FOLDER)): + shutil.rmtree(test_func.FOLDER) + + if os.path.exists(self.tmp_zip): + os.remove(self.tmp_zip) + + if os.path.exists(self.tmp_file): + os.remove(self.tmp_file) + + def mock_urlopen(self, mock_urlopen): + mock_response = MagicMock() + mock_response.read.return_value = self.zipped_bytes + mock_response.__enter__.return_value = mock_response + mock_urlopen.return_value = mock_response + + @patch('cil.utilities.dataexample.urlopen') + def test_unzip_remote_data(self, mock_urlopen): + self.mock_urlopen(mock_urlopen) + dataexample.REMOTEDATA._download_and_extract_from_url('.') + self.assertTrue(os.path.isfile(self.tmp_file)) + + @patch('cil.utilities.dataexample.input', return_value='n') + @patch('cil.utilities.dataexample.urlopen') + def test_download_data_input_n(self, mock_urlopen, input): + self.mock_urlopen(mock_urlopen) + + data_list = ['WALNUT','USB','KORN','SANDSTONE'] + for data in data_list: + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + test_func = getattr(dataexample, data) + test_func.download_data('.') + + self.assertFalse(os.path.isfile(self.tmp_file)) + self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n') + + # return to standard print output + sys.stdout = sys.__stdout__ + + @patch('cil.utilities.dataexample.input', return_value='y') + @patch('cil.utilities.dataexample.urlopen') + def test_download_data_input_y(self, mock_urlopen, input): + self.mock_urlopen(mock_urlopen) + + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + + + for data in self.data_list: + test_func = getattr(dataexample, data) + test_func.download_data('.') + self.assertTrue(os.path.isfile(os.path.join(test_func.FOLDER,self.tmp_file))) + + # return to standard print output + sys.stdout = sys.__stdout__ diff --git a/Wrappers/Python/test/test_io.py b/Wrappers/Python/test/test_io.py index d266a9170..b77dd1562 100644 --- a/Wrappers/Python/test/test_io.py +++ b/Wrappers/Python/test/test_io.py @@ -22,13 +22,13 @@ from cil.framework import AcquisitionGeometry import numpy as np import os +import sys from cil.framework import ImageGeometry from cil.io import TXRMDataReader, NEXUSDataReader, NikonDataReader, ZEISSDataReader from cil.io import TIFFWriter, TIFFStackReader from cil.io.utilities import HDF5_utilities from cil.processors import Slicer from utils import has_astra, has_nvidia -from cil.utilities.dataexample import data_dir from cil.utilities.quality_measures import mse from cil.utilities import dataexample import shutil @@ -65,6 +65,10 @@ # change basedir to point to the location of the walnut dataset which can # be downloaded from https://zenodo.org/record/4822516 # basedir = os.path.abspath('/home/edo/scratch/Data/Walnut/valnut_2014-03-21_643_28/tomo-A/') + +data_dir = os.path.abspath( + os.path.join(sys.prefix, 'share','cil') +) basedir = data_dir filename = os.path.join(basedir, "valnut_tomo-A.txrm") has_file = os.path.isfile(filename)