Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BERT Model To NeMo2.0 #11333

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
b922807
add bert in nemo2
suiyoubi Sep 4, 2024
0e0822c
Add bert loss
suiyoubi Sep 9, 2024
ccab2d5
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Nov 14, 2024
70973cd
update bert with latest nemo
suiyoubi Nov 14, 2024
686eed1
Apply isort and black reformatting
suiyoubi Nov 14, 2024
794951e
Bert update
suiyoubi Nov 18, 2024
bf91626
Apply isort and black reformatting
suiyoubi Nov 18, 2024
c2b2949
import optimize
suiyoubi Nov 18, 2024
f64457f
Apply isort and black reformatting
suiyoubi Nov 18, 2024
d70c098
pylint
suiyoubi Nov 19, 2024
27707f5
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
b91c247
Apply isort and black reformatting
suiyoubi Nov 19, 2024
146c21b
Merge remote-tracking branch 'origin/main' into aot/bert-nemo-ux
suiyoubi Nov 19, 2024
3a8841c
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
3247f36
pylint
suiyoubi Nov 19, 2024
bcc9b7c
Apply isort and black reformatting
suiyoubi Nov 19, 2024
5f19763
pylint
suiyoubi Nov 19, 2024
9bbef09
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 19, 2024
f42ec30
Apply isort and black reformatting
suiyoubi Nov 19, 2024
388ed8a
more comment
suiyoubi Nov 20, 2024
11a141e
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 20, 2024
8dad5a5
Merge remote-tracking branch 'origin/main' into aot/bert-nemo-ux
suiyoubi Nov 21, 2024
f00d2df
use lightning package
suiyoubi Nov 21, 2024
ba73a4d
Apply isort and black reformatting
suiyoubi Nov 21, 2024
546a646
comments resolved
suiyoubi Nov 22, 2024
2cf371d
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 22, 2024
659e29c
Add loss
suiyoubi Nov 22, 2024
2c542d0
Apply isort and black reformatting
suiyoubi Nov 22, 2024
451a1a0
Fix NaN loss when resume
suiyoubi Nov 22, 2024
c72f1be
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 22, 2024
d2bbe9c
Apply isort and black reformatting
suiyoubi Nov 22, 2024
846b236
add default for num_tokentypes
suiyoubi Nov 22, 2024
cc1893d
update forward output to be a dict
suiyoubi Nov 24, 2024
9654cfd
Apply isort and black reformatting
suiyoubi Nov 24, 2024
8658f53
ReName to HuggingFace Bert Model
suiyoubi Nov 24, 2024
cb1332d
Add BertLarge
suiyoubi Nov 26, 2024
7ae576c
Apply isort and black reformatting
suiyoubi Nov 26, 2024
3826aa3
Add BERT Tests
suiyoubi Nov 26, 2024
51692ad
Merge remote-tracking branch 'origin/aot/bert-nemo-ux' into aot/bert-…
suiyoubi Nov 26, 2024
ea969ed
Apply isort and black reformatting
suiyoubi Nov 26, 2024
06f03ff
typo
suiyoubi Nov 26, 2024
f61877a
add exporter
suiyoubi Nov 26, 2024
79acd82
Apply isort and black reformatting
suiyoubi Nov 26, 2024
19337e3
Fix Unit Tests
suiyoubi Nov 26, 2024
b15051d
Apply isort and black reformatting
suiyoubi Nov 26, 2024
cca4c79
add sig
suiyoubi Nov 27, 2024
db19dd1
rename pretraining dataset testing for bert
suiyoubi Nov 28, 2024
a2a859c
delete pretraining dataset testing for bert
suiyoubi Nov 28, 2024
1e9bc3f
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Nov 28, 2024
69b6128
resolve TE dependency
suiyoubi Dec 2, 2024
27b56a9
Apply isort and black reformatting
suiyoubi Dec 2, 2024
56fa3f1
resolve TE dependency
suiyoubi Dec 2, 2024
9b11d84
resolve TE dependency
suiyoubi Dec 2, 2024
4a7f608
resolve TE dependency
suiyoubi Dec 2, 2024
5acf9aa
Apply isort and black reformatting
suiyoubi Dec 2, 2024
d363a11
resolve TE dependency
suiyoubi Dec 2, 2024
c0a7836
Apply isort and black reformatting
suiyoubi Dec 2, 2024
de1c788
Merge branch 'main' into aot/bert-nemo-ux
suiyoubi Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
safe_import("transformer_engine")

