diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 10b939d4aecb..eafe10eeec2b 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -349,7 +349,7 @@ def fake_initialize_model_parallel( decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, - ep=expert_model_parallel_size_, + ep=1, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, @@ -357,7 +357,32 @@ def fake_initialize_model_parallel( rank_offset=encoder_world_size, ) - def generator_wrapper(group_type, **kwargs): + # This configuration folds EP first into CP, and if EP>CP fold the remained ranks of EP into DP + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size * expert_model_parallel_size_ * pipeline_model_parallel_size + ) + expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size + if decoder_world_size % expert_tensor_model_pipeline_parallel_size == 0: + raise RuntimeError( + f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + expert_decoder_rank_generator = RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size_, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=encoder_world_size, + ) + + assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks( + "pp" + ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \ + but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}" + + def generator_wrapper(group_type, is_expert=False, **kwargs): from itertools import cycle """The `RankGenerator` class produces a hyper-rectangle for a given set of @@ -365,7 +390,10 @@ def generator_wrapper(group_type, **kwargs): in addition to the default decoder, we essentially instantiate two `RankGenerator` classes to construct the parallelism for each module separately, and we then have to stitch them together for the right groups. For now, this means pp and tp-pp.""" - d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if is_expert: + d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs) + else: + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) if encoder_rank_generator is None: for x in d_ranks: yield x @@ -446,7 +474,7 @@ def generator_wrapper(group_type, **kwargs): # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - for ranks in generator_wrapper('ep', independent_ep=True): + for ranks in generator_wrapper('ep', is_expert=True): if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank)