diff --git a/elegy/data/dataset.py b/elegy/data/dataset.py index 6f4cb38c..373b59c2 100644 --- a/elegy/data/dataset.py +++ b/elegy/data/dataset.py @@ -1,5 +1,5 @@ import numpy as np -import jax.numpy as jnp +import jax, jax.numpy as jnp import multiprocessing.pool import typing as tp from .data_adapter import DataAdapter @@ -42,12 +42,20 @@ def __len__(self) -> int: """Abstract method. In a subclass this should return the number of data samples in the dataset.""" raise NotImplementedError + def batch_fn( + self, list_of_samples: tp.List[tp.Any] + ) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray]]: + """Used by DataLoader to group a list of individual samples into a batch. + By default tries to stack elements in the samples according to their positiion. + Can be overridden for more complex use cases. + """ + return default_batch_fn(list_of_samples) + class DataLoader: """Loads samples from a dataset and combines them into batches. Can be directly passed to `Model.fit()`""" # +_example_usage_docstring # TODO: __getitem__ incl slicing e.g. [:5] - # TODO: custom batch_fn parameter # TODO: n_workers='auto' # TODO: timeout parameter @@ -126,18 +134,13 @@ def default_batch_fn( ) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray]]: """Batches individual data samples.""" assert len(list_of_samples) > 0 - first_sample = list_of_samples[0] - if hasattr(first_sample, "__array__"): - return jnp.asarray(list_of_samples) - elif isinstance(first_sample, (tp.Tuple, tp.List)): - sample_len = len(first_sample) - batched_lists = [ - [sample[i] for sample in list_of_samples] for i in range(sample_len) - ] - batched_stacks = [jnp.asarray(batch) for batch in batched_lists] - return tuple(batched_stacks) - else: - return tuple(list_of_samples) + return jax.tree_multimap(lambda *x: jnp.asarray(x), *list_of_samples) + + +def get_batch_fn(ds: tp.Any) -> tp.Callable: + """Returns either the batch_fn of the argument if it has one, otherwise `default_batch_fn` + to allow arrays or datasets that don't inherit from elegy.data.Dataset""" + return getattr(ds, "batch_fn", default_batch_fn) def mainthread_data_iterator( @@ -146,12 +149,7 @@ def mainthread_data_iterator( """Generator that loads datasamples from the data set in the main thread""" for batch_of_indices in batched_indices: samples = list(map(ds.__getitem__, batch_of_indices)) - yield default_batch_fn(samples) - - -def data_transfer_fn(async_map_result, timeout=10): - samples = async_map_result.get(timeout) - return default_batch_fn(samples) + yield get_batch_fn(ds)(samples) class WorkerContext: @@ -229,7 +227,9 @@ def __next__(self): def dispatch_tasks(self, batch_of_indices): async_x = self.worker_pool.map_async(WorkerContext.get_sample, batch_of_indices) - async_x = self.data_transfer_worker.apply_async(data_transfer_fn, (async_x,)) + async_x = self.data_transfer_worker.apply_async( + self.data_transfer_fn, (async_x,) + ) self.async_results_queue.append(async_x) def shutdown(self): @@ -244,6 +244,13 @@ def shutdown(self): def __del__(self): self.shutdown() + def data_transfer_fn(self, async_map_result, timeout=10): + samples = async_map_result.get(timeout) + batch = get_batch_fn(self.ds)(samples) + # make sure the batch is transferred to the device + batch = jax.tree_map(jnp.asarray, batch) + return batch + class DataLoaderAdapter(DataAdapter): @staticmethod diff --git a/elegy/data/dataset_test.py b/elegy/data/dataset_test.py index 638618b6..53931232 100644 --- a/elegy/data/dataset_test.py +++ b/elegy/data/dataset_test.py @@ -158,6 +158,23 @@ def test_prefetch(self): batches = list(loader) assert len(loader) == len(batches) + def test_custom_batch_fn(self): + ds = DS_custom_batch_fn() + loader = elegy.data.DataLoader(ds, batch_size=3) + batches = list(loader) + assert len(loader) == len(batches) + assert batches[0]["a"].shape == (3, 10) + assert batches[0]["b"].shape == (3,) + assert batches[0]["c"] == "banana" + assert np.all(batches[0]["b"] == np.array([0, 1, 2])) + + def test_loader_from_array(self): + pseudo_ds = np.arange(65) + loader = elegy.data.DataLoader(pseudo_ds, batch_size=10) + batches = list(loader) + assert len(batches) == 7 + assert np.all(batches[1] == np.arange(10, 20)) + class DS0(elegy.data.Dataset): def __len__(self): @@ -165,3 +182,16 @@ def __len__(self): def __getitem__(self, i): return np.zeros([100, 200, 3]), np.arange(20)[i] + + +class DS_custom_batch_fn(elegy.data.Dataset): + def __len__(self): + return 11 + + def __getitem__(self, i): + return dict(a=np.random.random(size=10), b=i) + + def batch_fn(self, list_of_samples): + x = super().batch_fn(list_of_samples) + x.update(c="banana") + return x