diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index a6a4f020e..f410b91e8 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -16,7 +16,7 @@ # Authors: # CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt -from cil.framework import ImageData, ImageGeometry, DataContainer +from cil.framework import ImageGeometry import numpy import numpy as np from PIL import Image @@ -26,301 +26,44 @@ from zipfile import ZipFile from urllib.request import urlopen from io import BytesIO -from scipy.io import loadmat from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader +from abc import ABC +from matplotlib.pyplot import imread +from scipy.io import loadmat +from pathlib import Path -class DATA(object): - @classmethod - def dfile(cls): - return None - -class CILDATA(DATA): - data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil')) - @classmethod - def get(cls, size=None, scale=(0,1), **kwargs): - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = TestData(data_dir=ddir) - return loader.load(cls.dfile(), size, scale, **kwargs) - -class REMOTEDATA(DATA): - - FOLDER = '' - URL = '' - FILE_SIZE = '' - - @classmethod - def get(cls, data_dir): - return None - - @classmethod - def _download_and_extract_from_url(cls, data_dir): - with urlopen(cls.URL) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - - @classmethod - def download_data(cls, data_dir): - ''' - Download a dataset from a remote repository - - Parameters - ---------- - data_dir: str, optional - The path to the data directory where the downloaded data should be stored - - ''' - if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): - print("Dataset already exists in " + data_dir) - else: - if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": - print('Downloading dataset from ' + cls.URL) - cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) - print('Download complete') - else: - print('Download cancelled') - -class BOAT(CILDATA): - @classmethod - def dfile(cls): - return TestData.BOAT -class CAMERA(CILDATA): - @classmethod - def dfile(cls): - return TestData.CAMERA -class PEPPERS(CILDATA): - @classmethod - def dfile(cls): - return TestData.PEPPERS -class RESOLUTION_CHART(CILDATA): - @classmethod - def dfile(cls): - return TestData.RESOLUTION_CHART -class SIMPLE_PHANTOM_2D(CILDATA): - @classmethod - def dfile(cls): - return TestData.SIMPLE_PHANTOM_2D -class SHAPES(CILDATA): - @classmethod - def dfile(cls): - return TestData.SHAPES -class RAINBOW(CILDATA): - @classmethod - def dfile(cls): - return TestData.RAINBOW -class SYNCHROTRON_PARALLEL_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A DLS dataset - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The DLS dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs')) - return loader.read() -class SIMULATED_PARALLEL_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A simulated parallel-beam dataset generated from SIMULATED_SPHERE_VOLUME - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The simulated spheres dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs')) - return loader.read() -class SIMULATED_CONE_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A cone-beam dataset generated from SIMULATED_SPHERE_VOLUME - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The simulated spheres dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs')) - return loader.read() -class SIMULATED_SPHERE_VOLUME(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A simulated volume of spheres - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - ImageData - The simulated spheres volume - ''' - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) - return loader.read() - -class WALNUT(REMOTEDATA): - ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 - ''' - FOLDER = 'walnut' - URL = 'https://zenodo.org/record/4822516/files/walnut.zip' - FILE_SIZE = '6.4 GB' - - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 - This function returns the raw projection data from the .txrm file - - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) - - Returns - ------- - ImageData - The walnut dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') - try: - loader = ZEISSDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) - -class USB(REMOTEDATA): - ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 - ''' - FOLDER = 'USB' - URL = 'https://zenodo.org/record/4822516/files/usb.zip' - FILE_SIZE = '3.2 GB' - - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 - This function returns the raw projection data from the .txrm file - - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) - - Returns - ------- - ImageData - The usb dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') - try: - loader = ZEISSDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) - -class KORN(REMOTEDATA): - ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 - ''' - FOLDER = 'korn' - URL = 'https://zenodo.org/record/6874123/files/korn.zip' - FILE_SIZE = '2.9 GB' - - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 - This function returns the raw projection data from the .xtekct file - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir) +DEFAULT_DATA_DIR = os.path.abspath(os.path.join(sys.prefix, 'share', 'cil')) - Returns - ------- - ImageData - The korn dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') - try: - loader = NikonDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .xtekct file not found in: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) +class BaseTestData(ABC): + def __init__(self, data_dir=DEFAULT_DATA_DIR): + self.data_dir = data_dir +class TestData(BaseTestData): + '''Provides 6 datasets: -class SANDSTONE(REMOTEDATA): - ''' - A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 - A small subset of the data containing selected projections and 4 slices of the reconstruction + BOAT: 'boat.tiff' + CAMERA: 'camera.png' + PEPPERS: 'peppers.tiff' + RESOLUTION_CHART: 'resolution_chart.tiff' + SIMPLE_PHANTOM_2D: 'hotdog' + SHAPES: 'shapes.png' + RAINBOW: 'rainbow.png' ''' - FOLDER = 'sandstone' - URL = 'https://zenodo.org/records/4912435/files/small.zip' - FILE_SIZE = '227 MB' - -class TestData(object): - '''Class to return test data - - provides 6 dataset: BOAT = 'boat.tiff' CAMERA = 'camera.png' PEPPERS = 'peppers.tiff' RESOLUTION_CHART = 'resolution_chart.tiff' SIMPLE_PHANTOM_2D = 'hotdog' - SHAPES = 'shapes.png' + SHAPES = 'shapes.png' RAINBOW = 'rainbow.png' - ''' - BOAT = 'boat.tiff' - CAMERA = 'camera.png' - PEPPERS = 'peppers.tiff' - RESOLUTION_CHART = 'resolution_chart.tiff' - SIMPLE_PHANTOM_2D = 'hotdog' - SHAPES = 'shapes.png' - RAINBOW = 'rainbow.png' + dfile: str - def __init__(self, data_dir): - self.data_dir = data_dir - - def load(self, which, size=None, scale=(0,1), **kwargs): + @classmethod + def _datasets(cls): + return {cls.BOAT, cls.CAMERA, cls.PEPPERS, cls.RESOLUTION_CHART, cls.SIMPLE_PHANTOM_2D, cls.SHAPES, cls.RAINBOW} + + def load(self, which, size=None, scale=None): ''' Return a test data of the requested image @@ -338,52 +81,28 @@ def load(self, which, size=None, scale=(0,1), **kwargs): ImageData The simulated spheres volume ''' - if which not in [TestData.BOAT, TestData.CAMERA, - TestData.PEPPERS, TestData.RESOLUTION_CHART, - TestData.SIMPLE_PHANTOM_2D, TestData.SHAPES, - TestData.RAINBOW]: - raise ValueError('Unknown TestData {}.'.format(which)) + if scale is None: + scale = 0, 1 + if which not in self._datasets(): + raise KeyError(f"Unknown TestData: {which}") if which == TestData.SIMPLE_PHANTOM_2D: - if size is None: - N = 512 - M = 512 - else: - N = size[0] - M = size[1] - + N, M = (512, 512) if size is None else (size[0], size[1]) sdata = numpy.zeros((N, M)) sdata[int(round(N/4)):int(round(3*N/4)), int(round(M/4)):int(round(3*M/4))] = 0.5 sdata[int(round(N/8)):int(round(7*N/8)), int(round(3*M/8)):int(round(5*M/8))] = 1 ig = ImageGeometry(voxel_num_x = M, voxel_num_y = N, dimension_labels=[ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]) data = ig.allocate() data.fill(sdata) - elif which == TestData.SHAPES: - with Image.open(os.path.join(self.data_dir, which)) as f: - - if size is None: - N = 200 - M = 300 - else: - N = size[0] - M = size[1] - + N, M = (200, 300) if size is None else (size[0], size[1]) ig = ImageGeometry(voxel_num_x = M, voxel_num_y = N, dimension_labels=[ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]) data = ig.allocate() tmp = numpy.array(f.convert('L').resize((M,N))) data.fill(tmp/numpy.max(tmp)) - else: with Image.open(os.path.join(self.data_dir, which)) as tmp: - - if size is None: - N = tmp.size[1] - M = tmp.size[0] - else: - N = size[0] - M = size[1] - + N, M = (tmp.size[1], tmp.size[0]) if size is None else (size[0], size[1]) bands = tmp.getbands() if len(bands) > 1: if len(bands) == 4: @@ -414,26 +133,22 @@ def load(self, which, size=None, scale=(0,1), **kwargs): # print ("data.geometry", data.geometry) return data - @staticmethod - def random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): + @classmethod + def random_noise(cls, image, **kwargs): '''Function to add noise to input image :param image: input dataset, DataContainer of numpy.ndarray - :param mode: type of noise - :param seed: seed for random number generator - :param clip: should clip the data. + :param **kwargs: Passed to `scikit_random_noise` See https://github.com/scikit-image/scikit-image/blob/master/skimage/util/noise.py - ''' if hasattr(image, 'as_array'): - arr = TestData.scikit_random_noise(image.as_array(), mode=mode, seed=seed, clip=clip, - **kwargs) + arr = cls.scikit_random_noise(image.as_array(), **kwargs) out = image.copy() out.fill(arr) return out elif issubclass(type(image), numpy.ndarray): - return TestData.scikit_random_noise(image, mode=mode, seed=seed, clip=clip, - **kwargs) + return cls.scikit_random_noise(image, **kwargs) + raise TypeError(type(image)) @staticmethod def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): @@ -538,7 +253,6 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - """ mode = mode.lower() @@ -548,7 +262,7 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): else: low_clip = 0. - image = numpy.asarray(image, dtype=(np.float64)) + image = numpy.asarray(image, dtype=np.float64) if seed is not None: np.random.seed(seed=seed) @@ -645,4 +359,171 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): if clip: out = np.clip(out, low_clip, 1.0) - return out \ No newline at end of file + return out + + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR, **load_kwargs): + """Calls cls(data_dir).load(cls.dfile, **load_kwargs)""" + return cls(data_dir).load(cls.dfile, **load_kwargs) + +class BOAT(TestData): + dfile = TestData.BOAT +class CAMERA(TestData): + dfile = TestData.CAMERA +class PEPPERS(TestData): + dfile = TestData.PEPPERS +class RESOLUTION_CHART(TestData): + dfile = TestData.RESOLUTION_CHART +class SIMPLE_PHANTOM_2D(TestData): + dfile = TestData.SIMPLE_PHANTOM_2D +class SHAPES(TestData): + dfile = TestData.SHAPES +class RAINBOW(TestData): + dfile = TestData.RAINBOW + +class NexusTestData(BaseTestData): + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR): + ''' + Returns + ------- + AcquisitionData + ''' + loader = NEXUSDataReader() + loader.set_up(file_name=os.path.join(data_dir, cls.dfile)) + return loader.read() + +class SYNCHROTRON_PARALLEL_BEAM_DATA(NexusTestData): + '''A DLS dataset''' + dfile = '24737_fd_normalised.nxs' +class SIMULATED_PARALLEL_BEAM_DATA(NexusTestData): + '''A simulated parallel-beam dataset generated from SIMULATED_SPHERE_VOLUME''' + dfile = 'sim_parallel_beam.nxs' +class SIMULATED_CONE_BEAM_DATA(NexusTestData): + '''A cone-beam dataset generated from SIMULATED_SPHERE_VOLUME''' + dfile = 'sim_cone_beam.nxs' +class SIMULATED_SPHERE_VOLUME(NexusTestData): + '''A simulated volume of spheres''' + dfile = 'sim_volume.nxs' + +class RemoteTestData(BaseTestData): + URL: str + FILE_SIZE: str + + @staticmethod + def _prompt(msg): + while (res := input(f"{msg} [y/n]").lower()) not in "yn": + pass + return res == "y" + + def download_data(self): + '''Download a dataset from a remote repository''' + folder = os.path.join(self.data_dir, type(self).__name__) + if os.path.isdir(folder): + print(f"Dataset already exists in {folder}") + else: + if self._prompt(f"Are you sure you want to download {self.FILE_SIZE} dataset from {self.URL}?"): + print(f"Downloading dataset from {self.URL}") + with urlopen(self.URL) as response, BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path=folder) + print('Download complete') + else: + print('Download cancelled') + +class WALNUT(RemoteTestData): + '''A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516''' + URL = 'https://zenodo.org/record/4822516/files/walnut.zip' + FILE_SIZE = '6.4 GB' + + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR): + ''' + This function returns the raw projection data from the .txrm file + + Returns + ------- + ImageData + The walnut dataset + ''' + self = cls(data_dir) + filepath = os.path.join(self.data_dir, cls.__name__, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + self.download_data() + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + +class USB(RemoteTestData): + '''A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516''' + URL = 'https://zenodo.org/record/4822516/files/usb.zip' + FILE_SIZE = '3.2 GB' + + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR): + ''' + This function returns the raw projection data from the .txrm file + + Returns + ------- + ImageData + The usb dataset + ''' + self = cls(data_dir) + filepath = os.path.join(self.data_dir, cls.__name__, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') + self.download_data() + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + +class KORN(RemoteTestData): + '''A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123''' + URL = 'https://zenodo.org/record/6874123/files/korn.zip' + FILE_SIZE = '2.9 GB' + + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR): + ''' + This function returns the raw projection data from the .xtekct file + + Returns + ------- + ImageData + The korn dataset + ''' + self = cls(data_dir) + filepath = os.path.join(self.data_dir, cls.__name__, 'Korn i kasse','47209 testscan korn01_recon.xtekct') + self.download_data() + loader = NikonDataReader(file_name=filepath) + return loader.read() + +class SANDSTONE(RemoteTestData): + ''' + A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + ''' + URL = 'https://zenodo.org/records/4912435/files/small.zip' + FILE_SIZE = '227 MB' + + @classmethod + def get(cls, filename, data_dir=DEFAULT_DATA_DIR): + ''' + This function returns data from a specified file in the sandstone folder + Parameters + ---------- + filename : str + filename of the data to get, specify the filepath within the sandstone folder e.g. 'slice_0270_data.mat' or 'proj/BBii_0001.tif' + + Returns + ------- + DataContainer + Data from the sandstone dataset + ''' + self = cls(data_dir) + filepath = os.path.join(self.data_dir, cls.__name__, filename) + print(filepath) + self.download_data() + if Path(filename).suffix == '.tif': + return imread(filepath) + + elif Path(filename).suffix == '.mat': + return loadmat(filepath) + + else: + raise ValueError('{0} file type not recognised'.format( Path(filename).suffix) ) \ No newline at end of file diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index 2baf4eaad..afd5f74eb 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -25,7 +25,7 @@ from testclass import CCPiTestClass import platform import numpy as np -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock from urllib import request from zipfile import ZipFile from io import StringIO @@ -151,26 +151,25 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): .set_panel((128,128),(64,64))\ .set_angles(np.linspace(0,360,300,False)) - self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch") + self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch") class TestRemoteData(unittest.TestCase): def setUp(self): - self.data_list = ['WALNUT','USB','KORN','SANDSTONE'] self.tmp_file = 'tmp.txt' self.tmp_zip = 'tmp.zip' with ZipFile(self.tmp_zip, 'w') as zipped_file: zipped_file.writestr(self.tmp_file, np.array([1, 2, 3])) with open(self.tmp_zip, 'rb') as zipped_file: - self.zipped_bytes = zipped_file.read() - + self.zipped_bytes = zipped_file.read() + def tearDown(self): - for data in self.data_list: - test_func = getattr(dataexample, data) - if os.path.exists(os.path.join(test_func.FOLDER)): - shutil.rmtree(test_func.FOLDER) - + for data in self.data_list + ['REMOTE_TEST']: + folder = os.path.join(dataexample.DEFAULT_DATA_DIR, data) + if os.path.exists(folder): + shutil.rmtree(folder) + if os.path.exists(self.tmp_zip): os.remove(self.tmp_zip) @@ -183,45 +182,50 @@ def mock_urlopen(self, mock_urlopen): mock_response.__enter__.return_value = mock_response mock_urlopen.return_value = mock_response + @patch('cil.utilities.dataexample.input', return_value='y') @patch('cil.utilities.dataexample.urlopen') - def test_unzip_remote_data(self, mock_urlopen): + def test_unzip_remote_data(self, mock_urlopen, input): self.mock_urlopen(mock_urlopen) - dataexample.REMOTEDATA._download_and_extract_from_url('.') - self.assertTrue(os.path.isfile(self.tmp_file)) + sys.stdout = StringIO() # redirect print output + + fname = os.path.join(dataexample.DEFAULT_DATA_DIR, 'REMOTE_TEST', self.tmp_file) + self.assertFalse(os.path.isfile(fname)) + class REMOTE_TEST(dataexample.RemoteTestData): + URL = '' + FILE_SIZE = '0 B' + REMOTE_TEST().download_data() + self.assertTrue(os.path.isfile(fname)) + os.remove(fname) + + sys.stdout = sys.__stdout__ # return to standard print output - @patch('cil.utilities.dataexample.input', return_value='n') + @patch('cil.utilities.dataexample.input', return_value='n') @patch('cil.utilities.dataexample.urlopen') def test_download_data_input_n(self, mock_urlopen, input): self.mock_urlopen(mock_urlopen) - - data_list = ['WALNUT','USB','KORN','SANDSTONE'] - for data in data_list: - # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput + for data in self.data_list: + sys.stdout = capturedOutput = StringIO() # redirect print output test_func = getattr(dataexample, data) - test_func.download_data('.') - + test_func().download_data() + self.assertFalse(os.path.isfile(self.tmp_file)) self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n') - + # return to standard print output - sys.stdout = sys.__stdout__ + sys.stdout = sys.__stdout__ - @patch('cil.utilities.dataexample.input', return_value='y') + @patch('cil.utilities.dataexample.input', return_value='y') @patch('cil.utilities.dataexample.urlopen') def test_download_data_input_y(self, mock_urlopen, input): self.mock_urlopen(mock_urlopen) + sys.stdout = StringIO() # redirect print output - # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput - - for data in self.data_list: test_func = getattr(dataexample, data) - test_func.download_data('.') - self.assertTrue(os.path.isfile(os.path.join(test_func.FOLDER,self.tmp_file))) - - # return to standard print output - sys.stdout = sys.__stdout__ + fname = os.path.join(dataexample.DEFAULT_DATA_DIR, data, self.tmp_file) + self.assertFalse(os.path.isfile(fname)) + test_func().download_data() + self.assertTrue(os.path.isfile(fname)) + os.remove(fname) + + sys.stdout = sys.__stdout__ # return to standard print output