Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: suiyoubi <[email protected]>
  • Loading branch information
suiyoubi committed Dec 2, 2024
1 parent 4a7f608 commit 5acf9aa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
5 changes: 3 additions & 2 deletions nemo/collections/llm/bert/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import torch.distributed
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.models.bert import bert_layer_specs
from megatron.core.models.bert.bert_lm_head import BertLMHead as MCoreBertLMHead
from megatron.core.models.bert.bert_model import BertModel as MCoreBert
from megatron.core.models.bert.pooler import Pooler
from megatron.core.optimizer import OptimizerConfig
from megatron.core.packed_seq_params import PackedSeqParams
Expand All @@ -29,10 +31,9 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.utils import get_linear_layer as mcore_get_linear_layer
from megatron.core.models.bert import bert_layer_specs
from megatron.core.utils import make_viewless_tensor
from torch import Tensor, nn
from megatron.core.models.bert.bert_model import BertModel as MCoreBert

from nemo.collections.llm import fn
from nemo.collections.llm.bert.loss import BERTLossReduction
from nemo.collections.llm.bert.model.bert_spec import (
Expand Down
12 changes: 7 additions & 5 deletions nemo/collections/llm/bert/model/bert_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import make_viewless_tensor

Expand All @@ -39,16 +38,17 @@
TEDotProductAttention,
TENorm,
TERowParallelLinear,

)

HAVE_TE = True
except (ImportError, ModuleNotFoundError) as e:
HAVE_TE = False


@dataclass
class TransformerLayerSubmodulesWithPostLNSupport(TransformerLayerSubmodules):
"""TransformerLayerSubmodules with post layer norm"""

def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs):
super(TransformerLayerSubmodulesWithPostLNSupport, self).__init__(**kwargs)
self.post_att_layernorm = post_att_layernorm
Expand All @@ -57,6 +57,7 @@ def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs):

class TransformerLayerWithPostLNSupport(TransformerLayer):
"""TransformerLayer with post layer norm."""

def __init__(self, *args, **kwargs):
super(TransformerLayerWithPostLNSupport, self).__init__(*args, **kwargs)
## [Module add: Post attention LN]
Expand Down Expand Up @@ -88,7 +89,7 @@ def forward(
packed_seq_params=None,
):
"""Copy from megatron/core/transformer/transformer_layer.py with modification of applying
extra post layer norm if needed."""
extra post layer norm if needed."""
# hidden_states: [s, b, h]

# Residual connection.
Expand Down Expand Up @@ -201,6 +202,7 @@ def get_bert_layer_with_transformer_engine_spec_postln():
),
)


# Use this spec for an implementation using only modules in megatron core
def get_bert_layer_local_spec_postln():
"""Retrieve the Layer Spec when using MCore Engine"""
Expand Down Expand Up @@ -230,4 +232,4 @@ def get_bert_layer_local_spec_postln():
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=FusedLayerNorm,
),
)
)

0 comments on commit 5acf9aa

Please sign in to comment.