diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index eb5080eaee..c680d1d3d7 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -946,6 +946,8 @@ class Trainer: .. note:: This is implemented by taking the batch yielded by the ``train_dataloader`` and splitting it into sections of size ``device_train_microbatch_size``. If the batch size of the dataloader is not divisible by ``device_train_microbatch_size``, the last section will be potentially smaller. + accumulate_train_batch_on_tokens (bool, optional): Whether training loss is accumulated over the number of tokens in a batch, + rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. (default: ``False``) seed (int, optional): The seed used in randomization. If ``None``, then a random seed will be created. (default: ``None``) @@ -1080,6 +1082,7 @@ def __init__( precision: Optional[Union[str, Precision]] = None, precision_config: Optional[dict[str, Any]] = None, device_train_microbatch_size: Optional[Union[int, float, str]] = None, + accumulate_train_batch_on_tokens: bool = False, # Reproducibility seed: Optional[int] = None, @@ -1314,6 +1317,7 @@ def __init__( deepspeed_config=deepspeed_config, parallelism_config=parallelism_config, ) + self.accumulate_train_batch_on_tokens = accumulate_train_batch_on_tokens # Console Logging loggers = list(ensure_tuple(loggers)) @@ -2848,7 +2852,22 @@ def _train_microbatches( optimizer.zero_grad() # Tracker for gradient accumulation - current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches]) + if self.accumulate_train_batch_on_tokens: + current_batch_size = sum([self._train_data_spec.get_num_tokens_in_batch(b) for b in microbatches]) + if current_batch_size == 0: + raise ValueError( + textwrap.dedent( + 'Requested loss accumulation based on number of tokens in training batch, ' + 'but zero tokens found (perhaps due to an improper DataSpec).', + ), + ) + else: + current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(b) for b in microbatches]) + # Average the current batch size across ranks, to ensure each rank contributes appropriately + current_batch_size = self.state.device.tensor_to_device(torch.tensor(current_batch_size)) + dist.all_reduce(current_batch_size, reduce_operation='SUM') + current_batch_size = current_batch_size.item() / dist.get_world_size() + # Cache batch, which will be overwritten by microbatches. Restore after microbatches complete current_batch = self.state.batch @@ -2895,7 +2914,10 @@ def _train_microbatch( # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop device_batch = deepcopy(self.state.batch) - microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) + if self.accumulate_train_batch_on_tokens: + microbatch_size = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) + else: + microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch) if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel): sync_context = contextlib.nullcontext() elif self.state.auto_microbatching and not self.first_batch_complete: @@ -2985,7 +3007,7 @@ def _train_microbatch( # For each loss to log: detach, clone, mean, then multiply by (microbatch size) / (batch size) for k, loss in microbatch_loss_dict.items(): - microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_num_samples / current_batch_size) + microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_size / current_batch_size) if use_grad_scaling: microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) # type: ignore @@ -2994,7 +3016,7 @@ def _train_microbatch( self.state.deepspeed_model.backward(microbatch_loss) else: # Scale loss based on the number of samples in the microbatch to maintain gradient numerics - microbatch_loss.mul_(microbatch_num_samples / current_batch_size) + microbatch_loss.mul_(microbatch_size / current_batch_size) microbatch_loss.backward(create_graph=self._backwards_create_graph) if self.state.device.dist_backend == 'xla': diff --git a/tests/test_simple_nlp.py b/tests/test_simple_nlp.py index 19b3c3d6e5..5fd107aaa5 100644 --- a/tests/test_simple_nlp.py +++ b/tests/test_simple_nlp.py @@ -4,8 +4,10 @@ import pytest from torch.utils.data import DataLoader +from composer.core import DataSpec from composer.trainer import Trainer -from composer.utils import dist +from composer.utils import dist, get_device +from tests.common import device from tests.common.datasets import RandomTextClassificationDataset, RandomTextLMDataset from tests.common.models import SimpleTransformerClassifier, SimpleTransformerMaskedLM @@ -125,3 +127,93 @@ def test_simple_nlp_mlm(tiny_bert_tokenizer, tiny_bert_model): num_predict_batches_expected = ((size - 1) // batch_size) + 1 assert len(predictions) == num_predict_batches_expected assert predictions[0].shape == (batch_size, sequence_length, vocab_size) + + +@device('gpu') +def test_simple_nlp_mlm_token_batch(tiny_bert_tokenizer, device): + transformers = pytest.importorskip('transformers') + + vocab_size = tiny_bert_tokenizer.vocab_size + sequence_length = 32 + size = 96 + batch_size = 8 + device = get_device(device) + + train_dataset = RandomTextLMDataset( + size=size, + vocab_size=vocab_size, + sequence_length=sequence_length, + use_keys=True, + pad_token_id=tiny_bert_tokenizer.pad_token_id, + ) + for i in range(size): # Proactively load dataset for consistent randomization + train_dataset[i] + collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer) + + # Get the model's state dict before training starts, so we can reproduce results + model = SimpleTransformerMaskedLM(vocab_size=vocab_size) + state_dict = model.state_dict() + + # Set up the data spec that can count the non-padding tokens in a batch + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + sampler=dist.get_sampler(train_dataset), + collate_fn=collator, + ) + data_spec = DataSpec( + dataloader=train_dataloader, + get_num_tokens_in_batch=lambda b: (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(), + ) + + trainer = Trainer( + model=model, + seed=42, + train_dataloader=data_spec, + max_duration='2ep', + device_train_microbatch_size=batch_size // 2, + accumulate_train_batch_on_tokens=False, + device=device, + ) + trainer.fit() + + # Check that there is some train cross entropy + assert trainer.state.train_metrics is not None + cross_entropy = trainer.state.train_metrics['LanguageCrossEntropy'].compute() + assert cross_entropy != 0.0 + + # Set up a trainer that accumulates train loss based on token counts, after reloading original state dict + model.load_state_dict(state_dict) + token_trainer = Trainer( + model=model, + seed=42, + train_dataloader=data_spec, + max_duration='2ep', + device_train_microbatch_size=batch_size // 2, + accumulate_train_batch_on_tokens=True, + device=device, + ) + token_trainer.fit() + + # Check that there is some train cross entropy + assert token_trainer.state.train_metrics is not None + token_cross_entropy = token_trainer.state.train_metrics['LanguageCrossEntropy'].compute() + assert token_cross_entropy != 0.0 + + # Require that the train cross entropies are different between the trainers + assert cross_entropy != token_cross_entropy + + # Make sure we can reproduce the original cross entropy calculation + model.load_state_dict(state_dict) + trainer2 = Trainer( + model=model, + seed=42, + train_dataloader=data_spec, + max_duration='2ep', + device_train_microbatch_size=batch_size // 2, + accumulate_train_batch_on_tokens=False, + device=device, + ) + trainer2.fit() + assert trainer2.state.train_metrics is not None + assert trainer2.state.train_metrics['LanguageCrossEntropy'].compute() == cross_entropy diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index b2ece95622..c73c1bbbde 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -185,6 +185,7 @@ def test_extract_hparams_trainer(): 'precision': 'fp32', 'precision_config': None, 'device_train_microbatch_size': 16, + 'accumulate_train_batch_on_tokens': False, # Reproducibility 'seed': 3,