Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataexample update #1774

Merged
merged 31 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f982806
Edits from zenodo_get
hrobarts Apr 9, 2024
06b7128
Use zenodo_get for download
hrobarts Apr 11, 2024
8b7c1f6
Merge branch 'master' into dataexample_update
hrobarts Apr 11, 2024
c4b0b7a
add zenodo_get dependency
casperdcl Apr 12, 2024
8581269
Update dependecies
hrobarts May 15, 2024
4f4265e
Merge branch 'master' into dataexample_update
hrobarts May 15, 2024
0785e63
Merge branch 'dataexample_update' of github.com:TomographicImaging/CI…
hrobarts May 15, 2024
da0872a
Update test
hrobarts Jun 28, 2024
93ecc5c
Add get for sandstone
hrobarts Jun 28, 2024
63850e9
Merge branch 'master' into dataexample_update
hrobarts Jun 28, 2024
188fca0
Update changelog
hrobarts Jun 28, 2024
84ae004
Tidy
hrobarts Jun 28, 2024
9ee2b9f
Merge branch 'master' into dataexample_update
hrobarts Jun 28, 2024
8247c20
Merge branch 'master' into dataexample_update
hrobarts Jul 1, 2024
11cac80
Merge branch 'master' into dataexample_update
hrobarts Jul 22, 2024
f9c1a59
Merge branch 'master' into dataexample_update
hrobarts Sep 12, 2024
a5b6864
Review updates, add zenodo record test
hrobarts Sep 16, 2024
949e118
Test zip file is removed
hrobarts Sep 16, 2024
83bd33a
Add examples to docs
hrobarts Sep 17, 2024
5958da2
Merge branch 'master' into dataexample_update
hrobarts Sep 17, 2024
3be0ee4
Documentation updates
hrobarts Sep 17, 2024
7933fa1
Remove zenodo record exists test
hrobarts Sep 17, 2024
ba29832
Merge branch 'dataexample_update' of github.com:TomographicImaging/CI…
hrobarts Sep 17, 2024
8be1ffb
Documentation update
hrobarts Sep 17, 2024
51ac09b
Merge branch 'master' into dataexample_update
hrobarts Sep 23, 2024
9f30fda
Apply suggestions from code review
hrobarts Sep 24, 2024
042cbfc
Add test for data_download(prompt=False)
hrobarts Sep 24, 2024
c3c6801
Apply suggestions from code review
casperdcl Sep 24, 2024
0d94a6e
slight refactor
casperdcl Sep 24, 2024
d3fb961
Merge branch 'master' into dataexample_update
hrobarts Sep 26, 2024
eaac1e1
Merge branch 'master' into dataexample_update
hrobarts Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- BlockOperator that would return a BlockDataContainer of shape (1,1) now returns the appropriate DataContainer. BlockDataContainer direct and adjoint methods accept DataContainer as parameter (#1802).
- BlurringOperator: remove check for geometry class (old SIRF integration bug) (#1807)
- The `ZeroFunction` and `ConstantFunction` now have a Lipschitz constant of 1. (#1768)
- Update dataexample remote data download to work with windows and use zenodo_get for data download (#1774)
- Changes that break backwards compatibility:
- Merged the files `BlockGeometry.py` and `BlockDataContainer.py` in `framework` to one file `block.py`. Please use `from cil.framework import BlockGeometry, BlockDataContainer` as before (#1799)
- Bug fix in `FGP_TV` function to set the default behaviour not to enforce non-negativity (#1826).
Expand Down
102 changes: 75 additions & 27 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import os.path
import sys
from zipfile import ZipFile
from urllib.request import urlopen
from io import BytesIO
from scipy.io import loadmat
from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader
from zenodo_get import zenodo_get

class DATA(object):
@classmethod
Expand All @@ -46,21 +45,15 @@ def get(cls, size=None, scale=(0,1), **kwargs):
class REMOTEDATA(DATA):

FOLDER = ''
URL = ''
FILE_SIZE = ''
ZENODO_RECORD = ''
ZIP_FILE = ''

@classmethod
def get(cls, data_dir):
return None

@classmethod
def _download_and_extract_from_url(cls, data_dir):
with urlopen(cls.URL) as response:
with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile:
zipfile.extractall(path = data_dir)

@classmethod
def download_data(cls, data_dir):
def download_data(cls, data_dir, prompt=True):
'''
Download a dataset from a remote repository

Expand All @@ -71,12 +64,18 @@ def download_data(cls, data_dir):

'''
if os.path.isdir(os.path.join(data_dir, cls.FOLDER)):
print("Dataset already exists in " + data_dir)
print("Dataset folder already exists in " + data_dir)
else:
if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y":
print('Downloading dataset from ' + cls.URL)
cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER))
print('Download complete')
# get user confirmation for download
user_input = 'y' if not prompt else input("Are you sure you want to download {} dataset from Zenodo record {} ? (y/n)").format(cls.ZIP_FILE, cls.ZENODO_RECORD)
if user_input == "y":
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
zenodo_get([cls.ZENODO_RECORD, '-g', cls.ZIP_FILE, '-o', data_dir])

# unzip file
with ZipFile(os.path.join(data_dir, cls.ZIP_FILE), 'r') as zip_ref:
zip_ref.extractall(os.path.join(data_dir, cls.FOLDER))
os.remove(os.path.join(data_dir, cls.ZIP_FILE))

else:
print('Download cancelled')
casperdcl marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -195,15 +194,21 @@ def get(cls, **kwargs):
class WALNUT(REMOTEDATA):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516

Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.WALNUT.download_data(data_dir) # download the data
>>> dataexample.WALNUT.get(data_dir) # load the data
'''
FOLDER = 'walnut'
URL = 'https://zenodo.org/record/4822516/files/walnut.zip'
FILE_SIZE = '6.4 GB'
ZENODO_RECORD = '4822516'
ZIP_FILE = 'walnut.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
Get the microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file

Parameters
Expand All @@ -227,15 +232,21 @@ def get(cls, data_dir):
class USB(REMOTEDATA):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516

Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.USB.download_data(data_dir) # download the data
>>> dataexample.USB.get(data_dir) # load the data
'''
FOLDER = 'USB'
URL = 'https://zenodo.org/record/4822516/files/usb.zip'
FILE_SIZE = '3.2 GB'
ZENODO_RECORD = '4822516'
ZIP_FILE = 'usb.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
Get the microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516
This function returns the raw projection data from the .txrm file

Parameters
Expand All @@ -259,15 +270,21 @@ def get(cls, data_dir):
class KORN(REMOTEDATA):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123

Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.KORN.download_data(data_dir) # download the data
>>> dataexample.KORN.get(data_dir) # load the data
'''
FOLDER = 'korn'
URL = 'https://zenodo.org/record/6874123/files/korn.zip'
FILE_SIZE = '2.9 GB'
ZENODO_RECORD = '6874123'
ZIP_FILE = 'korn.zip'

@classmethod
def get(cls, data_dir):
'''
A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
Get the microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123
This function returns the raw projection data from the .xtekct file

Parameters
Expand All @@ -279,6 +296,7 @@ def get(cls, data_dir):
-------
ImageData
The korn dataset

'''
filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct')
try:
Expand All @@ -293,10 +311,40 @@ class SANDSTONE(REMOTEDATA):
'''
A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435
A small subset of the data containing selected projections and 4 slices of the reconstruction

Example
--------
>>> data_dir = 'my_PC/data_folder'
>>> dataexample.SANDSTONE.download_data(data_dir) # download the data
>>> dataexample.SANDSTONE.get(data_dir) # load the data
'''
FOLDER = 'sandstone'
URL = 'https://zenodo.org/records/4912435/files/small.zip'
FILE_SIZE = '227 MB'
ZENODO_RECORD = '4912435'
ZIP_FILE = 'small.zip'

@classmethod
def get(cls, data_dir, filename):
hrobarts marked this conversation as resolved.
Show resolved Hide resolved
'''
Get the synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435
A small subset of the data containing selected projections and 4 slices of the reconstruction
Parameters
----------
data_dir: str
The path to the directory where the dataset is stored. Data can be downloaded with dataexample.SANDSTONE.download_data(data_dir)

file: str
The slices or projections to return, specify the path to the file within the data_dir

Returns
-------
ImageData
The selected sandstone dataset
'''
extension = os.path.splitext(filename)[1]
if extension == '.mat':
return loadmat(os.path.join(data_dir,filename))
hrobarts marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError(f"Unknown extension: {extension}")


class TestData(object):
'''Class to return test data
Expand Down
163 changes: 68 additions & 95 deletions Wrappers/Python/test/test_dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from testclass import CCPiTestClass
import platform
import numpy as np
from unittest.mock import patch, MagicMock
from urllib import request
from unittest.mock import patch
from zipfile import ZipFile
from io import StringIO
from tempfile import NamedTemporaryFile
import uuid
from zenodo_get import zenodo_get

initialise_tests()

Expand Down Expand Up @@ -157,116 +157,89 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self):
class TestRemoteData(unittest.TestCase):

def setUp(self):

self.data_list = ['WALNUT','USB','KORN','SANDSTONE']
self.shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES)

def mock_urlopen(self, mock_urlopen, zipped_bytes):
mock_response = MagicMock()
mock_response.read.return_value = zipped_bytes
mock_response.__enter__.return_value = mock_response
mock_urlopen.return_value = mock_response

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.urlopen')
def test_unzip_remote_data(self, mock_urlopen):
'''
Test the _download_and_extract_data_from_url function correctly extracts files from a byte string
The zipped byte string is mocked using a temporary local zip file
def mock_zenodo_get(*args):
hrobarts marked this conversation as resolved.
Show resolved Hide resolved
# mock zenodo_get by making a zip file containing the shapes test data when the function is called
shapes_path = os.path.join(dataexample.CILDATA.data_dir, dataexample.TestData.SHAPES)
with ZipFile(os.path.join(args[0][4], args[0][2]), mode='w') as zip_file:
zip_file.write(shapes_path, arcname=dataexample.TestData.SHAPES)


@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get)
def test_download_data_input_y(self, mock_zenodo_get, input):
'''

# create a temporary zip file to test the function
with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)
dataexample.REMOTEDATA._download_and_extract_from_url(os.path.join(tmp_path, tmp_dir))
Test the download_data function, when the user input is 'y' to 'are you sure you want to download data'
The user input to confirm the download is mocked as 'y'
The zip file download is mocked by creating a zip file locally
Test the download_data function correctly extracts files from the zip file
hrobarts marked this conversation as resolved.
Show resolved Hide resolved
'''
# create a temporary folder in the CIL data directory
tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4()))
os.makedirs(tmp_dir)
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
for data in self.data_list:
hrobarts marked this conversation as resolved.
Show resolved Hide resolved
test_func = getattr(dataexample, data)
test_func.download_data(tmp_dir)
# Test the data file exists
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)),
msg = "Download data test failed with dataset " + data)
# Test the zip file is removed
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'ZIP_FILE'))))
# return to standard print output
sys.stdout = sys.__stdout__
shutil.rmtree(tmp_dir)

self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, dataexample.TestData.SHAPES)))

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_n(self, mock_urlopen, input):
@patch('cil.utilities.dataexample.input', return_value='n')
@patch('cil.utilities.dataexample.zenodo_get', side_effect=mock_zenodo_get)
def test_download_data_input_n(self, mock_zenodo_get, input):
'''
Test the download_data function, when the user input is 'n' to 'are you sure you want to download data'
The zipped byte string is mocked using a temporary local zip file
'''

# create a temporary zip file to test the function
with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)

# create a temporary folder in the CIL data directory
tmp_dir = os.path.join(dataexample.CILDATA.data_dir, str(uuid.uuid4()))
os.makedirs(tmp_dir)
for data in self.data_list:
# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
capturedOutput = StringIO()
sys.stdout = capturedOutput
test_func = getattr(dataexample, data)
test_func.download_data(os.path.join(tmp_path, tmp_dir))
self.assertFalse(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data)
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n', msg = "Failed with dataset " + data)
test_func.download_data(tmp_dir)
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, getattr(test_func, 'FOLDER'), dataexample.TestData.SHAPES)),
msg = "Download dataset test failed with dataset " + data)
self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n',
msg = "Download dataset test failed with dataset " + data)
# return to standard print output
sys.stdout = sys.__stdout__

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))

@unittest.skipIf(platform.system() == 'Windows', "Skip on Windows")
@patch('cil.utilities.dataexample.input', return_value='y')
@patch('cil.utilities.dataexample.urlopen')
def test_download_data_input_y(self, mock_urlopen, input):
'''
Test the download_data function, when the user input is 'y' to 'are you sure you want to download data'
The zipped byte string is mocked using a temporary local zip file
'''

with NamedTemporaryFile(suffix = '.zip') as tf:
tmp_path = os.path.dirname(tf.name)
tmp_dir = os.path.splitext(os.path.basename(tf.name))[0]
with ZipFile(tf.name, mode='w') as zip_file:
zip_file.write(self.shapes_path, arcname=dataexample.TestData.SHAPES)

with open(tf.name, 'rb') as zip_file:
zipped_bytes = zip_file.read()

self.mock_urlopen(mock_urlopen, zipped_bytes)

# redirect print output
capturedOutput = StringIO()
sys.stdout = capturedOutput
# Test the zip file IS created with prompt=False i.e. prompt not used
dataexample.WALNUT.download_data(tmp_dir, prompt=False)
# Test the data file exists
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.FOLDER, dataexample.TestData.SHAPES)),
msg = "Download data test failed with dataset " + data)
# Test the zip file is removed
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, dataexample.WALNUT.ZIP_FILE)))

for data in self.data_list:
test_func = getattr(dataexample, data)
test_func.download_data(os.path.join(tmp_path, tmp_dir))
self.assertTrue(os.path.isfile(os.path.join(tmp_path, tmp_dir, test_func.FOLDER, dataexample.TestData.SHAPES)), msg = "Failed with dataset " + data)

# return to standard print output
sys.stdout = sys.__stdout__

if os.path.exists(os.path.join(tmp_path,tmp_dir)):
shutil.rmtree(os.path.join(tmp_path,tmp_dir))
shutil.rmtree(tmp_dir)


def test_download_data_bad_URL(self):
@patch('cil.utilities.dataexample.input', return_value='y')
def test_download_data_empty(self, input):
'''
Test an error is raised when _download_and_extract_from_url has an empty URL
Test an error is raised when download_data is used on an empty Zenodo record
'''
remote_data = dataexample.REMOTEDATA
remote_data.ZENODO_RECORD = 'empty'
remote_data.FOLDER = 'empty'

with self.assertRaises(ValueError):
dataexample.REMOTEDATA._download_and_extract_from_url('.')
remote_data.download_data('.')

def test_a(self):
from cil.utilities.dataexample import WALNUT

Loading
Loading