Skip to content

Commit

Permalink
better batch_fn and custom batch_fn (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g committed Feb 1, 2021
1 parent 59acb3b commit e8f4884
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 21 deletions.
49 changes: 28 additions & 21 deletions elegy/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import jax.numpy as jnp
import jax, jax.numpy as jnp
import multiprocessing.pool
import typing as tp
from .data_adapter import DataAdapter
Expand Down Expand Up @@ -42,12 +42,20 @@ def __len__(self) -> int:
"""Abstract method. In a subclass this should return the number of data samples in the dataset."""
raise NotImplementedError

def batch_fn(
self, list_of_samples: tp.List[tp.Any]
) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray]]:
"""Used by DataLoader to group a list of individual samples into a batch.
By default tries to stack elements in the samples according to their positiion.
Can be overridden for more complex use cases.
"""
return default_batch_fn(list_of_samples)


class DataLoader:
"""Loads samples from a dataset and combines them into batches. Can be directly passed to `Model.fit()`""" # +_example_usage_docstring

# TODO: __getitem__ incl slicing e.g. [:5]
# TODO: custom batch_fn parameter
# TODO: n_workers='auto'
# TODO: timeout parameter

Expand Down Expand Up @@ -126,18 +134,13 @@ def default_batch_fn(
) -> tp.Union[jnp.ndarray, tp.Tuple[jnp.ndarray]]:
"""Batches individual data samples."""
assert len(list_of_samples) > 0
first_sample = list_of_samples[0]
if hasattr(first_sample, "__array__"):
return jnp.asarray(list_of_samples)
elif isinstance(first_sample, (tp.Tuple, tp.List)):
sample_len = len(first_sample)
batched_lists = [
[sample[i] for sample in list_of_samples] for i in range(sample_len)
]
batched_stacks = [jnp.asarray(batch) for batch in batched_lists]
return tuple(batched_stacks)
else:
return tuple(list_of_samples)
return jax.tree_multimap(lambda *x: jnp.asarray(x), *list_of_samples)


def get_batch_fn(ds: tp.Any) -> tp.Callable:
"""Returns either the batch_fn of the argument if it has one, otherwise `default_batch_fn`
to allow arrays or datasets that don't inherit from elegy.data.Dataset"""
return getattr(ds, "batch_fn", default_batch_fn)


def mainthread_data_iterator(
Expand All @@ -146,12 +149,7 @@ def mainthread_data_iterator(
"""Generator that loads datasamples from the data set in the main thread"""
for batch_of_indices in batched_indices:
samples = list(map(ds.__getitem__, batch_of_indices))
yield default_batch_fn(samples)


def data_transfer_fn(async_map_result, timeout=10):
samples = async_map_result.get(timeout)
return default_batch_fn(samples)
yield get_batch_fn(ds)(samples)


class WorkerContext:
Expand Down Expand Up @@ -229,7 +227,9 @@ def __next__(self):

def dispatch_tasks(self, batch_of_indices):
async_x = self.worker_pool.map_async(WorkerContext.get_sample, batch_of_indices)
async_x = self.data_transfer_worker.apply_async(data_transfer_fn, (async_x,))
async_x = self.data_transfer_worker.apply_async(
self.data_transfer_fn, (async_x,)
)
self.async_results_queue.append(async_x)

def shutdown(self):
Expand All @@ -244,6 +244,13 @@ def shutdown(self):
def __del__(self):
self.shutdown()

def data_transfer_fn(self, async_map_result, timeout=10):
samples = async_map_result.get(timeout)
batch = get_batch_fn(self.ds)(samples)
# make sure the batch is transferred to the device
batch = jax.tree_map(jnp.asarray, batch)
return batch


class DataLoaderAdapter(DataAdapter):
@staticmethod
Expand Down
30 changes: 30 additions & 0 deletions elegy/data/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,40 @@ def test_prefetch(self):
batches = list(loader)
assert len(loader) == len(batches)

def test_custom_batch_fn(self):
ds = DS_custom_batch_fn()
loader = elegy.data.DataLoader(ds, batch_size=3)
batches = list(loader)
assert len(loader) == len(batches)
assert batches[0]["a"].shape == (3, 10)
assert batches[0]["b"].shape == (3,)
assert batches[0]["c"] == "banana"
assert np.all(batches[0]["b"] == np.array([0, 1, 2]))

def test_loader_from_array(self):
pseudo_ds = np.arange(65)
loader = elegy.data.DataLoader(pseudo_ds, batch_size=10)
batches = list(loader)
assert len(batches) == 7
assert np.all(batches[1] == np.arange(10, 20))


class DS0(elegy.data.Dataset):
def __len__(self):
return 11

def __getitem__(self, i):
return np.zeros([100, 200, 3]), np.arange(20)[i]


class DS_custom_batch_fn(elegy.data.Dataset):
def __len__(self):
return 11

def __getitem__(self, i):
return dict(a=np.random.random(size=10), b=i)

def batch_fn(self, list_of_samples):
x = super().batch_fn(list_of_samples)
x.update(c="banana")
return x

0 comments on commit e8f4884

Please sign in to comment.