Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BERT Model To NeMo2.0 #11333

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
b922807
add bert in nemo2
suiyoubi Sep 4, 2024
0e0822c
Add bert loss
suiyoubi Sep 9, 2024
ccab2d5
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Nov 14, 2024
70973cd
update bert with latest nemo
suiyoubi Nov 14, 2024
686eed1
Apply isort and black reformatting
suiyoubi Nov 14, 2024
794951e
Bert update
suiyoubi Nov 18, 2024
bf91626
Apply isort and black reformatting
suiyoubi Nov 18, 2024
c2b2949
import optimize
suiyoubi Nov 18, 2024
f64457f
Apply isort and black reformatting
suiyoubi Nov 18, 2024
d70c098
pylint
suiyoubi Nov 19, 2024
27707f5
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
b91c247
Apply isort and black reformatting
suiyoubi Nov 19, 2024
146c21b
Merge remote-tracking branch 'origin/main' into aot/bert-nemo-ux
suiyoubi Nov 19, 2024
3a8841c
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
3247f36
pylint
suiyoubi Nov 19, 2024
bcc9b7c
Apply isort and black reformatting
suiyoubi Nov 19, 2024
5f19763
pylint
suiyoubi Nov 19, 2024
9bbef09
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
f42ec30
Apply isort and black reformatting
suiyoubi Nov 19, 2024
388ed8a
more comment
suiyoubi Nov 20, 2024
11a141e
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 20, 2024
8dad5a5
Merge remote-tracking branch 'origin/main' into aot/bert-nemo-ux
suiyoubi Nov 21, 2024
f00d2df
use lightning package
suiyoubi Nov 21, 2024
ba73a4d
Apply isort and black reformatting
suiyoubi Nov 21, 2024
546a646
comments resolved
suiyoubi Nov 22, 2024
2cf371d
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 22, 2024
659e29c
Add loss
suiyoubi Nov 22, 2024
2c542d0
Apply isort and black reformatting
suiyoubi Nov 22, 2024
451a1a0
Fix NaN loss when resume
suiyoubi Nov 22, 2024
c72f1be
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 22, 2024
d2bbe9c
Apply isort and black reformatting
suiyoubi Nov 22, 2024
846b236
add default for num_tokentypes
suiyoubi Nov 22, 2024
cc1893d
update forward output to be a dict
suiyoubi Nov 24, 2024
9654cfd
Apply isort and black reformatting
suiyoubi Nov 24, 2024
8658f53
ReName to HuggingFace Bert Model
suiyoubi Nov 24, 2024
cb1332d
Add BertLarge
suiyoubi Nov 26, 2024
7ae576c
Apply isort and black reformatting
suiyoubi Nov 26, 2024
3826aa3
Add BERT Tests
suiyoubi Nov 26, 2024
51692ad
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 26, 2024
ea969ed
Apply isort and black reformatting
suiyoubi Nov 26, 2024
06f03ff
typo
suiyoubi Nov 26, 2024
f61877a
add exporter
suiyoubi Nov 26, 2024
79acd82
Apply isort and black reformatting
suiyoubi Nov 26, 2024
19337e3
Fix Unit Tests
suiyoubi Nov 26, 2024
b15051d
Apply isort and black reformatting
suiyoubi Nov 26, 2024
cca4c79
add sig
suiyoubi Nov 27, 2024
db19dd1
rename pretraining dataset testing for bert
suiyoubi Nov 28, 2024
a2a859c
delete pretraining dataset testing for bert
suiyoubi Nov 28, 2024
1e9bc3f
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Nov 28, 2024
69b6128
resolve TE dependency
suiyoubi Dec 2, 2024
27b56a9
Apply isort and black reformatting
suiyoubi Dec 2, 2024
56fa3f1
resolve TE dependency
suiyoubi Dec 2, 2024
9b11d84
resolve TE dependency
suiyoubi Dec 2, 2024
4a7f608
resolve TE dependency
suiyoubi Dec 2, 2024
5acf9aa
Apply isort and black reformatting
suiyoubi Dec 2, 2024
d363a11
resolve TE dependency
suiyoubi Dec 2, 2024
c0a7836
Apply isort and black reformatting
suiyoubi Dec 2, 2024
de1c788
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
safe_import("transformer_engine")

from nemo.collections.llm import peft
from nemo.collections.llm.bert.data import BERTPreTrainingDataModule
from nemo.collections.llm.bert.model import (
BertConfig,
BertModel,
GoogleBertBaseConfig,
GoogleBertConfig,
GoogleBertModel,
)
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
DollyDataModule,
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions nemo/collections/llm/bert/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nemo.collections.llm.bert.data.pre_training import BERTPreTrainingDataModule

