Skip to content

Commit

Permalink
added tiled dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
taxe10 committed Jan 26, 2024
1 parent 2d1c914 commit cc0de57
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/helper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import tensorflow as tf
import tensorflow_io as tfio

from tiled_dataloader import CustomTiledDataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
SPLASH_URL = 'http://splash:80/api/v0'

Expand Down Expand Up @@ -46,17 +48,17 @@ def get_dataset(data, shuffle=False, event_id = None, seed=42):
'''
# Retrieve data set list
data_info = pd.read_parquet(data, engine='pyarrow')
if 'local_uri' in data_info:
uri_list = data_info['local_uri']
else:
uri_list = data_info['uri']
if uri_list[0].split('.')[-1] in ['tif', 'tiff', 'TIF', 'TIFF']:
tif = True
else:
tif = False
# Retrieve labels
if event_id:
labeled_uris, labels = load_from_splash(uri_list.tolist(), event_id)
if 'local_uri' in data_info:
uri_list = data_info['local_uri']
splash_uri_list = data_info['uri']
splash_labeled_uris, labels = load_from_splash(splash_uri_list.tolist(), event_id)
labeled_uris = data_info[data_info['uri'].isin(splash_labeled_uris)]
labeled_uris = list(labeled_uris['local_uri'])
else:
uri_list = data_info['uri']
labeled_uris, labels = load_from_splash(uri_list.tolist(), event_id)
classes = list(set(labels))
df_labels = pd.DataFrame(labels).replace({class_name: label for label, class_name in
enumerate(classes)})
Expand All @@ -66,9 +68,18 @@ def get_dataset(data, shuffle=False, event_id = None, seed=42):
dataset = tf.data.Dataset.from_tensor_slices((labeled_uris, categorical_labels))
kwargs = classes
else:
uri_list = data_info['uri']
kwargs = uri_list.to_list()
dataset = tf.data.Dataset.from_tensor_slices(uri_list)
num_imgs = len(uri_list)
if data_info['type'][0] == 'tiled':
dataset = CustomTiledDataset(uri_list, log=False)
else:
dataset = tf.data.Dataset.from_tensor_slices(uri_list)
num_imgs = len(uri_list)
# Check if data is in tif format
if uri_list[0].split('.')[-1] in ['tif', 'tiff', 'TIF', 'TIFF'] or data_info['type'][0] == 'tiled':
tif = True
else:
tif = False
# Shuffle data
if shuffle:
dataset.shuffle(seed=seed, buffer_size=num_imgs)
Expand Down
54 changes: 54 additions & 0 deletions src/tiled_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import tensorflow as tf
import requests


class CustomTiledDataset(tf.data.Dataset):
def __init__(self, uri_list, log):
self.uri_list = uri_list
self.log = log
self.dataset = tf.data.Dataset.from_tensor_slices(self.uri_list)

@staticmethod
def _get_tiled_response(tiled_uri, expected_shape, max_tries=5):
'''
Get response from tiled URI
Args:
tiled_uri: Tiled URI from which data should be retrieved
max_tries: Maximum number of tries to retrieve data, defaults to 5
Returns:
Response content
'''
status_code = 502
trials = 0
while status_code != 200 and trials < max_tries:
if len(expected_shape) == 3:
response = requests.get(f'{tiled_uri},0,:,:&format=png')
else:
response = requests.get(f'{tiled_uri},:,:&format=png')
status_code = response.status_code
trials += 1
if status_code != 200:
raise Exception(f'Failed to retrieve data from {tiled_uri}')
return response.content

def _parse_function(self, uri):
tiled_uri, metadata = tf.strings.split(uri, '&expected_shape=')
expected_shape = tf.strings.split(metadata, '&dtype=')[0]
expected_shape = tf.strings.split(expected_shape, '%2C')
expected_shape = tf.strings.to_number(expected_shape, out_type=tf.int32)
expected_shape = tf.cond(tf.equal(tf.shape(expected_shape)[0], 3) & tf.reduce_any(tf.equal(expected_shape[0], [1,3,4])),
lambda: expected_shape[[1,2,0]],
lambda: expected_shape)
contents = self._get_tiled_response(tiled_uri, expected_shape, max_tries=5)
image = tf.io.decode_image(contents, channels=1)
if self.log:
image = tf.math.log1p(tf.cast(image, tf.float32))
image = ((image - tf.reduce_min(image)) / (tf.reduce_max(image) - tf.reduce_min(image))) * 255
image = tf.cast(image, tf.uint8)
image = tf.cast(image, tf.float32) / 255.0
return image

def __new__(cls, uri_list, log):
instance = super(CustomTiledDataset, cls).__new__(cls)
instance.__init__(uri_list, log)
return tf.data.Dataset.map(instance.dataset, instance._parse_function)

0 comments on commit cc0de57

Please sign in to comment.