Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download data in dataexamples #1712

Merged
merged 14 commits into from
Mar 13, 2024
Merged
112 changes: 82 additions & 30 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,79 @@
import os
import os.path
import sys
from cil.io import NEXUSDataReader

data_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__),
'../data/')
)

# this is the default location after a conda install
data_dir = os.path.abspath(
os.path.join(sys.prefix, 'share','cil')
)
from zipfile import ZipFile
from urllib.request import urlopen
from io import BytesIO
from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader

class DATA(object):
@classmethod
def dfile(cls):
return None

class INTERNALDATA(DATA):
data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil'))
@classmethod
def get(cls, size=None, scale=(0,1), **kwargs):
ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', INTERNALDATA.data_dir)
loader = TestData(data_dir=ddir)
return loader.load(cls.dfile(), size, scale, **kwargs)

class REMOTEDATA(DATA):
PATH = ''
URL = ''

@classmethod
def get(cls, data_dir):
return None

@classmethod
def download_from_url(cls, data_dir):
with urlopen(cls.URL) as response:
with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile:
zipfile.extractall(path = data_dir)

@classmethod
def download_data(cls, data_dir):
if os.path.isfile(os.path.join(data_dir, cls.PATH)):
print("Dataset already exists in " + data_dir)
else:
if input("Are you sure you want to download the dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
cls.download_from_url(data_dir)
print('Download complete')
else:
print('Download cancelled')

class BOAT(DATA):
class BOAT(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.BOAT
class CAMERA(DATA):
class CAMERA(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.CAMERA
class PEPPERS(DATA):
class PEPPERS(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.PEPPERS
class RESOLUTION_CHART(DATA):
class RESOLUTION_CHART(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.RESOLUTION_CHART
class SIMPLE_PHANTOM_2D(DATA):
class SIMPLE_PHANTOM_2D(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.SIMPLE_PHANTOM_2D
class SHAPES(DATA):
class SHAPES(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.SHAPES
class RAINBOW(DATA):
class RAINBOW(INTERNALDATA):
@classmethod
def dfile(cls):
return TestData.RAINBOW
class SYNCHROTRON_PARALLEL_BEAM_DATA(DATA):
class SYNCHROTRON_PARALLEL_BEAM_DATA(INTERNALDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -90,11 +112,11 @@ def get(cls, **kwargs):
The DLS dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', INTERNALDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs'))
return loader.read()
class SIMULATED_PARALLEL_BEAM_DATA(DATA):
class SIMULATED_PARALLEL_BEAM_DATA(INTERNALDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -111,11 +133,11 @@ def get(cls, **kwargs):
The simulated spheres dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', INTERNALDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs'))
return loader.read()
class SIMULATED_CONE_BEAM_DATA(DATA):
class SIMULATED_CONE_BEAM_DATA(INTERNALDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -132,11 +154,11 @@ def get(cls, **kwargs):
The simulated spheres dataset
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', INTERNALDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs'))
return loader.read()
class SIMULATED_SPHERE_VOLUME(DATA):
class SIMULATED_SPHERE_VOLUME(INTERNALDATA):
@classmethod
def get(cls, **kwargs):
'''
Expand All @@ -153,11 +175,41 @@ def get(cls, **kwargs):
The simulated spheres volume
'''

ddir = kwargs.get('data_dir', data_dir)
ddir = kwargs.get('data_dir', INTERNALDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs'))
return loader.read()

class WALNUT(REMOTEDATA):

PATH = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm')
URL = 'https://zenodo.org/record/4822516/files/walnut.zip'

@classmethod
def get(cls, data_dir):
try:
loader = ZEISSDataReader(file_name=os.path.join(data_dir,cls.PATH))
return loader.read()
except(FileNotFoundError):
raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__)))



class KORN(REMOTEDATA):
PATH = os.path.join('Korn i kasse','47209 testscan korn01_recon.xtekct')
URL = 'https://zenodo.org/record/6874123/files/korn.zip'

@classmethod
def get(cls, data_dir):
try:
loader = NikonDataReader(file_name=os.path.join(data_dir, cls.PATH))
return loader.read()
except(FileNotFoundError):
raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__)))


class TestData(object):
'''Class to return test data

Expand All @@ -178,8 +230,8 @@ class TestData(object):
SHAPES = 'shapes.png'
RAINBOW = 'rainbow.png'

def __init__(self, **kwargs):
self.data_dir = kwargs.get('data_dir', data_dir)
def __init__(self, data_dir):
self.data_dir = data_dir

def load(self, which, size=None, scale=(0,1), **kwargs):
'''
Expand Down Expand Up @@ -506,4 +558,4 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs):
if clip:
out = np.clip(out, low_clip, 1.0)

return out
return out
66 changes: 66 additions & 0 deletions Wrappers/Python/test/test_dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from testclass import CCPiTestClass
import platform
import numpy as np
from unittest.mock import patch, MagicMock
from urllib import request
from zipfile import ZipFile
from io import StringIO

initialise_tests()

Expand Down Expand Up @@ -148,3 +152,65 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self):
.set_angles(np.linspace(0,360,300,False))

self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch")

class TestRemoteData(unittest.TestCase):

def setUp(self):
self.tmp_file = 'tmp.txt'
self.tmp_zip = 'tmp.zip'
with ZipFile(self.tmp_zip, 'w') as zipped_file:
zipped_file.writestr(self.tmp_file, np.array([1, 2, 3]))
with open(self.tmp_zip, 'rb') as zipped_file:
self.zipped_bytes = zipped_file.read()

def tearDown(self):
if os.path.exists(self.tmp_file):
os.remove(self.tmp_file)
if os.path.exists(self.tmp_zip):
os.remove(self.tmp_zip)

def mock_urlopen(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = self.zipped_bytes
mock_response.__enter__.return_value = mock_response
mock_urlopen.return_value = mock_response

@patch('cil.utilities.dataexample.urlopen')
def test_unzip_remote_data(self, mock_urlopen):
self.mock_urlopen(mock_urlopen)
dataexample.REMOTEDATA.download_from_url('.')
self.assertTrue(os.path.isfile(self.tmp_file))

@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_n(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput

dataexample.WALNUT.download_data('.')

self.assertFalse(os.path.isfile(self.tmp_file))
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n')

# return to standard print output
sys.stdout = sys.__stdout__

@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_y(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput

dataexample.WALNUT.download_data('.')
self.assertTrue(os.path.isfile(self.tmp_file))

# return to standard print output
sys.stdout = sys.__stdout__


6 changes: 5 additions & 1 deletion Wrappers/Python/test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from cil.framework import AcquisitionGeometry
import numpy as np
import os
import sys
from cil.framework import ImageGeometry
from cil.io import TXRMDataReader, NEXUSDataReader, NikonDataReader, ZEISSDataReader
from cil.io import TIFFWriter, TIFFStackReader
from cil.io.utilities import HDF5_utilities
from cil.processors import Slicer
from utils import has_astra, has_nvidia
from cil.utilities.dataexample import data_dir
from cil.utilities.quality_measures import mse
from cil.utilities import dataexample
import shutil
Expand Down Expand Up @@ -65,6 +65,10 @@
# change basedir to point to the location of the walnut dataset which can
# be downloaded from https://zenodo.org/record/4822516
# basedir = os.path.abspath('/home/edo/scratch/Data/Walnut/valnut_2014-03-21_643_28/tomo-A/')

data_dir = os.path.abspath(
os.path.join(sys.prefix, 'share','cil')
)
basedir = data_dir
filename = os.path.join(basedir, "valnut_tomo-A.txrm")
has_file = os.path.isfile(filename)
Expand Down
Loading