__all__ = ["BERTPreTrainingDataModule"]
323 changes: 323 additions & 0 deletions nemo/collections/llm/bert/data/pre_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
import logging
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import lightning.pytorch as pl
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data

from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.pytorch.plugins import MegatronDataSampler

if TYPE_CHECKING:
from megatron.core.datasets.bert_dataset import BERTMaskedWordPieceDatasetConfig

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec


class BERTPreTrainingDataModule(pl.LightningDataModule, IOMixin):
"""PyTorch Lightning-compatible data module for pre-training
BERT-style models.
Args:
paths (Path | List | Dict[str, List]): Paths of the data distributions. Can be either a
single path, a list of paths, or a dictionary. If a single path or a list of paths,
the given paths will be used to generate the train, validation and test datasets. If
providing a list of paths, the format can be either (1) a list of paths, e.g.
["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"],
or (2) a flattened, zipped list of weights and paths, e.g.
["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
If a dictionary is provided, it is expected to have the following form:
{
'train': <TRAIN PATHS>,
'validation': <VALID PATHS>,
'test': <TEST PATHS>
}
where each value is either a path or a list of paths as described above.
In this case, each split will be generated using the given paths.
Note that if limit_val_batches <= 1, we generate the entire validaton dataset, so
weights should not be provided for the validation split.
seq_length (int): Sequence length.
tokenizer (Optional["TokenizerSpec"]): An instance of a TokenizerSpec object.
micro_batch_size (int): Batch size per GPU.
global_batch_size (int): Global batch size.
rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
num_workers (int): See ``torch.utils.data.DataLoader`` documentation.
pin_memory (bool): See ``torch.utils.data.DataLoader`` documentation.
persistent_workers (bool): See ``torch.utils.data.DataLoader`` documentation.
reset_position_ids (bool): Option to reset the position IDs in the dataset at an interval.
reset_attention_mask (bool): Option to reset the attention mask from the dataset.
eod_mask_loss (int): Option to enable the EOD mask loss.
seed (int): Seed for generating the GPT dataset.
split (str): A string of 3 comma-separated integers denoting how much of the distribution
to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict.
index_mapping_dir (Optional[str]): Path to a directory to write index mapping files.
"""

def __init__(
self,
paths: Union[Path, List, Dict[str, List]],
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
reset_position_ids: bool = False,
reset_attention_mask: bool = False,
eod_mask_loss: bool = False,
seed: int = 1234,
split: str = "900,50,50",
index_mapping_dir: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple, dict)):
paths = [paths]

from megatron.core.datasets.utils import get_blend_from_list

build_kwargs = {}
if isinstance(paths, dict):
if split is not None:
warnings.warn(
f"{split=} will be ignored since datasets are being created " f"from 3 separate distributions."
)
build_kwargs["blend_per_split"] = [
get_blend_from_list(paths["train"]),
get_blend_from_list(paths["validation"]),
get_blend_from_list(paths["test"]),
]
else:
paths, weights = get_blend_from_list(paths)
if len(paths) == 1:
weights = None
build_kwargs["blend"] = [paths, weights]
build_kwargs["split"] = split

self.build_kwargs = build_kwargs
self.seq_length = seq_length
self.tokenizer = tokenizer
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.reset_position_ids = reset_position_ids
self.reset_attention_mask = reset_attention_mask
self.eod_mask_loss = eod_mask_loss
self.seed = seed
self.split = split
self.index_mapping_dir = index_mapping_dir
self.init_global_step = 0

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceLowerCase")

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
)

def setup(self, stage: str = "") -> None:
"""Assign Train/Val/Test dataset"""
from megatron.core.datasets.bert_dataset import BERTMaskedWordPieceDataset
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder

assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."

# Trainer API
max_train_steps = self.trainer.max_steps
assert max_train_steps > 0, "Please specify trainer.max_steps"
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
num_val_samples = int(eval_iters * self.data_sampler.global_batch_size)
num_test_samples = int(test_iters * self.data_sampler.global_batch_size)

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
assert "blend" not in self.build_kwargs, (
"When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd "
"like to run with a fractional value of limit_val_batches, please pass in separate datasets for "
"the train, validation, and test datasets by providing a dictionary of paths, e.g.: \n"
" paths={ \n "
" 'train': [PATHS FOR TRAIN], \n "
" 'validation': [PATHS FOR VALIDATION], \n "
" 'test' :[PATHS FOR TEST], \n"
" }"
)

