diff --git a/nemo/collections/llm/bert/model/base.py b/nemo/collections/llm/bert/model/base.py index 967891082e65..c55ae1108845 100644 --- a/nemo/collections/llm/bert/model/base.py +++ b/nemo/collections/llm/bert/model/base.py @@ -20,9 +20,7 @@ import torch.distributed from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.models.bert import bert_layer_specs from megatron.core.models.bert.bert_lm_head import BertLMHead as MCoreBertLMHead -from megatron.core.models.bert.bert_model import BertModel as MCoreBert from megatron.core.models.bert.pooler import Pooler from megatron.core.optimizer import OptimizerConfig from megatron.core.packed_seq_params import PackedSeqParams @@ -46,8 +44,11 @@ HAVE_TE = True try: import transformer_engine # pylint: disable=W0611 + from megatron.core.models.bert import bert_layer_specs + from megatron.core.models.bert.bert_model import BertModel as MCoreBert except (ImportError, ModuleNotFoundError) as e: HAVE_TE = False + MCoreBert = TransformerLayer # Place holder for import checking. BERT requires TE installed. if TYPE_CHECKING: from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec