diff --git a/kauldron/data/kmix/loaders/with_shuffle_buffer.py b/kauldron/data/kmix/loaders/with_shuffle_buffer.py index 43425219..5d55fa92 100644 --- a/kauldron/data/kmix/loaders/with_shuffle_buffer.py +++ b/kauldron/data/kmix/loaders/with_shuffle_buffer.py @@ -15,10 +15,14 @@ """TFDS dataset loader.""" import dataclasses -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Sequence +from grain._src.tensorflow import transforms as grain_transforms +import grain.tensorflow as grain import jax from kauldron import kd +from kauldron import random +from kauldron.data import grain_utils from kauldron.data.kmix import base import tensorflow as tf @@ -32,6 +36,7 @@ class WithShuffleBuffer(base.TFDataPipeline): object to fully reset the iterator. Attributes: + transforms_before_cache: Data transforms to apply before caching. cache: Whether to cache the dataset. shuffle: Whether to shuffle the dataset. shuffle_buffer_size: Size of the shuffle buffer. @@ -39,8 +44,10 @@ class WithShuffleBuffer(base.TFDataPipeline): iteration). """ - # TODO(epot): Could also add a `transform_before_cache` to allow - # filter/resizing + transforms_before_cache: ( + Sequence[grain.Transformation] | dict[str, grain.Transformation] + ) = dataclasses.field(default_factory=tuple) + cache: bool = False shuffle: bool = True shuffle_buffer_size: Optional[int] = None @@ -48,7 +55,28 @@ class WithShuffleBuffer(base.TFDataPipeline): _supports_symbolic_checkpoint: ClassVar[bool] = False + def _maybe_apply_pre_cache_transforms( + self, ds: tf.data.Dataset, *, rng: random.PRNGKey + ) -> tf.data.Dataset: + """Applies transforms specified for application before caching.""" + if self.transforms_before_cache: + ds = grain_utils.maybe_add_grain_meta_features( + ds, + rng=rng, + ) + transforms = [] + if isinstance(self.transforms_before_cache, dict): + transforms.extend(self.transforms_before_cache.values()) + else: + transforms.extend(self.transforms_before_cache) + ds = grain_transforms.apply_transformations( + ds, self.transforms_before_cache, strict=True + ) + return ds + def transform_ds(self, ds, *, rng: kd.random.PRNGKey) -> tf.data.Dataset: + self._maybe_apply_pre_cache_transforms(ds, rng=rng) + if self.cache: ds = ds.cache()