Skip to content

Commit

Permalink
Download data in dataexamples (#1712)
Browse files Browse the repository at this point in the history
  • Loading branch information
hrobarts committed Mar 13, 2024
1 parent d4de7dc commit b16a429
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- Bug fix for missing factor of 1/2 in SIRT update objective and catch in place errors in the SIRT constraint
- Bug fix to allow safe in place calculation for the soft shrinkage algorithm
- Allow Masker to take integer arrays in addition to boolean
- Add remote data class to example data to enable download of relevant datasets from remote repositories
- Improved import error/warning messages
- New adjoint operator
- Bug fix for complex matrix adjoint
Expand Down
205 changes: 172 additions & 33 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,91 @@
import os
import os.path
import sys
from cil.io import NEXUSDataReader

data_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../data/')
)

# this is the default location after a conda install
data_dir = os.path.abspath(
os.path.join(sys.prefix, 'share','cil')
)
from zipfile import ZipFile
from urllib.request import urlopen
from io import BytesIO
from scipy.io import loadmat
from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader

class DATA(object):
@classmethod
def dfile(cls):
return None

class CILDATA(DATA):
data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil'))
@classmethod
def get(cls, size=None, scale=(0,1), **kwargs):
ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = TestData(data_dir=ddir)
return loader.load(cls.dfile(), size, scale, **kwargs)

class REMOTEDATA(DATA):

FOLDER = ''
URL = ''
FILE_SIZE = ''

@classmethod
def get(cls, data_dir):
return None

@classmethod
def _download_and_extract_from_url(cls, data_dir):
with urlopen(cls.URL) as response:
with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile:
zipfile.extractall(path = data_dir)

@classmethod
def download_data(cls, data_dir):
'''
Download a dataset from a remote repository
Parameters
----------
data_dir: str, optional
The path to the data directory where the downloaded data should be stored
'''
if os.path.isdir(os.path.join(data_dir, cls.FOLDER)):
print("Dataset already exists in " + data_dir)
else:
if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER))
print('Download complete')
else:
print('Download cancelled')

class BOAT(DATA):
class BOAT(CILDATA):
@classmethod
def dfile(cls):
return TestData.BOAT
class CAMERA(DATA):
class CAMERA(CILDATA):
@classmethod
def dfile(cls):
return TestData.CAMERA
class PEPPERS(DATA):
class PEPPERS(CILDATA):
@classmethod
def dfile(cls):
return TestData.PEPPERS
class RESOLUTION_CHART(DATA):
class RESOLUTION_CHART(CILDATA):
@classmethod
def dfile(cls):
return TestData.RESOLUTION_CHART
class SIMPLE_PHANTOM_2D(DATA):
class SIMPLE_PHANTOM_2D(CILDATA):
@classmethod
def dfile(cls):
return TestData.SIMPLE_PHANTOM_2D
class SHAPES(DATA):
class SHAPES(CILDATA):
@classmethod
def dfile(cls):
return TestData.SHAPES
class RAINBOW(DATA):
class RAINBOW(CILDATA):
@classmethod
def dfile(cls):
return TestData.RAINBOW
class SYNCHROTRON_PARALLEL_BEAM_DATA(DATA):
class SYNCHROTRON_PARALLEL_BEAM_DATA(CILDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -90,11 +124,11 @@ def get(cls, **kwargs):
The DLS dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs'))
return loader.read()
class SIMULATED_PARALLEL_BEAM_DATA(DATA):
class SIMULATED_PARALLEL_BEAM_DATA(CILDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -111,11 +145,11 @@ def get(cls, **kwargs):
The simulated spheres dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs'))
return loader.read()
class SIMULATED_CONE_BEAM_DATA(DATA):
class SIMULATED_CONE_BEAM_DATA(CILDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -132,11 +166,11 @@ def get(cls, **kwargs):
The simulated spheres dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs'))
return loader.read()
class SIMULATED_SPHERE_VOLUME(DATA):
class SIMULATED_SPHERE_VOLUME(CILDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -151,12 +185,117 @@ def get(cls, **kwargs):
-------
ImageData
The simulated spheres volume
'''

ddir = kwargs.get('data_dir', data_dir)
'''
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs'))
return loader.read()

class WALNUT(REMOTEDATA):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
'''
FOLDER = 'walnut'
URL = 'https://zenodo.org/record/4822516/files/walnut.zip'
FILE_SIZE = '6.4 GB'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
----------
data_dir: str
The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir)
Returns
-------
ImageData
The walnut dataset
'''
filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm')
try:
loader = ZEISSDataReader(file_name=filepath)
return loader.read()
except(FileNotFoundError):
raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__)))

class USB(REMOTEDATA):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
'''
FOLDER = 'USB'
URL = 'https://zenodo.org/record/4822516/files/usb.zip'
FILE_SIZE = '3.2 GB'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
----------
data_dir: str
The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir)
Returns
-------
ImageData
The usb dataset
'''
filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm')
try:
loader = ZEISSDataReader(file_name=filepath)
return loader.read()
except(FileNotFoundError):
raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__)))

