From a23345b3d6501a80bd5255833c83a1c5724df1eb Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Sat, 30 Nov 2024 14:31:11 -0800 Subject: [PATCH 1/5] Fix fake_initialize_model_parallel for MoE models * Make the use of RankGenerator consistent with recent Mcore changes !1940 ( https://github.com/NVIDIA/Megatron-LM/commit/7f22e210cddc3215adda25d9e16ea512dc32458c) - use ep=1 for decoder_rank_generator, making it treat EP as part of DP - define a new expert_decoder_rank_generator to handle EP groups/ranks only Signed-off-by: Guyue Huang --- .../modules/common/megatron/megatron_init.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 10b939d4aecb..260c1302adaa 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -349,15 +349,31 @@ def fake_initialize_model_parallel( decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, - ep=expert_model_parallel_size_, + ep=1, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', rank_offset=encoder_world_size, ) + + # the default setting uses DEP (expert-parallel ranks for FFN are data-parallel ranks for Attention. This definition follows that rule.) + expert_decoder_rank_generator = RankGenerator( + tp=tensor_model_parallel_size, # the same as Attention part + ep=expert_model_parallel_size_, + dp=(decoder_world_size // (expert_model_parallel_size_ * tensor_model_parallel_size * pipeline_model_parallel_size)), + pp=pipeline_model_parallel_size, + cp=1, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=encoder_world_size, + ) - def generator_wrapper(group_type, **kwargs): + assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks( + "pp" + ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \ + but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}" + + def generator_wrapper(group_type, is_expert=False, **kwargs): from itertools import cycle """The `RankGenerator` class produces a hyper-rectangle for a given set of @@ -365,7 +381,10 @@ def generator_wrapper(group_type, **kwargs): in addition to the default decoder, we essentially instantiate two `RankGenerator` classes to construct the parallelism for each module separately, and we then have to stitch them together for the right groups. For now, this means pp and tp-pp.""" - d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if is_expert: + d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs) + else: + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) if encoder_rank_generator is None: for x in d_ranks: yield x @@ -446,7 +465,7 @@ def generator_wrapper(group_type, **kwargs): # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - for ranks in generator_wrapper('ep', independent_ep=True): + for ranks in generator_wrapper('ep', is_expert=True): if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank) From b12892b42c508d5b69bc952af2e797096f16d362 Mon Sep 17 00:00:00 2001 From: guyueh1 Date: Sat, 30 Nov 2024 22:43:10 +0000 Subject: [PATCH 2/5] Apply isort and black reformatting Signed-off-by: guyueh1 --- .../nlp/modules/common/megatron/megatron_init.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 260c1302adaa..d58b97e3dd10 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -356,12 +356,15 @@ def fake_initialize_model_parallel( order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', rank_offset=encoder_world_size, ) - + # the default setting uses DEP (expert-parallel ranks for FFN are data-parallel ranks for Attention. This definition follows that rule.) expert_decoder_rank_generator = RankGenerator( - tp=tensor_model_parallel_size, # the same as Attention part + tp=tensor_model_parallel_size, # the same as Attention part ep=expert_model_parallel_size_, - dp=(decoder_world_size // (expert_model_parallel_size_ * tensor_model_parallel_size * pipeline_model_parallel_size)), + dp=( + decoder_world_size + // (expert_model_parallel_size_ * tensor_model_parallel_size * pipeline_model_parallel_size) + ), pp=pipeline_model_parallel_size, cp=1, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', From 23befa1f6a0f15a7023a91786dd8f6704b2f668d Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Tue, 3 Dec 2024 10:30:14 -0800 Subject: [PATCH 3/5] Fix expert rank generator Signed-off-by: Guyue Huang --- .../nlp/modules/common/megatron/megatron_init.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index d58b97e3dd10..ff16d36fb4bb 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -361,12 +361,9 @@ def fake_initialize_model_parallel( expert_decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, # the same as Attention part ep=expert_model_parallel_size_, - dp=( - decoder_world_size - // (expert_model_parallel_size_ * tensor_model_parallel_size * pipeline_model_parallel_size) - ), + dp=(data_parallel_size // expert_model_parallel_size_), pp=pipeline_model_parallel_size, - cp=1, + cp=context_parallel_size, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', rank_offset=encoder_world_size, ) From 8105144e45e33e7f593283d02754292f16703d38 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Tue, 3 Dec 2024 21:03:14 -0800 Subject: [PATCH 4/5] Make rank generation consistent with mcore (fold EP into CP then DP) Signed-off-by: Guyue Huang --- .../modules/common/megatron/megatron_init.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index ff16d36fb4bb..eafe10eeec2b 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -357,13 +357,22 @@ def fake_initialize_model_parallel( rank_offset=encoder_world_size, ) - # the default setting uses DEP (expert-parallel ranks for FFN are data-parallel ranks for Attention. This definition follows that rule.) + # This configuration folds EP first into CP, and if EP>CP fold the remained ranks of EP into DP + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size * expert_model_parallel_size_ * pipeline_model_parallel_size + ) + expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size + if decoder_world_size % expert_tensor_model_pipeline_parallel_size == 0: + raise RuntimeError( + f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) expert_decoder_rank_generator = RankGenerator( - tp=tensor_model_parallel_size, # the same as Attention part + tp=expert_tensor_parallel_size, ep=expert_model_parallel_size_, - dp=(data_parallel_size // expert_model_parallel_size_), + dp=expert_data_parallel_size, pp=pipeline_model_parallel_size, - cp=context_parallel_size, + cp=1, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', rank_offset=encoder_world_size, ) From b8b6fd47805dd309bf431963726fcf9320df8a81 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 4 Dec 2024 10:01:27 -0800 Subject: [PATCH 5/5] fix Signed-off-by: Guyue Huang --- nemo/collections/nlp/modules/common/megatron/megatron_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index eafe10eeec2b..1811264f2093 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -363,7 +363,7 @@ def fake_initialize_model_parallel( expert_tensor_parallel_size * expert_model_parallel_size_ * pipeline_model_parallel_size ) expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size - if decoder_world_size % expert_tensor_model_pipeline_parallel_size == 0: + if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0: raise RuntimeError( f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" )