Skip to content

Commit

Permalink
whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Mar 13, 2024
1 parent b16a429 commit d101983
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
34 changes: 17 additions & 17 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ class DATA(object):
@classmethod
def dfile(cls):
return None

class CILDATA(DATA):
data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil'))
@classmethod
def get(cls, size=None, scale=(0,1), **kwargs):
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = TestData(data_dir=ddir)
return loader.load(cls.dfile(), size, scale, **kwargs)

class REMOTEDATA(DATA):

FOLDER = ''
URL = ''
FILE_SIZE = ''
Expand All @@ -56,7 +56,7 @@ def get(cls, data_dir):
def _download_and_extract_from_url(cls, data_dir):
with urlopen(cls.URL) as response:
with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile:
zipfile.extractall(path = data_dir)
zipfile.extractall(path = data_dir)

@classmethod
def download_data(cls, data_dir):
Expand All @@ -72,8 +72,8 @@ def download_data(cls, data_dir):
if os.path.isdir(os.path.join(data_dir, cls.FOLDER)):
print("Dataset already exists in " + data_dir)
else:
if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER))
print('Download complete')
else:
Expand Down Expand Up @@ -185,15 +185,15 @@ def get(cls, **kwargs):
-------
ImageData
The simulated spheres volume
'''
'''
ddir = kwargs.get('data_dir', CILDATA.data_dir)
loader = NEXUSDataReader()
loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs'))
return loader.read()

class WALNUT(REMOTEDATA):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
'''
FOLDER = 'walnut'
URL = 'https://zenodo.org/record/4822516/files/walnut.zip'
Expand All @@ -202,7 +202,7 @@ class WALNUT(REMOTEDATA):
@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
Expand All @@ -222,19 +222,19 @@ def get(cls, data_dir):
except(FileNotFoundError):
raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__)))

class USB(REMOTEDATA):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
'''
FOLDER = 'USB'
FOLDER = 'USB'
URL = 'https://zenodo.org/record/4822516/files/usb.zip'
FILE_SIZE = '3.2 GB'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file
Parameters
Expand All @@ -254,7 +254,7 @@ def get(cls, data_dir):
except(FileNotFoundError):
raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \
Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__)))

class KORN(REMOTEDATA):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
Expand Down Expand Up @@ -319,7 +319,7 @@ class TestData(object):

def __init__(self, data_dir):
self.data_dir = data_dir

def load(self, which, size=None, scale=(0,1), **kwargs):
'''
Return a test data of the requested image
Expand Down Expand Up @@ -645,4 +645,4 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs):
if clip:
out = np.clip(out, low_clip, 1.0)

return out
return out
36 changes: 18 additions & 18 deletions Wrappers/Python/test/test_dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from testclass import CCPiTestClass
import platform
import numpy as np
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock
from urllib import request
from zipfile import ZipFile
from io import StringIO
Expand Down Expand Up @@ -151,26 +151,26 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self):
.set_panel((128,128),(64,64))\
.set_angles(np.linspace(0,360,300,False))

self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch")
self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch")

class TestRemoteData(unittest.TestCase):

def setUp(self):

self.data_list = ['WALNUT','USB','KORN','SANDSTONE']
self.tmp_file = 'tmp.txt'
self.tmp_zip = 'tmp.zip'
with ZipFile(self.tmp_zip, 'w') as zipped_file:
zipped_file.writestr(self.tmp_file, np.array([1, 2, 3]))
with open(self.tmp_zip, 'rb') as zipped_file:
self.zipped_bytes = zipped_file.read()
self.zipped_bytes = zipped_file.read()

def tearDown(self):
for data in self.data_list:
test_func = getattr(dataexample, data)
if os.path.exists(os.path.join(test_func.FOLDER)):
shutil.rmtree(test_func.FOLDER)

if os.path.exists(self.tmp_zip):
os.remove(self.tmp_zip)

Expand All @@ -189,39 +189,39 @@ def test_unzip_remote_data(self, mock_urlopen):
dataexample.REMOTEDATA._download_and_extract_from_url('.')
self.assertTrue(os.path.isfile(self.tmp_file))

@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_n(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

data_list = ['WALNUT','USB','KORN','SANDSTONE']
for data in data_list:
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
capturedOutput = StringIO()
sys.stdout = capturedOutput
test_func = getattr(dataexample, data)
test_func.download_data('.')

self.assertFalse(os.path.isfile(self.tmp_file))
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n')

# return to standard print output
sys.stdout = sys.__stdout__
sys.stdout = sys.__stdout__

@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_y(self, mock_urlopen, input):
self.mock_urlopen(mock_urlopen)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
capturedOutput = StringIO()
sys.stdout = capturedOutput



for data in self.data_list:
test_func = getattr(dataexample, data)
test_func.download_data('.')
self.assertTrue(os.path.isfile(os.path.join(test_func.FOLDER,self.tmp_file)))

# return to standard print output
sys.stdout = sys.__stdout__
sys.stdout = sys.__stdout__

0 comments on commit d101983

Please sign in to comment.