Skip to content

Commit bba680d

Browse files
authored
Merge branch 'main' into bump_version_0.27.0.dev0
2 parents 87d9b38 + 18da725 commit bba680d

File tree

4 files changed

+70
-17
lines changed

4 files changed

+70
-17
lines changed

composer/core/data_spec.py

+37-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.utils.data
1515
from torch.utils.data.distributed import DistributedSampler
1616

17-
from composer.utils import dist, ensure_tuple
17+
from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple
1818

1919
if TYPE_CHECKING:
2020
from composer.core.types import Batch
@@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ
126126
class DataSpec:
127127
"""Specifications for operating and training on data.
128128
129-
An example of constructing a :class:`DataSpec` object with a ``device_transforms``
129+
An example of constructing a :class:`DataSpec` object with a ``batch_transforms``
130130
callable and then using it with :class:`~.Trainer`:
131131
132132
.. doctest::
133133
134134
>>> # Construct DataSpec and subtract mean from the batch
135-
>>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
136-
>>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn)
135+
>>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
136+
>>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn)
137137
>>> # The same function can be used for eval dataloader as well
138-
>>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn)
138+
>>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn)
139139
>>> # Use this DataSpec object to construct trainer
140140
>>> trainer = Trainer(
141141
... model=model,
@@ -155,11 +155,20 @@ class DataSpec:
155155
num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
156156
:class:`.Timestamp` (training progress tracker).
157157
158-
device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
159-
batch once it has been moved onto the device. For example, this function can be used for GPU-based
158+
device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch
159+
level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target
160+
device.
161+
162+
batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
163+
batch before it is moved onto the device. For example, this function can be used for CPU-based
160164
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
161165
the batch is not modified.
162166
167+
microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
168+
microbatch before it is moved onto the device. For example, this function can be used for GPU-based
169+
normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
170+
specified, the microbatch is not modified.
171+
163172
split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
164173
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
165174
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
@@ -186,13 +195,32 @@ def __init__(
186195
num_samples: Optional[int] = None,
187196
num_tokens: Optional[int] = None,
188197
device_transforms: Optional[Callable[[Batch], Batch]] = None,
198+
batch_transforms: Optional[Callable[[Batch], Batch]] = None,
199+
microbatch_transforms: Optional[Callable[[Batch], Batch]] = None,
189200
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
190201
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
191202
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
192203
) -> None:
193204
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
194205
self.num_tokens = num_tokens
195-
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
206+
if device_transforms is not None:
207+
if batch_transforms is not None:
208+
raise ValueError(
209+
'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for '
210+
'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations '
211+
'on target device.',
212+
)
213+
warnings.warn(
214+
VersionedDeprecationWarning(
215+
'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level '
216+
'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target '
217+
'device.',
218+
'v0.29.0',
219+
),
220+
)
221+
self.batch_transforms = device_transforms
222+
self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms
223+
self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms
196224
self.split_batch = default_split_batch if split_batch is None else split_batch
197225
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
198226
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
@@ -242,7 +270,7 @@ def __init__(
242270
'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.',
243271
)
244272

245-
def _default_device_transforms(self, batch: Batch):
273+
def _default_transforms(self, batch: Batch):
246274
return batch
247275

248276
def _default_get_num_samples_in_batch(self, batch: Batch) -> int:

composer/trainer/trainer.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -2622,7 +2622,7 @@ def _train_loop(self) -> None:
26222622
self._rng_state = None
26232623
continue
26242624

2625-
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
2625+
self.state.batch = self._train_data_spec.batch_transforms(self.state.batch)
26262626
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
26272627
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
26282628

@@ -3034,6 +3034,7 @@ def _train_microbatches(
30343034

30353035
for microbatch_idx, self.state.batch in enumerate(microbatches):
30363036
self.state.batch = self.state.device.batch_to_device(self.state.batch)
3037+
self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch)
30373038
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
30383039
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
30393040

@@ -3306,11 +3307,11 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
33063307
self.engine.run_event(Event.PREDICT_START)
33073308

33083309
for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT):
3310+
33093311
# Move the batch onto the device
3312+
self.state.batch = data_spec.batch_transforms(self.state.batch)
33103313
self.state.batch = self.state.device.batch_to_device(self.state.batch)
3311-
3312-
# Perform any device transforms
3313-
self.state.batch = data_spec.device_transforms(self.state.batch)
3314+
self.state.batch = data_spec.microbatch_transforms(self.state.batch)
33143315

33153316
# Count the batch size and num tokens before any events run
33163317
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
@@ -3586,7 +3587,7 @@ def _eval_loop(
35863587
)
35873588

35883589
for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
3589-
self.state.batch = data_spec.device_transforms(self.state.batch)
3590+
self.state.batch = data_spec.batch_transforms(self.state.batch)
35903591

35913592
# Count the batch size and num tokens before any events run
35923593
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
@@ -3616,6 +3617,7 @@ def _eval_loop(
36163617
microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size)
36173618
for i, self.state.batch in enumerate(microbatches):
36183619
self.state.batch = self.state.device.batch_to_device(self.state.batch)
3620+
self.state.batch = data_spec.microbatch_transforms(self.state.batch)
36193621
last_microbatch = i == len(microbatches) - 1
36203622
skip_metric_update = False
36213623
# Distributed samplers pad batches to be the same size. If using a

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def package_files(prefix: str, directory: str, extension: str):
7878
install_requires = [
7979
'pyyaml>=6.0,<7',
8080
'tqdm>=4.62.3,<5',
81-
'torchmetrics>=1.0,<1.4.1',
81+
'torchmetrics>=1.0,<1.5.3',
8282
'torch_optimizer>=0.3.0,<0.4',
8383
'torchvision>=0.18.0,<0.20.2',
8484
'torch>=2.3.0,<2.5.2',
@@ -88,7 +88,7 @@ def package_files(prefix: str, directory: str, extension: str):
8888
'coolname>=1.1.0,<3',
8989
'tabulate==0.9.0', # for auto-generating tables
9090
'py-cpuinfo>=8.0.0,<10',
91-
'packaging>=21.3.0,<24.2',
91+
'packaging>=21.3.0,<24.3',
9292
'importlib-metadata>=5.0.0,<9',
9393
'mosaicml-cli>=0.5.25,<0.7',
9494
'pillow>=10.3.0,<12',

tests/trainer/test_trainer.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from composer import Callback, Evaluator, Trainer
2121
from composer.algorithms import CutOut, LabelSmoothing
22-
from composer.core import Event, Precision, State, Time, TimeUnit
22+
from composer.core import DataSpec, Event, Precision, State, Time, TimeUnit
2323
from composer.devices import Device
2424
from composer.loggers import InMemoryLogger, Logger, RemoteUploaderDownloader
2525
from composer.loss import soft_cross_entropy
@@ -1733,3 +1733,26 @@ def test_empty_eval_dataloader(self):
17331733
max_duration='1ba',
17341734
)
17351735
trainer.fit()
1736+
1737+
1738+
@device('cpu', 'gpu')
1739+
def test_transforms(device: str):
1740+
1741+
def get_transform(device: str):
1742+
1743+
def transform(batch: list[torch.Tensor]):
1744+
batch_device = 'gpu' if batch[0].device.type == 'cuda' else 'cpu'
1745+
assert batch_device == device
1746+
return batch
1747+
1748+
return transform
1749+
1750+
dataloader = _get_classification_dataloader()
1751+
data_spec = DataSpec(
1752+
dataloader,
1753+
batch_transforms=get_transform('cpu'),
1754+
microbatch_transforms=get_transform(device),
1755+
)
1756+
model = SimpleModel()
1757+
trainer = Trainer(model=model, train_dataloader=data_spec, max_duration='1ba')
1758+
trainer.fit()

0 commit comments

Comments
 (0)