class KORN(REMOTEDATA):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
'''
FOLDER = 'korn'
URL = 'https://zenodo.org/record/6874123/files/korn.zip'
FILE_SIZE = '2.9 GB'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
This function returns the raw projection data from the .xtekct file
Parameters
----------
data_dir: str
The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir)
Returns
-------
ImageData
The korn dataset
'''
filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct')
try:
loader = NikonDataReader(file_name=filepath)
return loader.read()
except(FileNotFoundError):
raise(FileNotFoundError("Dataset .xtekct file not found in: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__)))


class SANDSTONE(REMOTEDATA):
'''
A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435
A small subset of the data containing selected projections and 4 slices of the reconstruction
'''
FOLDER = 'sandstone'
URL = 'https://zenodo.org/records/4912435/files/small.zip'
FILE_SIZE = '227 MB'

class TestData(object):
'''Class to return test data
Expand All @@ -178,9 +317,9 @@ class TestData(object):
SHAPES = 'shapes.png'
RAINBOW = 'rainbow.png'

def __init__(self, **kwargs):
self.data_dir = kwargs.get('data_dir', data_dir)

def __init__(self, data_dir):
self.data_dir = data_dir
def load(self, which, size=None, scale=(0,1), **kwargs):
'''
Return a test data of the requested image
Expand Down Expand Up @@ -506,4 +645,4 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs):
if clip:
out = np.clip(out, low_clip, 1.0)

return out
return out
81 changes: 79 additions & 2 deletions Wrappers/Python/test/test_dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
from cil.framework.framework import ImageGeometry,AcquisitionGeometry
from cil.utilities import dataexample
from cil.utilities import noise
import os, sys
import os, sys, shutil
from testclass import CCPiTestClass
import platform
import numpy as np
from unittest.mock import patch, MagicMock
from urllib import request
from zipfile import ZipFile
from io import StringIO

initialise_tests()

Expand Down Expand Up @@ -147,4 +151,77 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self):
.set_panel((128,128),(64,64))\
.set_angles(np.linspace(0,360,300,False))

self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch")
self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch")

class TestRemoteData(unittest.TestCase):

def setUp(self):

self.data_list = ['WALNUT','USB','KORN','SANDSTONE']
self.tmp_file = 'tmp.txt'
self.tmp_zip = 'tmp.zip'
with ZipFile(self.tmp_zip, 'w') as zipped_file:
zipped_file.writestr(self.tmp_file, np.array([1, 2, 3]))
with open(self.tmp_zip, 'rb') as zipped_file:
self.zipped_bytes = zipped_file.read()

def tearDown(self):
for data in self.data_list:
test_func = getattr(dataexample, data)
if os.path.exists(os.path.join(test_func.FOLDER)):
shutil.rmtree(test_func.FOLDER)

if os.path.exists(self.tmp_zip):
os.remove(self.tmp_zip)

if os.path.exists(self.tmp_file):
os.remove(self.tmp_file)

def mock_urlopen(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = self.zipped_bytes
mock_response.__enter__.return_value = mock_response
mock_urlopen.return_value = mock_response

@patch('cil.utilities.dataexample.urlopen')
def test_unzip_remote_data(self, mock_urlopen):
self.mock_urlopen(mock_urlopen)
dataexample.REMOTEDATA._download_and_extract_from_url('.')
self.assertTrue(os.path.isfile(self.tmp_file))

@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_n(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

data_list = ['WALNUT','USB','KORN','SANDSTONE']
for data in data_list:
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
test_func = getattr(dataexample, data)
test_func.download_data('.')

self.assertFalse(os.path.isfile(self.tmp_file))
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n')

# return to standard print output
sys.stdout = sys.__stdout__

@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_y(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput


for data in self.data_list:
test_func = getattr(dataexample, data)
test_func.download_data('.')
self.assertTrue(os.path.isfile(os.path.join(test_func.FOLDER,self.tmp_file)))

# return to standard print output
sys.stdout = sys.__stdout__
Loading

0 comments on commit b16a429

Please sign in to comment.