diff --git a/nemo/collections/llm/bert/model/base.py b/nemo/collections/llm/bert/model/base.py index 173f29d94c07..f22a575e6cc3 100644 --- a/nemo/collections/llm/bert/model/base.py +++ b/nemo/collections/llm/bert/model/base.py @@ -29,15 +29,15 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.utils import get_linear_layer as mcore_get_linear_layer +from megatron.core.models.bert import bert_layer_specs from megatron.core.utils import make_viewless_tensor from torch import Tensor, nn - +from megatron.core.models.bert.bert_model import BertModel as MCoreBert from nemo.collections.llm import fn from nemo.collections.llm.bert.loss import BERTLossReduction from nemo.collections.llm.bert.model.bert_spec import ( - bert_layer_local_spec_postln, - bert_layer_with_transformer_engine_spec_postln, - megatron_layer_local_spec_preln, + get_bert_layer_local_spec_postln, + get_bert_layer_with_transformer_engine_spec_postln, ) from nemo.lightning import get_vocab_size, io from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule @@ -45,9 +45,7 @@ HAVE_TE = True try: import transformer_engine # pylint: disable=W0611 - from megatron.core.models.bert import bert_layer_specs - from megatron.core.models.bert.bert_model import BertModel as MCoreBert -except (ImportError, ModuleNotFoundError): +except (ImportError, ModuleNotFoundError) as e: HAVE_TE = False if TYPE_CHECKING: @@ -114,12 +112,12 @@ def default_layer_spec(config: "BertConfig") -> ModuleSpec: if bert_type == 'megatron': return bert_layer_specs.bert_layer_with_transformer_engine_spec else: - return bert_layer_with_transformer_engine_spec_postln + return get_bert_layer_with_transformer_engine_spec_postln() if bert_type == 'megatron': - return megatron_layer_local_spec_preln + return bert_layer_specs.bert_layer_local_spec else: - return bert_layer_local_spec_postln + return get_bert_layer_local_spec_postln() @dataclass @@ -542,6 +540,15 @@ def __init__( tokenizer: Optional["TokenizerSpec"] = None, model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): + # Megatron-LM's BERT implementation has high dependency on TE, and it is not possible + # to instantiate the MCore BERT without TE package. + # Few issues there: 1. bert_layer_specs.py is not TE dependency-free. + # 2. in bert_model.py _sanity_check_attention_and_get_attn_mask_dimension() checks on + # if transformer_layer_spec is identical to bert_layer_local_spec to determine if TE is + # required; since in NeMo we use customized bert layer spec, it will always assume this + # if using TE. + # We need to address the above two issues to enable TE-Free NeMo BERT. + assert HAVE_TE, "NeMo BERT requires Transformer Engine to be installed." super().__init__() self.config = config self.tokenizer = tokenizer diff --git a/nemo/collections/llm/bert/model/bert_spec.py b/nemo/collections/llm/bert/model/bert_spec.py index 453b323d5632..a5943be323bf 100644 --- a/nemo/collections/llm/bert/model/bert_spec.py +++ b/nemo/collections/llm/bert/model/bert_spec.py @@ -11,139 +11,223 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults +from dataclasses import dataclass try: from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules - from megatron.core.transformer.custom_layers.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TENorm, - TERowParallelLinear, - ) + from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + from megatron.core.utils import make_viewless_tensor HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError): +except (ImportError, ModuleNotFoundError) as e: TransformerConfig = ApexGuardDefaults HAVE_MEGATRON_CORE = False try: - import apex # pylint: disable=unused-import + from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TENorm, + TERowParallelLinear, - from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + ) + + HAVE_TE = True +except (ImportError, ModuleNotFoundError) as e: + HAVE_TE = False + +@dataclass +class TransformerLayerSubmodulesWithPostLNSupport(TransformerLayerSubmodules): + """TransformerLayerSubmodules with post layer norm""" + def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs): + super(TransformerLayerSubmodulesWithPostLNSupport, self).__init__(**kwargs) + self.post_att_layernorm = post_att_layernorm + self.post_mlp_layernorm = post_mlp_layernorm + + +class TransformerLayerWithPostLNSupport(TransformerLayer): + """TransformerLayer with post layer norm.""" + def __init__(self, *args, **kwargs): + super(TransformerLayerWithPostLNSupport, self).__init__(*args, **kwargs) + ## [Module add: Post attention LN] + self.post_att_layernorm = build_module( + self.submodules_config.post_att_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + ## [Module add: Post MLP LN] + self.post_mlp_layernorm = build_module( + self.submodules_config.post_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + ): + """Copy from megatron/core/transformer/transformer_layer.py with modification of applying + extra post layer norm if needed.""" + # hidden_states: [s, b, h] + + # Residual connection. + residual = hidden_states - HAVE_APEX = True - LNImpl = FusedLayerNorm -except ImportError: - import warnings + # Optional Input Layer norm + input_layernorm_output = self.input_layernorm(hidden_states) - from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + # Self attention. + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + ) - warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm') - LNImpl = WrappedTorchLayerNorm + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + with self.bias_dropout_add_exec_handler(): + hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.hidden_dropout + ) + + # Residual connection. + residual = hidden_states + + # Post-LN after Self Attention + hidden_states = self.post_att_layernorm(hidden_states) + + # Optional Layer norm after self-attention + pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) + + # Cross attention. + attention_output_with_bias = self.cross_attention( + pre_cross_attn_layernorm_output, + attention_mask=context_mask, + key_value_states=context, + inference_params=inference_params, + ) + + if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: + context = attention_output_with_bias["context"] + + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + with self.bias_dropout_add_exec_handler(): + hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.hidden_dropout + ) + + # Residual connection. + residual = hidden_states + + # Optional Layer norm post the cross-attention. + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + + # MLP. + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( + mlp_output_with_bias, residual, self.hidden_dropout + ) + + # Post-LN after MLP + hidden_states = self.post_mlp_layernorm(hidden_states) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, context -from nemo.collections.nlp.models.language_modeling.megatron.bert.bert_model import ( - TransformerLayerSubmodulesWithPostLNSupport, - TransformerLayerWithPostLNSupport, -) # Use this spec to use lower level Transformer Engine modules (required for fp8 training) -bert_layer_with_transformer_engine_spec_postln = ModuleSpec( - module=TransformerLayerWithPostLNSupport, - submodules=TransformerLayerSubmodulesWithPostLNSupport( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=TEColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, +def get_bert_layer_with_transformer_engine_spec_postln(): + """Retrieve the Layer Spec when using Transformer Engine""" + return ModuleSpec( + module=TransformerLayerWithPostLNSupport, + submodules=TransformerLayerSubmodulesWithPostLNSupport( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), ), - ), - self_attn_bda=get_bias_dropout_add, - post_att_layernorm=TENorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, + self_attn_bda=get_bias_dropout_add, + post_att_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), ), + mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=TENorm, ), - mlp_bda=get_bias_dropout_add, - post_mlp_layernorm=TENorm, - ), -) + ) # Use this spec for an implementation using only modules in megatron core -bert_layer_local_spec_postln = ModuleSpec( - module=TransformerLayerWithPostLNSupport, - submodules=TransformerLayerSubmodulesWithPostLNSupport( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, +def get_bert_layer_local_spec_postln(): + """Retrieve the Layer Spec when using MCore Engine""" + return ModuleSpec( + module=TransformerLayerWithPostLNSupport, + submodules=TransformerLayerSubmodulesWithPostLNSupport( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), ), - ), - self_attn_bda=get_bias_dropout_add, - post_att_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - post_mlp_layernorm=FusedLayerNorm, - ), -) - -# We copy the Mcore's local spec here to avoid TE dependency issue. -# Megatron-LM's core/models/bert/bert_layer_specs.py always requires -# TE dependency to load. Avoid it by copying paste the local specs in NeMo. -megatron_layer_local_spec_preln = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=LNImpl, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, + self_attn_bda=get_bias_dropout_add, + post_att_layernorm=FusedLayerNorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), ), + mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=FusedLayerNorm, ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=LNImpl, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), - ), - mlp_bda=get_bias_dropout_add, - sharded_state_dict_keys_map={ - 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', - }, - ), -) + ) \ No newline at end of file