# This is to make sure we only have one epoch on every validation iteration
num_val_samples = None

train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples]
print('Building dataset')
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved
self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
BERTMaskedWordPieceDataset,
train_valid_test_num_samples,
is_built_on_rank=lambda: True,
config=self.bert_dataset_config,
).build()
print('Building Dataset Done.')

# uncomment once fabric API is merged
# def fabric_setup(
# self,
# fabric: fl.Fabric,
# num_train_samples: int,
# num_val_samples: int,
# num_test_samples: int,
# ) -> None:
# from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
# from megatron.core.datasets.gpt_dataset import GPTDataset
#
# del fabric
# train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples]
# self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
# GPTDataset, train_valid_test_num_samples, self.gpt_dataset_config,
# ).build()

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Create Train dataloader"""
return self._create_dataloader(self._train_ds, mode='train')

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Create Validation dataloader"""
return self._create_dataloader(self._validation_ds, mode='validation')

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Create Test dataloader"""
return self._create_dataloader(self._test_ds, mode='test')

def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
self.init_global_step = self.trainer.global_step
self.data_sampler.init_global_step = self.init_global_step
dataloader = WrappedDataLoader(
mode=mode,
dataset=dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate),
**kwargs,
)
return dataloader

@property
def bert_dataset_config(self) -> "BERTMaskedWordPieceDatasetConfig":
"""Create Bert Dataset Config using Mcore's BERT MaskedWordPieceDatasetConfig"""
from megatron.core.datasets.bert_dataset import BERTMaskedWordPieceDatasetConfig

return BERTMaskedWordPieceDatasetConfig(
random_seed=self.seed,
sequence_length=self.seq_length,
tokenizer=self.tokenizer,
path_to_cache=self.index_mapping_dir,
classification_head=True,
masking_probability=0.15,
short_sequence_probability=0.10,
masking_max_ngram=3, # Following values are taken from megatron-lm/pretrain_bert.py
masking_do_full_word=True,
masking_do_permutation=False,
masking_use_longer_ngrams=False,
masking_use_geometric_distribution=False,
**self.build_kwargs,
)

def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:
A dictionary containing datamodule state.

"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {'consumed_samples': consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat

Args:
state_dict: the datamodule state returned by ``state_dict``.

"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches

except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches

consumed_samples = state_dict['consumed_samples']
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1

def reconfigure_limit_batches(self):
"""Reconfigure trainer.limit_val_batches for pretraining"""
# Override limit_train_batches in terms of num of microbatches
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train')
# Override limit_val_batches to be a multiple of num microbatches
# to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val')

def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_batches in terms of num microbatches
# and so there are limit_batches//num_micro_batches num of global batches
try:
from megatron.core.num_microbatches_calculator import get_num_microbatches

except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

if isinstance(limit_batches, int):
limit_batches *= get_num_microbatches()
else:
assert isinstance(limit_batches, float)
# Don't reconfigure if limit_batches is 0.0 or if there's no dataloader
if limit_batches == 0.0 or dataloader is None:
return
# len(dataloader) returns len as num of microbatches
dl_len_in_micro_batches = len(dataloader)
if len(dataloader) != float("inf"):
if limit_batches == 1.0:
limit_batches = dl_len_in_micro_batches
else:
limit_micro_batches = int(dl_len_in_micro_batches * limit_batches)
if limit_micro_batches == 0 and limit_batches > 0.0:
min_percentage = 1.0 / len(dataloader)
raise MisconfigurationException(
f"You requested to check {limit_batches} of the val_dataloader but"
f" {limit_batches} * {len(dataloader)} < 1. Please increase the"
f" `limit_val_batches` argument. Try at least"
f" `limit_val_batches={min_percentage}`"
)
# Make sure trainer.limit_val_batches is a multiple of num of microbatches
if limit_micro_batches < get_num_microbatches():
limit_batches = get_num_microbatches()
else:
limit_batches = limit_batches - limit_batches % get_num_microbatches()

if mode == 'train':
self.trainer.limit_train_batches = limit_batches
else:
self.trainer.limit_val_batches = limit_batches

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()
10 changes: 10 additions & 0 deletions nemo/collections/llm/bert/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from nemo.collections.llm.bert.model.base import BertConfig, BertModel
from nemo.collections.llm.bert.model.google_bert import GoogleBertBaseConfig, GoogleBertConfig, GoogleBertModel

__all__ = [
"BertConfig",
"BertModel",
"GoogleBertBaseConfig",
"GoogleBertConfig",
"GoogleBertModel",
]
Loading
Loading