From 5acf9aa2af80d8c1edfc8ce6d7e74a4cb87e2e75 Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Mon, 2 Dec 2024 15:45:35 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/bert/model/base.py | 5 +++-- nemo/collections/llm/bert/model/bert_spec.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/bert/model/base.py b/nemo/collections/llm/bert/model/base.py index f22a575e6cc3..967891082e65 100644 --- a/nemo/collections/llm/bert/model/base.py +++ b/nemo/collections/llm/bert/model/base.py @@ -20,7 +20,9 @@ import torch.distributed from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.models.bert import bert_layer_specs from megatron.core.models.bert.bert_lm_head import BertLMHead as MCoreBertLMHead +from megatron.core.models.bert.bert_model import BertModel as MCoreBert from megatron.core.models.bert.pooler import Pooler from megatron.core.optimizer import OptimizerConfig from megatron.core.packed_seq_params import PackedSeqParams @@ -29,10 +31,9 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.utils import get_linear_layer as mcore_get_linear_layer -from megatron.core.models.bert import bert_layer_specs from megatron.core.utils import make_viewless_tensor from torch import Tensor, nn -from megatron.core.models.bert.bert_model import BertModel as MCoreBert + from nemo.collections.llm import fn from nemo.collections.llm.bert.loss import BERTLossReduction from nemo.collections.llm.bert.model.bert_spec import ( diff --git a/nemo/collections/llm/bert/model/bert_spec.py b/nemo/collections/llm/bert/model/bert_spec.py index a5943be323bf..c6165a24cdfb 100644 --- a/nemo/collections/llm/bert/model/bert_spec.py +++ b/nemo/collections/llm/bert/model/bert_spec.py @@ -18,12 +18,11 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules - from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules - from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.utils import make_viewless_tensor @@ -39,16 +38,17 @@ TEDotProductAttention, TENorm, TERowParallelLinear, - ) HAVE_TE = True except (ImportError, ModuleNotFoundError) as e: HAVE_TE = False + @dataclass class TransformerLayerSubmodulesWithPostLNSupport(TransformerLayerSubmodules): """TransformerLayerSubmodules with post layer norm""" + def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs): super(TransformerLayerSubmodulesWithPostLNSupport, self).__init__(**kwargs) self.post_att_layernorm = post_att_layernorm @@ -57,6 +57,7 @@ def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs): class TransformerLayerWithPostLNSupport(TransformerLayer): """TransformerLayer with post layer norm.""" + def __init__(self, *args, **kwargs): super(TransformerLayerWithPostLNSupport, self).__init__(*args, **kwargs) ## [Module add: Post attention LN] @@ -88,7 +89,7 @@ def forward( packed_seq_params=None, ): """Copy from megatron/core/transformer/transformer_layer.py with modification of applying - extra post layer norm if needed.""" + extra post layer norm if needed.""" # hidden_states: [s, b, h] # Residual connection. @@ -201,6 +202,7 @@ def get_bert_layer_with_transformer_engine_spec_postln(): ), ) + # Use this spec for an implementation using only modules in megatron core def get_bert_layer_local_spec_postln(): """Retrieve the Layer Spec when using MCore Engine""" @@ -230,4 +232,4 @@ def get_bert_layer_local_spec_postln(): mlp_bda=get_bias_dropout_add, post_mlp_layernorm=FusedLayerNorm, ), - ) \ No newline at end of file + )