14
14
import torch .utils .data
15
15
from torch .utils .data .distributed import DistributedSampler
16
16
17
- from composer .utils import dist , ensure_tuple
17
+ from composer .utils import VersionedDeprecationWarning , dist , ensure_tuple
18
18
19
19
if TYPE_CHECKING :
20
20
from composer .core .types import Batch
@@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ
126
126
class DataSpec :
127
127
"""Specifications for operating and training on data.
128
128
129
- An example of constructing a :class:`DataSpec` object with a ``device_transforms ``
129
+ An example of constructing a :class:`DataSpec` object with a ``batch_transforms ``
130
130
callable and then using it with :class:`~.Trainer`:
131
131
132
132
.. doctest::
133
133
134
134
>>> # Construct DataSpec and subtract mean from the batch
135
- >>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
136
- >>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn )
135
+ >>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
136
+ >>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn )
137
137
>>> # The same function can be used for eval dataloader as well
138
- >>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn )
138
+ >>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn )
139
139
>>> # Use this DataSpec object to construct trainer
140
140
>>> trainer = Trainer(
141
141
... model=model,
@@ -155,11 +155,20 @@ class DataSpec:
155
155
num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
156
156
:class:`.Timestamp` (training progress tracker).
157
157
158
- device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
159
- batch once it has been moved onto the device. For example, this function can be used for GPU-based
158
+ device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch
159
+ level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target
160
+ device.
161
+
162
+ batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
163
+ batch before it is moved onto the device. For example, this function can be used for CPU-based
160
164
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
161
165
the batch is not modified.
162
166
167
+ microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
168
+ microbatch before it is moved onto the device. For example, this function can be used for GPU-based
169
+ normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
170
+ specified, the microbatch is not modified.
171
+
163
172
split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
164
173
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
165
174
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
@@ -186,13 +195,32 @@ def __init__(
186
195
num_samples : Optional [int ] = None ,
187
196
num_tokens : Optional [int ] = None ,
188
197
device_transforms : Optional [Callable [[Batch ], Batch ]] = None ,
198
+ batch_transforms : Optional [Callable [[Batch ], Batch ]] = None ,
199
+ microbatch_transforms : Optional [Callable [[Batch ], Batch ]] = None ,
189
200
split_batch : Optional [Callable [[Batch , Union [int , float ]], Sequence [Batch ]]] = None ,
190
201
get_num_samples_in_batch : Optional [Callable [[Batch ], Union [int , float ]]] = None ,
191
202
get_num_tokens_in_batch : Optional [Callable [[Batch ], Union [int , dict [str , int ]]]] = None ,
192
203
) -> None :
193
204
self .dataloader : Union [Iterable , torch .utils .data .DataLoader ] = dataloader
194
205
self .num_tokens = num_tokens
195
- self .device_transforms = self ._default_device_transforms if device_transforms is None else device_transforms
206
+ if device_transforms is not None :
207
+ if batch_transforms is not None :
208
+ raise ValueError (
209
+ 'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for '
210
+ 'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations '
211
+ 'on target device.' ,
212
+ )
213
+ warnings .warn (
214
+ VersionedDeprecationWarning (
215
+ 'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level '
216
+ 'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target '
217
+ 'device.' ,
218
+ 'v0.29.0' ,
219
+ ),
220
+ )
221
+ self .batch_transforms = device_transforms
222
+ self .batch_transforms = self ._default_transforms if batch_transforms is None else batch_transforms
223
+ self .microbatch_transforms = self ._default_transforms if microbatch_transforms is None else microbatch_transforms
196
224
self .split_batch = default_split_batch if split_batch is None else split_batch
197
225
self .get_num_samples_in_batch = self ._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
198
226
self ._get_num_tokens_in_batch = self ._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
@@ -242,7 +270,7 @@ def __init__(
242
270
'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.' ,
243
271
)
244
272
245
- def _default_device_transforms (self , batch : Batch ):
273
+ def _default_transforms (self , batch : Batch ):
246
274
return batch
247
275
248
276
def _default_get_num_samples_in_batch (self , batch : Batch ) -> int :
0 commit comments