Skip to content

Commit

Permalink
Add option to accumulate train loss over tokens. (#3273)
Browse files Browse the repository at this point in the history
  • Loading branch information
aadyotb authored Jun 6, 2024
1 parent 4b46f64 commit f039f06
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 5 deletions.
30 changes: 26 additions & 4 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,8 @@ class Trainer:
.. note:: This is implemented by taking the batch yielded by the ``train_dataloader`` and splitting
it into sections of size ``device_train_microbatch_size``. If the batch size of the dataloader
is not divisible by ``device_train_microbatch_size``, the last section will be potentially smaller.
accumulate_train_batch_on_tokens (bool, optional): Whether training loss is accumulated over the number of tokens in a batch,
rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. (default: ``False``)
seed (int, optional): The seed used in randomization. If ``None``, then a random seed
will be created. (default: ``None``)
Expand Down Expand Up @@ -1080,6 +1082,7 @@ def __init__(
precision: Optional[Union[str, Precision]] = None,
precision_config: Optional[dict[str, Any]] = None,
device_train_microbatch_size: Optional[Union[int, float, str]] = None,
accumulate_train_batch_on_tokens: bool = False,

# Reproducibility
seed: Optional[int] = None,
Expand Down Expand Up @@ -1314,6 +1317,7 @@ def __init__(
deepspeed_config=deepspeed_config,
parallelism_config=parallelism_config,
)
self.accumulate_train_batch_on_tokens = accumulate_train_batch_on_tokens

# Console Logging
loggers = list(ensure_tuple(loggers))
Expand Down Expand Up @@ -2848,7 +2852,22 @@ def _train_microbatches(
optimizer.zero_grad()

# Tracker for gradient accumulation
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])
if self.accumulate_train_batch_on_tokens:
current_batch_size = sum([self._train_data_spec.get_num_tokens_in_batch(b) for b in microbatches])
if current_batch_size == 0:
raise ValueError(
textwrap.dedent(
'Requested loss accumulation based on number of tokens in training batch, '
'but zero tokens found (perhaps due to an improper DataSpec).',
),
)
else:
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(b) for b in microbatches])
# Average the current batch size across ranks, to ensure each rank contributes appropriately
current_batch_size = self.state.device.tensor_to_device(torch.tensor(current_batch_size))
dist.all_reduce(current_batch_size, reduce_operation='SUM')
current_batch_size = current_batch_size.item() / dist.get_world_size()

# Cache batch, which will be overwritten by microbatches. Restore after microbatches complete
current_batch = self.state.batch

Expand Down Expand Up @@ -2895,7 +2914,10 @@ def _train_microbatch(
# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop
device_batch = deepcopy(self.state.batch)

microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.accumulate_train_batch_on_tokens:
microbatch_size = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
else:
microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel):
sync_context = contextlib.nullcontext()
elif self.state.auto_microbatching and not self.first_batch_complete:
Expand Down Expand Up @@ -2985,7 +3007,7 @@ def _train_microbatch(

# For each loss to log: detach, clone, mean, then multiply by (microbatch size) / (batch size)
for k, loss in microbatch_loss_dict.items():
microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_num_samples / current_batch_size)
microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_size / current_batch_size)

if use_grad_scaling:
microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) # type: ignore
Expand All @@ -2994,7 +3016,7 @@ def _train_microbatch(
self.state.deepspeed_model.backward(microbatch_loss)
else:
# Scale loss based on the number of samples in the microbatch to maintain gradient numerics
microbatch_loss.mul_(microbatch_num_samples / current_batch_size)
microbatch_loss.mul_(microbatch_size / current_batch_size)
microbatch_loss.backward(create_graph=self._backwards_create_graph)

if self.state.device.dist_backend == 'xla':
Expand Down
94 changes: 93 additions & 1 deletion tests/test_simple_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pytest
from torch.utils.data import DataLoader

from composer.core import DataSpec
from composer.trainer import Trainer
from composer.utils import dist
from composer.utils import dist, get_device
from tests.common import device
from tests.common.datasets import RandomTextClassificationDataset, RandomTextLMDataset
from tests.common.models import SimpleTransformerClassifier, SimpleTransformerMaskedLM

Expand Down Expand Up @@ -125,3 +127,93 @@ def test_simple_nlp_mlm(tiny_bert_tokenizer, tiny_bert_model):
num_predict_batches_expected = ((size - 1) // batch_size) + 1
assert len(predictions) == num_predict_batches_expected
assert predictions[0].shape == (batch_size, sequence_length, vocab_size)


@device('gpu')
def test_simple_nlp_mlm_token_batch(tiny_bert_tokenizer, device):
transformers = pytest.importorskip('transformers')

vocab_size = tiny_bert_tokenizer.vocab_size
sequence_length = 32
size = 96
batch_size = 8
device = get_device(device)

train_dataset = RandomTextLMDataset(
size=size,
vocab_size=vocab_size,
sequence_length=sequence_length,
use_keys=True,
pad_token_id=tiny_bert_tokenizer.pad_token_id,
)
for i in range(size): # Proactively load dataset for consistent randomization
train_dataset[i]
collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer)

# Get the model's state dict before training starts, so we can reproduce results
model = SimpleTransformerMaskedLM(vocab_size=vocab_size)
state_dict = model.state_dict()

# Set up the data spec that can count the non-padding tokens in a batch
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=dist.get_sampler(train_dataset),
collate_fn=collator,
)
data_spec = DataSpec(
dataloader=train_dataloader,
get_num_tokens_in_batch=lambda b: (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(),
)

trainer = Trainer(
model=model,
seed=42,
train_dataloader=data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=False,
device=device,
)
trainer.fit()

# Check that there is some train cross entropy
assert trainer.state.train_metrics is not None
cross_entropy = trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert cross_entropy != 0.0

# Set up a trainer that accumulates train loss based on token counts, after reloading original state dict
model.load_state_dict(state_dict)
token_trainer = Trainer(
model=model,
seed=42,
train_dataloader=data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=True,
device=device,
)
token_trainer.fit()

# Check that there is some train cross entropy
assert token_trainer.state.train_metrics is not None
token_cross_entropy = token_trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert token_cross_entropy != 0.0

# Require that the train cross entropies are different between the trainers
assert cross_entropy != token_cross_entropy

# Make sure we can reproduce the original cross entropy calculation
model.load_state_dict(state_dict)
trainer2 = Trainer(
model=model,
seed=42,
train_dataloader=data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=False,
device=device,
)
trainer2.fit()
assert trainer2.state.train_metrics is not None
assert trainer2.state.train_metrics['LanguageCrossEntropy'].compute() == cross_entropy
1 change: 1 addition & 0 deletions tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def test_extract_hparams_trainer():
'precision': 'fp32',
'precision_config': None,
'device_train_microbatch_size': 16,
'accumulate_train_batch_on_tokens': False,

# Reproducibility
'seed': 3,
Expand Down

0 comments on commit f039f06

Please sign in to comment.