from nemo.collections.llm import peft
from nemo.collections.llm.bert.data import BERTPreTrainingDataModule
from nemo.collections.llm.bert.model import (
BertConfig,
BertModel,
HuggingFaceBertBaseConfig,
HuggingFaceBertConfig,
HuggingFaceBertLargeConfig,
HuggingFaceBertModel,
MegatronBertBaseConfig,
MegatronBertConfig,
MegatronBertLargeConfig,
)
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
DollyDataModule,
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions nemo/collections/llm/bert/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nemo.collections.llm.bert.data.pre_training import BERTPreTrainingDataModule

__all__ = ["BERTPreTrainingDataModule"]
335 changes: 335 additions & 0 deletions nemo/collections/llm/bert/data/pre_training.py

Large diffs are not rendered by default.

121 changes: 121 additions & 0 deletions nemo/collections/llm/bert/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from nemo.lightning.megatron_parallel import MaskedTokenLossReduction, MegatronLossReduction


class BERTLossReduction(MegatronLossReduction):
"""Bert Loss Function.
when add_sop_loss = False, only calculate Masked token loss.
"""

def __init__(self, validation_step: bool = False, val_drop_last: bool = True, add_sop_loss: bool = True) -> None:
super().__init__()
self.validation_step = validation_step
self.val_drop_last = val_drop_last
self.add_sop_loss = add_sop_loss
if not add_sop_loss:
# BERTLoss would act like MaskedTokenLossReduction when only use MLM loss
self.mlm = MaskedTokenLossReduction(validation_step, val_drop_last)

def forward(
self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Perform Loss calculation on batch.
Currently, Context parallelism is not supported for SOP loss.
"""

# Update loss_mask to batch.
# Model forward did no update to loss_mask, but for unknown reason loss_mask can get lost (to None)
# in 'batch' during update. We use the original loss_mask in the dataloader as the ground truth.
batch['loss_mask'] = forward_out['loss_mask']
if not self.add_sop_loss:
return self.mlm.forward(batch, forward_out['lm_loss'])

from megatron.core import parallel_state

from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group

lm_loss_, sop_logits = forward_out['lm_loss'], forward_out['binary_logits']
assert sop_logits is not None, (
'Attempting to calculate Sentence Order Prediction Loss but SOP logits '
'are not provideds, Please Make sure you have added binary head.'
)

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
sop_loss_for_ub = sentence_order_prediction_loss(sop_logits, batch["is_random"])
lm_loss_for_ub = masked_token_with_zero(lm_loss_, batch["loss_mask"])
else:
raise NotImplementedError('CP is not supported for SOP loss yet')

loss_for_ub = sop_loss_for_ub + lm_loss_for_ub
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub * cp_size, {"avg": reduced_loss}

def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
"""Taken from: https://github.com/NVIDIA/NeMo/blob/main
/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 ."""
if losses_reduced_per_micro_batch:
if "avg" in losses_reduced_per_micro_batch[0]:
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)

return loss_tensor.mean()

# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list: List[torch.Tensor] = [
loss_sum["loss_sum_and_ub_size"]
for loss_sum in losses_reduced_per_micro_batch
if loss_sum["loss_sum_and_ub_size"][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(dim=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0], device=torch.cuda.current_device())
)
return loss_sum

return torch.tensor(0.0, device=torch.cuda.current_device())


def masked_token_with_zero(tensor: Tensor, mask: Tensor):
"""Calculate masked token loss with consideration of possible NaN.
Sometimes when the number of tokens is very small, none of the tokens get masked for prediction.
In that case loss mask is all zeros i.e Happens when the entire batch is masked out
(Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 )
"""
losses = tensor.float()
loss_mask = mask.float()
if loss_mask.sum() == 0:
loss = torch.sum(losses.view(-1)) * 0.0
else:
loss = torch.sum(losses.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

return loss


def sentence_order_prediction_loss(tensor: Tensor, sentence_order: Tensor):
"""Calculate sentence order prediction loss."""
losses = tensor.view(-1, 2).float()
sentence_order = sentence_order.view(-1)
loss = F.cross_entropy(losses, sentence_order, ignore_index=-1)

return loss
22 changes: 22 additions & 0 deletions nemo/collections/llm/bert/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from nemo.collections.llm.bert.model.base import BertConfig, BertModel
from nemo.collections.llm.bert.model.bert import (
HuggingFaceBertBaseConfig,
HuggingFaceBertConfig,
HuggingFaceBertLargeConfig,
HuggingFaceBertModel,
MegatronBertBaseConfig,
MegatronBertConfig,
MegatronBertLargeConfig,
)

__all__ = [
"BertConfig",
"BertModel",
"HuggingFaceBertBaseConfig",
"HuggingFaceBertLargeConfig",
"HuggingFaceBertConfig",
"HuggingFaceBertModel",
"MegatronBertConfig",
"MegatronBertBaseConfig",
"MegatronBertLargeConfig",
]
Loading
Loading