diff --git a/README.md b/README.md index e63a59f28..ef97cdc17 100644 --- a/README.md +++ b/README.md @@ -334,77 +334,18 @@ For a more detailed guide to the features available and how to configure them, s ## Mixture of Experts -GPT-NeoX includes multiple expert implementations for MoE. To select between them, specify `moe_type` of `megablocks` (default) or `deepspeed`. +GPT-NeoX includes support for Dropless Mixture of Experts (DMoE) through the `megablocks` library. It is compatible with both existing Megatron Tensor Parallelism and DeepSpeed Pipeline Parallel setups. -Both are based on the DeepSpeed MoE parallelism framework, which supports tensor-expert-data parallelism. -Both allow you to toggle between token-dropping and dropless (default, and this is what Megablocks was designed for). -Sinkhorn routing to come soon! +This implementation leverages the existing Tensor Parallel Group to also shard the expert weights. +It uses Sinkhorn routing to avoid the need for a load balancing loss. -For an example of a basic complete configuration, see configs/125M-dmoe.yml (for Megablocks dropless) or configs/125M-moe.yml. +For an example of a basic complete configuration, see configs/125M-dmoe.yml. -Most MoE related configuration arguments are prefixed with `moe`. Some common configuration parameters and their defaults are as follows: +Most MoE related configuration arguments are prefixed with `moe`. The bare minimum addition to your configuration to enable MoE is as follows: +```yaml +moe_num_experts: 1 # 1 disables MoE. 8 is a common value. ``` -moe_type: megablocks -moe_num_experts: 1 # 1 disables MoE. 8 is a reasonable value. -moe_loss_coeff: 0.1 -expert_interval: 2 # See details below -enable_expert_tensor_parallelism: false # See details below -moe_expert_parallel_size: 1 # See details below -moe_token_dropping: false -``` - -DeepSpeed can be further configured with the following: - -``` -moe_top_k: 1 -moe_min_capacity: 4 -moe_train_capacity_factor: 1.0 # Setting to 1.0 -moe_eval_capacity_factor: 1.0 # Setting to 1.0 -``` - -One MoE layer is present every `expert_interval` transformer layers including the first, so with 12 layers total: - -``` -0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 -``` - -Experts would be in these layers: - -``` -0, 2, 4, 6, 8, 10 -``` - -By default, we use expert-data parallelism, so any available tensor parallelism (`model_parallel_size`) will be used for expert routing. For instance, given the following: - -``` -expert_parallel_size: 4 -model_parallel_size: 2 # aka tensor parallelism -``` - -With 32 GPUs, the behavior will be look like: - -- In non-expert layers: - - Tensor parallelism is 2. (There are 32 / 2 = 16 such tensor parallel groups, each of size 2.) - - Data parallelism implicitly becomes 32 / 2 = 16. -- In expert layers: - - There is no tensor parallelism. - - Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.) - - Data parallelism implicitly becomes 32 / 4 = 8. Some cross-node token routing happens as a result of this redivision of data parallelism between 16 and 8. To avoid it, ensure that `expert_parallel_size == model_parallel_size`. - -Setting `enable_expert_tensor_parallelism` enables tensor-expert-data (TED) parallelism. The way to interpret the above would then be: - -- In non-expert layers: same as before. -- In expert layers: - - Tensor parallelism is 2. (There are 32 / 2 = 16 tensor parallel groups, each of size 2.) - - Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.) - - Data parallelism implicitly becomes 32 / (2 * 4) = 4. Again, cross-node token routing happens. To avoid, ensure `expert_parallel_size == 1` or `model_parallel_size == 1`. - -So note that DP must be divisible by (MP * EP). For more details, see the [TED paper]. - -Pipeline parallelism is not yet supported - coming soon! - -[TED paper]: https://arxiv.org/abs/2303.06318 # Datasets diff --git a/configs/125M-dmoe.yml b/configs/125M-dmoe.yml index 229191b4d..e712fc847 100644 --- a/configs/125M-dmoe.yml +++ b/configs/125M-dmoe.yml @@ -1,36 +1,28 @@ # GPT-2 pretraining setup { - # See README for MoE config docs! - "moe_type": "megablocks", - "moe_token_dropping": false, - # Have 4 experts per layer (every 2 layers by default) - "moe_num_experts": 4, - # parallelism settings - "enable_expert_tensor_parallelism": true, - "pipe_parallel_size": 1, # not yet supported for MoE - "model_parallel_size": 1, - "moe_expert_parallel_size": 1, + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe_parallel_size": 2, # MoE supports PP + "model_parallel_size": 2, # MoE uses model parallel group to split both experts and attention weights # model settings "num_layers": 12, - "hidden_size": 768, - "num_attention_heads": 12, + "hidden_size": 1024, + "num_attention_heads": 16, "seq_length": 2048, "max_position_embeddings": 2048, "norm": "layernorm", "pos_emb": "rotary", "no_weight_tying": true, - "gpt_j_residual": false, - "output_layer_parallelism": "column", + + # moe settings + "moe_num_experts": 8, # these should provide some speedup but takes a while to build, set to true if desired "scaled_upper_triang_masked_softmax_fusion": false, "bias_gelu_fusion": false, "rope_fusion": false, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", + "layernorm_fusion": false, # optimizer settings @@ -38,12 +30,10 @@ "type": "Adam", "params": { "lr": 0.0006, - "betas": [0.9, 0.95], + "betas": [0.9, 0.999], "eps": 1.0e-8, } }, - "min_lr": 0.00006, - # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training "zero_optimization": { "stage": 0, @@ -58,6 +48,7 @@ # batch / data settings "train_micro_batch_size_per_gpu": 4, "data_impl": "mmap", + "split": "949,50,1", # activation checkpointing "checkpoint_activations": true, @@ -67,35 +58,26 @@ # regularization "gradient_clipping": 1.0, - "weight_decay": 0.1, + "weight_decay": 0.0, "hidden_dropout": 0.0, "attention_dropout": 0.0, - # precision settings - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, + "precision": "bfloat16", + "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 # misc. training settings - "train_iters": 320000, + "train_iters": 5, "lr_decay_iters": 320000, "distributed_backend": "nccl", - "lr_decay_style": "cosine", - "warmup": 0.01, + "min_lr": 0.0006, + "warmup": 0.0, "checkpoint_factor": 10000, "eval_interval": 1000, "eval_iters": 10, # logging - "log_interval": 10, - "steps_per_print": 10, + "log_interval": 1, + "steps_per_print": 1, "keep_last_n_checkpoints": 4, "wall_clock_breakdown": true, - - # networking - "hostfile": "/mock_path" } diff --git a/configs/125M-moe.yml b/configs/125M-moe.yml deleted file mode 100644 index 1d08d78a4..000000000 --- a/configs/125M-moe.yml +++ /dev/null @@ -1,101 +0,0 @@ -# GPT-2 pretraining setup -{ - # See README for MoE config docs! - "moe_type": "deepspeed", - "moe_token_dropping": true, - # Have 4 experts per layer (every 2 layers by default) - "moe_num_experts": 4, - # parallelism settings - "enable_expert_tensor_parallelism": true, - "pipe_parallel_size": 1, # not yet supported for MoE - "model_parallel_size": 1, - "moe_expert_parallel_size": 1, - - # model settings - "num_layers": 12, - "hidden_size": 768, - "num_attention_heads": 12, - "seq_length": 2048, - "max_position_embeddings": 2048, - "norm": "layernorm", - "pos_emb": "rotary", - "no_weight_tying": true, - "gpt_j_residual": false, - "output_layer_parallelism": "column", - - # these should provide some speedup but takes a while to build, set to true if desired - "scaled_upper_triang_masked_softmax_fusion": false, - "bias_gelu_fusion": false, - "rope_fusion": false, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - - # optimizer settings - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.0006, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 0.00006, - - # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - }, - - # batch / data settings - "train_micro_batch_size_per_gpu": 4, - "data_impl": "mmap", - - # activation checkpointing - "checkpoint_activations": true, - "checkpoint_num_layers": 1, - "partition_activations": true, - "synchronize_each_layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight_decay": 0.1, - "hidden_dropout": 0.0, - "attention_dropout": 0.0, - - # precision settings - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # misc. training settings - "train_iters": 320000, - "lr_decay_iters": 320000, - "distributed_backend": "nccl", - "lr_decay_style": "cosine", - "warmup": 0.01, - "checkpoint_factor": 10000, - "eval_interval": 1000, - "eval_iters": 10, - - # logging - "log_interval": 10, - "steps_per_print": 10, - "keep_last_n_checkpoints": 4, - "wall_clock_breakdown": true, - - # networking - "hostfile": "/mock_path" -} diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index c6d369524..a39b8a058 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 1b85a2f + Default = 613aeb9 current git hash of repository @@ -247,6 +247,62 @@ Logging Arguments +## NeoXArgsMoE + +Mixture of Expert (MoE) Arguments + + + +- **moe_num_experts**: int + + Default = 1 + + The number of experts in MoE layers. MoE layers not used if set to 1 + + + +- **moe_expert_interval**: int + + Default = 1 + + Have one MoE layer every expert_interval layers + + + +- **moe_top_k**: int + + Default = 1 + + The number of experts each token is routed to in MoE layers. + + + +- **moe_router_type**: typing.Literal['sinkhorn', 'topk'] + + Default = sinkhorn + + What token routing algorithm to use. Currently only sinkhorn is supported for training. + TopK is only used for inference/eval. + + + +- **moe_lbl_in_fp32**: bool + + Default = False + + Whether to compute the load balancing loss in fp32. + + + +- **moe_jitter_eps**: float + + Default = None + + Coefficient for MoE routing jitter. Jitter is + not used if set to None + + + ## NeoXArgsModel Model Arguments @@ -1056,14 +1112,6 @@ Parallelism Arguments -- **expert_interval**: int - - Default = 2 - - Have one MoE layer every expert_interval layers - - - ## NeoXArgsTemplate NeoXArgsTemplate() @@ -1185,135 +1233,6 @@ Text Generation arguments -- **moe_top_k**: int - - Default = 1 - - Activate top K experts in MoE - - - -- **use_tutel**: bool - - Default = False - - Use Tutel optimizations in MoE - - - -- **moe_num_experts**: int - - Default = 1 - - Number of MoE experts - - - -- **moe_loss_coeff**: float - - Default = 0.1 - - Coefficient for MoE loss - - - -- **moe_train_capacity_factor**: float - - Default = 1.0 - - The capacity of the expert at train time - - - -- **moe_eval_capacity_factor**: float - - Default = 1.0 - - The capacity of the expert at eval time - - - -- **moe_min_capacity**: int - - Default = 4 - - The minimum capacity per expert regardless of the capacity_factor - - - -- **moe_token_dropping**: bool - - Default = False - - Whether to drop tokens when exceeding capacity - - - -- **create_moe_param_group**: bool - - Default = True - - Whether to create a separate parameter group for MoE parameters - - - -- **moe_use_residual**: bool - - Default = True - - Whether to use residual in MoE - - - -- **moe_expert_parallel_size**: int - - Default = 1 - - Number of parallel experts in MoE - - - -- **moe_type**: str - - Default = megablocks - - Either `deepspeed` or `megablocks` - - - -- **moe_glu**: bool - - Default = False - - Use gated linear units in MoE - - - -- **moe_lbl_in_fp32**: bool - - Default = False - - Whether to compute the load balancing loss in fp32. - - - -- **moe_jitter_eps**: float - - Default = None - - Coefficient for MoE routing jitter. Jitter is - not used if set to None - - - -- **enable_expert_tensor_parallelism**: bool - - Default = False - - Enable expert tensor parallelism - - - ## NeoXArgsTokenizer Tokenizer Arguments diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index 3694e964b..b100f0ca7 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -46,10 +46,10 @@ def load(neox_args=None): if int(bare_metal_minor) >= 1: cc_flag.append("-gencode") cc_flag.append("arch=compute_86,code=sm_86") - if int(bare_metal_minor) >= 4: + elif int(bare_metal_minor) >= 4: cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") - if int(bare_metal_minor) >= 8: + elif int(bare_metal_minor) >= 8: cc_flag.append("-gencode") cc_flag.append("arch=compute_89,code=sm_89") if int(bare_metal_major) >= 12: diff --git a/megatron/model/megablocks_utils.py b/megatron/model/megablocks_utils.py deleted file mode 100644 index 6f94b2b2c..000000000 --- a/megatron/model/megablocks_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Adapter to expose MegaBlocks package, if available.""" - -try: - import megablocks -except ImportError: - megablocks = None - - -def megablocks_is_available(): - return megablocks is not None - - -def assert_megablocks_is_available(): - assert ( - megablocks_is_available() - ), "MegaBlocks not available. Please run `pip install megablocks`." - - -moe = megablocks.layers.moe if megablocks_is_available() else None -dmoe = megablocks.layers.dmoe if megablocks_is_available() else None -arguments = megablocks.layers.arguments if megablocks_is_available() else None - - -def as_megablocks_args(neox_args): - import copy - - tmp = copy.copy(neox_args) - delattr(tmp, "mlp_type") - tmp.mlp_type = "mlp" - args = arguments.from_megatron(tmp) - args.moe_lbl_in_fp32 = True - args.fp16 = neox_args.precision == "fp16" - args.moe_loss_weight = neox_args.moe_loss_coeff - return args diff --git a/megatron/model/moe.py b/megatron/model/moe.py new file mode 100644 index 000000000..51791ece8 --- /dev/null +++ b/megatron/model/moe.py @@ -0,0 +1,259 @@ +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2023 MegaBlocks authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +from typing import Optional + +import megablocks.ops +import numpy as np +import torch + +from megatron import mpu +from megatron.mpu import get_expert_token_counts_for_rank +from megatron.mpu import get_expert_tokens_for_rank +from megatron.mpu import copy_to_expert_model_parallel_region +from megatron.mpu import gather_from_expert_model_parallel_region +from megatron.neox_arguments.arguments import NeoXArgs + +from .moe_mlp import ParallelGroupedLLaMAMLP, ParallelGroupedMLP +from .router import TopKTokenChoiceRouter, SinkhornRouter + + +class ParallelDroplessMLP(torch.nn.Module): + """ + This class defines MoE expert computation, using tensor (model) parallel size as the expert parallel size + + The implication of this parallelism decision is that the expert weights can only be sharded within a single node + """ + + def __init__( + self, + neox_args: NeoXArgs, + init_method, + output_layer_init_method, + ): + """ + + Bias is currently not supported + """ + super(ParallelDroplessMLP, self).__init__() + + # Calculate the number of experts to allocate on this rank + world_size = mpu.get_model_parallel_world_size() + assert neox_args.moe_num_experts % world_size == 0 + self.num_experts = neox_args.moe_num_experts + self.experts_per_rank = self.num_experts // world_size + self.top_k = neox_args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # decide which parallel grouped MLP implementation to use + if neox_args.mlp_type == "regular": + self.mlp = ParallelGroupedMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + ) + elif neox_args.mlp_type == "llama": + self.mlp = ParallelGroupedLLaMAMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + ) + else: + raise KeyError(neox_args.mlp_type) + + def indices_and_bins(self, top_expert: torch.Tensor): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + bin_ids, indices = megablocks.ops.sort(top_expert, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = megablocks.ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = megablocks.ops.inclusive_cumsum(tokens_per_expert, 0) + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + input_: torch.Tensor, + tokens_per_expert: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + expert_weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + """ + grouped_permute_and_compute + + torch.distributed.all_reduce(tensor, op=, group=None, async_op=False) + + NOTE: Megablocks sets up all MLP tensors as column parallel and uses transposes on some of the grouped_gemm calls for the ops that would be row parallel. This seems to be fine and since we aren't using the underlying NeoX ColumnParallelLinear and RowParallelLinear classes, there doesn't seem to be a reason to change it...because that'd introduce a lot of additional complexity. + + column parallel linear forward + + ```python + def forward(self, input_): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + # Matrix multiply. + + bias = self.bias if not self.skip_bias_add else None + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + ``` + """ + # Route the tokens for MoE computation. + ## stack (sl, bs, hs) into (sl * bs, hs) + input_ = input_.view(-1, input_.shape[-1]) + + ## repeat each token top_k times and shuffle tokens to group them by their respective experts + input_ = megablocks.ops.gather(input_, indices, bin_ids, bins, top_k) + + # get tokens routed to this rank's experts only + input_parallel = copy_to_expert_model_parallel_region(input_, tokens_per_expert) + + # get tokens_per_expert for this rank's experts only + # with torch.no_grad(): + local_tokens_per_expert = get_expert_token_counts_for_rank(tokens_per_expert) + # if torch.cuda.current_device() == 0: + # print(f"{torch.cuda.current_device()}: local_tokens_per_expert {local_tokens_per_expert}, global tokens {tokens_per_expert}") + + # Perform the expert computation for this rank's experts + output_parallel = self.mlp(input_parallel, local_tokens_per_expert) + + # all gather masked results from across Tensor parallel ranks here and cat them together + # this will replicate the calculation of each expert across all ranks + # NOTE: this combined all_gather and torch.cat operation is performed by gather_from_model_parallel_region(output_parallel) + # Unlike ColumnParallelLinear, it is nonsensical in the MoE world + # to optionally return the output_parallel result...we still have to scatter the tokens back to their original positions + output = gather_from_expert_model_parallel_region( + output_parallel, + tokens_per_expert, + ) + + # Un-route the data for the MoE output + return megablocks.ops.scatter( + output, + indices, + bin_ids, + expert_weights, + bins, + top_k, + ) + + def forward(self, x, expert_weights, expert_indices): + """ + grouped_forward_once + + x: [sl, bs, hs] + expert_weights: [sl * bs, top-k] + expert_indices: [sl * bs, top-k] + """ + # save shape so we can re-shape the outputs later + in_shape = x.size() + + # both are now (sl * bs * top_k) + expert_weights = expert_weights.flatten() + expert_indices = expert_indices.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins( + expert_indices + ) + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + self.top_k, + ) + + # restore input shape + x = x.view(in_shape) + return x + + +def cast_if_autocast_enabled(tensor: torch.Tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + + +class ParallelDroplessMoE(torch.nn.Module): + def __init__( + self, + neox_args: NeoXArgs, + init_method, + output_layer_init_method, + ): + super(ParallelDroplessMoE, self).__init__() + + if neox_args.moe_router_type == "sinkhorn": + self.router = SinkhornRouter( + neox_args, + init_method, + ) + elif neox_args.moe_router_type == "topk": + self.router = TopKTokenChoiceRouter( + neox_args, + init_method, + ) + else: + raise ValueError(f"Invalid MoE Router type {neox_args.moe_router_type}") + + self.experts = ParallelDroplessMLP( + neox_args, + init_method, + output_layer_init_method, + ) + + def forward(self, x): + # we expect inputs as (sl, bs, hs) + # neox provides inputs as torch.Size([2048, 4, 768]) + # (sl, bs, hs) + + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth + x = cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments + expert_weights, expert_indices = self.router(x) + + # return value should be + return self.experts(x, expert_weights, expert_indices), None diff --git a/megatron/model/moe_mlp.py b/megatron/model/moe_mlp.py new file mode 100644 index 000000000..3e5917970 --- /dev/null +++ b/megatron/model/moe_mlp.py @@ -0,0 +1,451 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2023 MegaBlocks authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.model.activations import get_activation + +from megatron.mpu.layers import _initialize_affine_weight_gpu +from megatron.mpu.initialize import get_model_parallel_world_size +from megatron.mpu.utils import divide + +from megatron.neox_arguments.arguments import NeoXArgs + +from megablocks import grouped_gemm_util as gg + + +class ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, scale): + ctx.scale = scale + return x + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +class MemoryOptimizedParallelGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # x: [m, k], w1: [n, k], w2: [n, k] + if not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous(): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, ddsd_out): + if ( + not ctx.needs_input_grad[0] + or not ctx.needs_input_grad[1] + or not ctx.needs_input_grad[2] + ): + raise ValueError("Expected all MLP inputs to need grad.") + + # Unpack saved tensors + dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + dw2 = gg.backend.gmm(activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm(ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedParallelGroupedMLP.apply + + +class ParallelGroupedMLP(torch.nn.Module): + def __init__( + self, + neox_args: NeoXArgs, + init_method, + output_layer_init_method, + stride=1, + multiple_of=256, + ): + """ + Copied from SparseMLP + """ + super(ParallelGroupedMLP, self).__init__() + + self.activation_func = get_activation(neox_args) + self.activation_type = neox_args.activation + + self.multiple_of = multiple_of + + world_size = get_model_parallel_world_size() + self.num_experts = neox_args.moe_num_experts + self.experts_per_rank = divide(self.num_experts, world_size) + + self.hidden_size = neox_args.hidden_size + + # Allow custom intermediate size + if neox_args.intermediate_size is not None: + per_expert_ff_dim = neox_args.intermediate_size + # Otherwise, 4 x hidden size, padded to multiple of 256 + else: + per_expert_ff_dim = 4 * self.hidden_size + per_expert_ff_dim = self.multiple_of * ( + (per_expert_ff_dim + multiple_of - 1) // multiple_of + ) + + self.per_expert_ff_dim = per_expert_ff_dim + # number of rows per rank is the number of experts * ff dimension + self.num_rows_per_rank = self.experts_per_rank * per_expert_ff_dim + + # input + self.w1 = torch.nn.Parameter( + torch.empty( + self.num_rows_per_rank, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.w1, init_method, partition_dim=0, stride=stride + ) + + # output + self.w2 = torch.nn.Parameter( + torch.empty( + self.num_rows_per_rank, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.w2, output_layer_init_method, partition_dim=0, stride=stride + ) + + # TODO: why do we need this? was in original megablocks code + self.gradient_scale = None + if world_size > 1: + self.gradient_scale = 1 / world_size + + def scale_grad(self, w: torch.Tensor): + """ + Copied from SparseMLP + """ + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor): + grouped_gemm_batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs + w1 = w1.view(self.experts_per_rank, -1, self.hidden_size) + w2 = w2.view(self.experts_per_rank, -1, self.hidden_size) + + # Compute the MLP + x = gg.ops.gmm(x, w1, grouped_gemm_batch_sizes, trans_b=True) + x = self.activation_func(x) + return gg.ops.gmm(x, w2, grouped_gemm_batch_sizes) + + +class MemoryOptimizedParallelGroupedLLaMAMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, w1, w3, w2, batch_sizes, activation_fn): + # x: [m, k], w1: [n, k], w3: [n, k], w2: [n, k] + if ( + not x.is_contiguous() + or not w1.is_contiguous() + or not w3.is_contiguous() + or not w2.is_contiguous() + ): + raise ValueError("Expected contiguous 'x', 'w1', 'w3' and 'w2'.") + + # Layer 0: x @ w1.t(). + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + w3_out = gg.backend.gmm(x, w3, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * w3_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w3, w2, batch_sizes, x, sdd_out, w3_out) + return dsd_out + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, ddsd_out): + if ( + not ctx.needs_input_grad[0] + or not ctx.needs_input_grad[1] + or not ctx.needs_input_grad[2] + ): + raise ValueError("Expected all MLP inputs to need grad.") + + # Unpack saved tensors + dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w3, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, w3_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + w3_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * w3_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + dw2 = gg.backend.gmm(activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm(ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dw3_out = w3_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dw3. + dw3 = gg.backend.gmm(dw3_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dw3_out, w3, batch_sizes) + return dx, dw1, dw3, dw2, None, None + + +memory_optimized_grouped_llama_mlp = MemoryOptimizedParallelGroupedLLaMAMLP.apply + + +class ParallelGroupedLLaMAMLP(torch.nn.Module): + def __init__( + self, + neox_args: NeoXArgs, + init_method, + output_layer_init_method, + stride=1, + multiple_of=256, + ): + """ + Copied from SparseMLP + """ + super(ParallelGroupedLLaMAMLP, self).__init__() + + self.activation_func = get_activation(neox_args) + self.activation_type = neox_args.activation + + self.multiple_of = multiple_of + + world_size = get_model_parallel_world_size() + self.num_experts = neox_args.moe_num_experts + self.experts_per_rank = divide(self.num_experts, world_size) + + self.hidden_size = neox_args.hidden_size + + # Allow custom intermediate size + if neox_args.intermediate_size is not None: + per_expert_ff_dim = neox_args.intermediate_size + # Otherwise, 8/3 x hidden size, padded to multiple of 256 + # TODO: why is this how we formulate it this way? + else: + per_expert_ff_dim = int(2 * neox_args.hidden_size * 4 / 3) + per_expert_ff_dim = self.multiple_of * ( + (per_expert_ff_dim + multiple_of - 1) // multiple_of + ) + + self.per_expert_ff_dim = per_expert_ff_dim + # number of rows per rank is the number of experts * ff dimension per expert + self.num_rows_per_rank = self.experts_per_rank * per_expert_ff_dim + + # input + self.w1 = torch.nn.Parameter( + torch.empty( + self.num_rows_per_rank, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.w1, init_method, partition_dim=0, stride=stride + ) + + # gate + self.w3 = torch.nn.Parameter( + torch.empty( + self.num_rows_per_rank, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.w3, init_method, partition_dim=0, stride=stride + ) + + # output + self.w2 = torch.nn.Parameter( + torch.empty( + self.num_rows_per_rank, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.w2, output_layer_init_method, partition_dim=0, stride=stride + ) + + # TODO: why do we need this? was in original megablocks code + self.gradient_scale = None + if world_size > 1: + self.gradient_scale = 1 / world_size + + def scale_grad(self, w: torch.Tensor): + """ + Copied from SparseMLP + """ + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor): + grouped_gemm_batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w3, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.w3), + self.scale_grad(self.w2), + ) + + w1 = self.w1.view(self.experts_per_rank, -1, self.hidden_size) + w3 = w3.view(self.experts_per_rank, -1, self.hidden_size) + + w2 = w2.view(self.experts_per_rank, -1, self.hidden_size) + + # return memory_optimized_grouped_llama_mlp( + # x, + # w1, + # w3, + # w2, + # grouped_gemm_batch_sizes, + # self.activation_func + # ) + + llama_x_w1T = gg.ops.gmm(x, w1, grouped_gemm_batch_sizes, trans_b=True) + + llama_x_w3T = gg.ops.gmm(x, w3, grouped_gemm_batch_sizes, trans_b=True) + + llama_act_x_w1T = self.activation_func(llama_x_w1T) + + # self.w2(self.activation_func(w1_out) * w3_out) + llama_mlp_out = gg.ops.gmm( + llama_act_x_w1T + * llama_x_w3T, # activation results gated (element-wise) with w3 + w2, # w2 + grouped_gemm_batch_sizes, # batch_sizes + ) + + return llama_mlp_out diff --git a/megatron/model/router.py b/megatron/model/router.py new file mode 100644 index 000000000..86fbbe805 --- /dev/null +++ b/megatron/model/router.py @@ -0,0 +1,273 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2023 MegaBlocks authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron.neox_arguments.arguments import NeoXArgs +from megatron.mpu import get_model_parallel_group, get_model_parallel_rank + + +class SinkhornRouter(torch.nn.Module): + # TODO: reduce precision on expert_indices? it looks like it's currently int64 + # TODO: how do we ensure that all copies of the router get the same + # initializations and stay in sync over time? Or is this handled by RNG seeding? + + ### Sinkhorn + + # - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py + # - https://github.com/fanshiqing/grouped_gemm + # - NVIDIA forked original implementation and is using this in Megatron Core now + # - https://github.com/NVIDIA/Megatron-LM/blob/cafda9529d9956578014d4cb89b69b741702b514/megatron/core/transformer/moe/router.py#L215: this his how megatron actually does its router forward pass + + def __init__( + self, + neox_args: NeoXArgs, + init_method, + ): + super().__init__() + self.top_k = neox_args.moe_top_k + self.params_dtype = neox_args.params_dtype + + # expert parallel group rank, for purposes of deciding if I should compute the router or wait for the result to be broadcast to me + self.expert_parallel_group = get_model_parallel_group() + self.expert_parallel_rank = get_model_parallel_rank() + + # Sinkhorn router parameters. + # + # NOTE: This weight matrix is not parallelized with expert tensor + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + neox_args.hidden_size, + neox_args.moe_num_experts, + bias=False, + dtype=neox_args.params_dtype, + device=torch.cuda.current_device(), + ) + init_method(self.layer.weight) + + def sinkhorn(self, cost: torch.Tensor, tol: float = 0.0001): + """Sinkhorn based MoE routing function""" + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) + d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) + error = torch.mean(torch.abs(d1_old - d1)) + d1_old = d1 + return d1 * cost * d0.unsqueeze(1) + + def sinkhorn_load_balancing(self, logits: torch.Tensor): + """Apply sinkhorn routing to the logits tensor. + + Args: + logits (torch.Tensor): The logits tensor, as (bs * sl, hidden_size) + + Returns: + torch.Tensor: The logits tensor after applying sinkhorn routing. + """ + + def _sinkhorn_activation(logits): + if self.top_k == 1: + logits = torch.sigmoid(logits) + else: # k > 1 + logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as( + logits + ) + return logits + + # assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss." + if self.training: + with torch.no_grad(): + norm_logits = self.sinkhorn( + logits.to(dtype=torch.float32) + ) # explicit fp32 conversion for stability + _, indices = torch.topk(norm_logits, k=self.top_k, dim=1) + logits = _sinkhorn_activation(logits) + scores = torch.gather(logits, 1, indices) + # at inference, just top_k it...sinkhorn algorithm doesn't support autoregressive generation + else: + logits = _sinkhorn_activation(logits) + scores, indices = torch.topk(logits, k=self.top_k, dim=1) + return scores, indices + + def forward(self, x): + """ + Forward pass through the Sinkhorn Router. + + Only compute on rank 0 in the expert parallel group and broadcast to everyone else to avoid weird states where things get out of sync. + + Args: + x (torch.Tensor): Input tensor to be routed. + (sl, bs, hs) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing + - expert_weights (sl * bs, top_k): Weights assigned to the selected experts + - expert_indices (sl * bs, top_k): Indices of the selected experts + """ + if self.expert_parallel_rank == 0: + # x.view shape: (sl * bs, hs)...every token as a row + # router_logits (float) shape: (sl * bs, num_experts)...expert rankings for every token + router_logits = self.layer(x.view(-1, x.shape[-1])) + + # expert_weights (float) shape: (sl * bs, top_k)...value(s) from scores corresponding to the top_k experts + # expert_indices (int) shape: (sl * bs, top_k)...index(indices) from scores corresponding to the top_k experts + expert_weights, expert_indices = self.sinkhorn_load_balancing(router_logits) + + # broadcast the routing result to all ranks + expert_weights_broadcast = torch.distributed.broadcast( + expert_weights, + src=torch.distributed.get_global_rank(self.expert_parallel_group, 0), + group=self.expert_parallel_group, + async_op=True, + ) + expert_indices_broadcast = torch.distributed.broadcast( + expert_indices, + src=torch.distributed.get_global_rank(self.expert_parallel_group, 0), + group=self.expert_parallel_group, + async_op=True, + ) + else: + # sl * bs + num_rows = x.view(-1, x.shape[-1]).shape[0] + expert_weights = torch.empty( + num_rows, + self.top_k, + device=torch.cuda.current_device(), + dtype=self.params_dtype, + ) + expert_indices = torch.empty( + num_rows, + self.top_k, + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + expert_weights_broadcast = torch.distributed.broadcast( + expert_weights, + src=torch.distributed.get_global_rank(self.expert_parallel_group, 0), + group=self.expert_parallel_group, + async_op=True, + ) + expert_indices_broadcast = torch.distributed.broadcast( + expert_indices, + src=torch.distributed.get_global_rank(self.expert_parallel_group, 0), + group=self.expert_parallel_group, + async_op=True, + ) + + # since both are executing asynchronously, it doesn't matter which one + # we wait for first + expert_weights_broadcast.wait() + expert_indices_broadcast.wait() + + return expert_weights, expert_indices + + +class TopKTokenChoiceRouter(torch.nn.Module): + # TODO: how do we ensure that all copies of the router get the same + # initializations and stay in sync over time? Or is this handled by RNG seeding? + + def __init__( + self, + neox_args: NeoXArgs, + init_method, + ): + super().__init__() + self.jitter_eps = neox_args.moe_jitter_eps + self.top_k = neox_args.moe_top_k + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert tensor + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + neox_args.hidden_size, + neox_args.moe_num_experts, + bias=False, + dtype=neox_args.params_dtype, + device=torch.cuda.current_device(), + ) + init_method(self.layer.weight) + + def jitter(self, x): + """ + Apply jittering to the input tensor during training. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Jittered input tensor. + """ + low = 1.0 - self.args.moe_jitter_eps + high = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores): + """ + Select the top-k experts based on input scores. + + Args: + scores (torch.Tensor): Input scores from the router. + (sl * bs, num_experts) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing expert weightings and indices of selected experts. + + + """ + if self.top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.top_k, dim=-1) + + def forward(self, x): + """ + Forward pass through the Learned Router. + + Args: + x (torch.Tensor): Input tensor to be routed. + (sl, bs, hs) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing + - expert_weights (sl * bs, top_k): Weights assigned to the selected experts + - expert_indices (sl * bs, top_k): Indices of the selected experts + """ + if self.training and self.jitter_eps is not None: + x = x * self.jitter(x) + + # x.view shape: (sl * bs, hs)...every token as a row + # scores (float) shape: (sl * bs, num_experts)...expert rankings for every token + scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) + + # expert_weights (float) shape: (sl * bs, top_k)...value(s) from scores corresponding to the top_k experts + # expert_indices (int) shape: (sl * bs, top_k)...index(indices) from scores corresponding to the top_k experts + expert_weights, expert_indices = self._top_k(scores) + # expert_weights probability mass won't add up to 1 because we took + # the topk scores from the softmax + # TODO: placeholder for moe_normalize_expert_weights if necessary + + return expert_weights, expert_indices diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..694d58166 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -24,9 +24,10 @@ from pkg_resources import packaging from importlib.metadata import version +from megatron.model.moe import ParallelDroplessMoE + from .norms import get_norm from megatron import mpu -from megatron.model import megablocks_utils from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.activations import get_activation from megatron.model.utils import exists, get_fusion_type @@ -46,7 +47,6 @@ bias_dropout_add_fused_inference, ) from megatron.model.utils import configure_sparse_attention -from deepspeed.moe.layer import MoE # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -93,8 +93,6 @@ def __init__( init_method, output_layer_init_method, parallel_output=False, - MOE=False, - MoE_mp_size=1, ): super().__init__() @@ -116,8 +114,6 @@ def __init__( gather_output=False, init_method=init_method, skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, ) ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. @@ -129,8 +125,6 @@ def __init__( init_method=output_layer_init_method, parallel_output=parallel_output, skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size, ) def forward(self, hidden_states): @@ -172,8 +166,6 @@ def __init__( output_layer_init_method, parallel_output=False, multiple_of=256, - MOE=False, - MoE_mp_size=1, ): super().__init__() @@ -197,8 +189,6 @@ def __init__( init_method=init_method, skip_bias_add=True, bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, ) self.w3 = mpu.ColumnParallelLinear( neox_args=neox_args, @@ -208,8 +198,6 @@ def __init__( init_method=init_method, skip_bias_add=True, bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, ) self.w2 = mpu.RowParallelLinear( neox_args=neox_args, @@ -220,8 +208,6 @@ def __init__( skip_bias_add=True, parallel_output=parallel_output, bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, ) def forward(self, hidden_states): @@ -277,55 +263,6 @@ def forward(self, hidden_states): return self.final_linear(hidden_states) -class _MegablocksAdapter(nn.Module): - def __init__( - self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group - ): - super().__init__() - megablocks_utils.assert_megablocks_is_available() - args = megablocks_utils.as_megablocks_args(neox_args) - args.device = torch.cuda.current_device() - args.init_method = init_method - args.output_layer_init_method = output_layer_init_method - - # NOTE: Shard the MoE layers over the data parallel group. Expert - # parallel sharding and data parallel sharding could be decoupled - # by extending the optimizer to handle data parallel reductions for - # MoE and non-MoE parameters separately. - if args.moe_expert_model_parallelism: - args.expert_parallel_group = ep_group - - if neox_args.moe_glu: - args.mlp_type = "glu" - - self.moe = layer_cls(args) - - def forward(self, x): - return self.moe.forward(x) - - -class MbMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.moe.MoE, - init_method, - output_layer_init_method, - ep_group, - ) - - -class dMoE(_MegablocksAdapter): - def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): - super().__init__( - neox_args, - megablocks_utils.dmoe.dMoE, - init_method, - output_layer_init_method, - ep_group, - ) - - class ParallelSelfAttention(nn.Module): """Parallel self-attention layer abstract class. @@ -1008,7 +945,6 @@ def __init__( super().__init__() self.layer_number = layer_number - self.neox_args = neox_args norm, eps = get_norm(neox_args) @@ -1021,7 +957,11 @@ def __init__( self.gpt_j_residual = neox_args.gpt_j_residual self.gpt_j_tied = neox_args.gpt_j_tied self.mlp_type = neox_args.mlp_type - self.moe_type = neox_args.moe_type + self.num_experts = ( + neox_args.moe_num_experts + if layer_number % neox_args.moe_expert_interval == 0 + else 1 + ) if self.gpt_j_residual: self.reduce = mpu.mappings.reduce_from_model_parallel_region @@ -1044,7 +984,7 @@ def __init__( # leads to cleaner code self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) - # MLP + # Dense MLP selector def get_mlp(mlp_type, **kw): if mlp_type == "regular": return ParallelMLP( @@ -1065,103 +1005,16 @@ def get_mlp(mlp_type, **kw): else: raise KeyError(mlp_type) - self.num_experts = ( - neox_args.moe_num_experts - if layer_number % neox_args.expert_interval == 0 - else 1 - ) - args = neox_args + # Dense MLP if self.num_experts <= 1: self.mlp = get_mlp(neox_args.mlp_type) + # Dropless MoE MLP else: - from torch import distributed as dist - - if self.num_experts > dist.get_world_size(): - moe_mp_size = 1 - else: - moe_mp_size = dist.get_world_size() // self.num_experts - - if neox_args.moe_type == "deepspeed": - self.mlp = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - elif neox_args.moe_type == "megablocks": - - def integrate_megablocks_with_ds_expert_parallelism(): - # We make megablocks work with DS parallelism. - # - # We fool DS into accepting these MoE parameters as its own DS MoE params, - # which makes things work with the underlying expert parallelism, - # including TED parallelism. - # - # Effectively, we want to: - # - # - Make DS's data parallel gradient all-reduction skip these params. - # - But make these params participate in the expert parallel all-reduction! - # - # Further background: - # - # Normally, with the original megablocks demo codebase, it - # only supports 1 copy of any expert throughout - # the network, since it uses EP group = DP group. - # - # First, we trigger DS initialization of the MoE expert parallel groups and internal state. - throwaway = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, - ) - throwaway.set_deepspeed_parallelism() - - ep_group = throwaway.deepspeed_moe.ep_group - if args.moe_token_dropping: - self.mlp = MbMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - else: - self.mlp = dMoE( - neox_args, init_method, output_layer_init_method, ep_group - ) - - # Next, we trick DS into seeing these as its own MoE params. - for param in self.mlp.parameters(): - if getattr(param, "expert_model_parallel", None) is not None: - # is_moe_param looks for this attr. - param.allreduce = False - param.group_name = throwaway.expert_group_name - - integrate_megablocks_with_ds_expert_parallelism() - - else: - raise KeyError(neox_args.moe_type) + self.mlp = ParallelDroplessMoE( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + ) self.layer_past = None # used to cache k/v pairs in inference @@ -1179,7 +1032,7 @@ def _get_bias_dropout(self): def forward(self, x, attention_mask, layer_past=None): layer_past = layer_past if layer_past is not None else self.layer_past bias_dropout_fn = self._get_bias_dropout() - moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) + # x: [b, s, h] if self.gpt_j_residual: # pseudocode: @@ -1261,29 +1114,13 @@ def forward(self, x, attention_mask, layer_past=None): # output = x + mlp(ln2(x)) layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) - else: - raise KeyError(self.moe_type) + # call signatures of both dense and MoE are the same + mlp_output, mlp_bias = self.mlp(layernorm_output) with torch.enable_grad(): - if ( - self.mlp_type == "llama" - or self.num_experts > 1 - and self.moe_type == "deepspeed" - ): + # dense llama MLP and MoE don't support bias + if self.mlp_type == "llama" or self.num_experts > 1: # No dropout either assert mlp_bias is None output = mlp_output + attention_output @@ -1295,7 +1132,7 @@ def forward(self, x, attention_mask, layer_past=None): prob=self.hidden_dropout, ) - return output, moe_loss + return output class ParallelTransformerLayerPipe(ParallelTransformerLayer): @@ -1307,10 +1144,7 @@ def forward(self, args): ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" hidden_states, attention_mask = args # we are returning just [hidden_states, mask] - output, moe_loss = super().forward(hidden_states, attention_mask) - # auxiliary output - self.last_moe_loss = moe_loss - return output, attention_mask + return super().forward(hidden_states, attention_mask), attention_mask class ParallelLinearPipe(ParallelLinear): diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 2365507d9..f12ba7da7 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -35,6 +35,8 @@ from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_world_size from .initialize import get_io_parallel_group +from .initialize import get_expert_tokens_for_rank +from .initialize import get_expert_token_counts_for_rank from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized @@ -44,7 +46,9 @@ from .layers import ParallelRelativePositionBias from .mappings import copy_to_model_parallel_region +from .mappings import copy_to_expert_model_parallel_region from .mappings import gather_from_model_parallel_region +from .mappings import gather_from_expert_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 19d231524..9f73fd3dc 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -18,6 +18,7 @@ """Model and data parallel groups.""" +from typing import Optional import torch from .utils import ensure_divisibility @@ -266,6 +267,51 @@ def get_pipe_parallel_world_size(): return torch.distributed.get_world_size(group=get_pipe_parallel_group()) +def get_expert_tokens_for_rank( + routed_tokens: torch.Tensor, + tokens_per_expert: torch.Tensor, + rank: Optional[int] = None, +): + """ + Allow user to specify rank, fall back on this device + """ + # Calculate cumulative sums of tokens_per_expert, ensure the shapes are correct + world_size = get_model_parallel_world_size() + if rank is None: + rank = get_model_parallel_rank() + + # TODO: is this check necessary here/what does it cost us to redundantly do it in multiple places? + assert tokens_per_expert.shape[0] % world_size == 0 + + cumulative_sums = torch.cumsum(tokens_per_expert, dim=0) + assert cumulative_sums[-1] == routed_tokens.shape[0] + + # select the right starting and ending indices from the cumsum to figure out what tokens to select + rank_expert_indices = cumulative_sums.chunk(world_size) + start_index = rank_expert_indices[rank - 1][-1] if rank > 0 else 0 + end_index = rank_expert_indices[rank][-1] + + # Use indices to select the chunk of the tokens matrix + selected_experts = routed_tokens[start_index:end_index] + + return selected_experts + + +def get_expert_token_counts_for_rank( + tokens_per_expert: torch.Tensor, rank: Optional[int] = None +): + """ + Allow user to specify rank, fall back on this device + """ + # TODO: add bounds checking of size is 1D for tokens_per_expert + # should be (num_experts) long + world_size = get_model_parallel_world_size() + if rank is None: + rank = get_model_parallel_rank() + + return tokens_per_expert.chunk(world_size)[rank] + + def set_tensor_model_parallel_world_size(world_size): """Set the tensor model parallel size""" set_model_parallel_world_size(world_size) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 0d14806ac..19dff0b5f 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -413,8 +413,6 @@ def __init__( stride=1, keep_master_weight_for_test=False, skip_bias_add=False, - MOE=False, - MoE_mp_size=1, mup_rescale_parameters=False, ): super(ColumnParallelLinear, self).__init__() @@ -424,7 +422,7 @@ def __init__( self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + world_size = get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add self.init_method = init_method @@ -607,8 +605,6 @@ def __init__( stride=1, keep_master_weight_for_test=False, skip_bias_add=False, - MOE=False, - MoE_mp_size=1, parallel_output=False, mup_rescale_parameters=False, ): @@ -619,7 +615,7 @@ def __init__( self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. - world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + world_size = get_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.parallel_output = parallel_output diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 535fe6255..5a2880b46 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -18,10 +18,12 @@ import torch from .initialize import ( + get_expert_tokens_for_rank, get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank, get_fp32_allreduce, + get_expert_token_counts_for_rank, ) from .utils import split_tensor_along_last_dim @@ -89,7 +91,99 @@ def _gather(input_): torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() + output = torch.cat(tensor_list, dim=last_dim) + + # Bf16 convert + if dt == torch.bfloat16 and get_fp32_allreduce(): + output = output.bfloat16() + + return output + + +def _dmoe_reduce(input_, tokens_per_expert): + """All-reduce the the dMoE input tensor across model parallel group.""" + # Bypass the function if we are using only 1 GPU. + if get_model_parallel_world_size() == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and get_fp32_allreduce(): + input_ = input_.float() + + output = torch.zeros( + (sum(tokens_per_expert), input_.shape[-1]), + dtype=input_.dtype, + device=input_.device, + ) + world_size = get_model_parallel_world_size() + rank = get_model_parallel_rank() + + cumulative_sums = torch.cumsum(tokens_per_expert, dim=0) + + # select the right starting and ending indices from the cumsum to figure out what tokens to select + rank_expert_indices = cumulative_sums.chunk(world_size) + start_index = rank_expert_indices[rank - 1][-1] if rank > 0 else 0 + end_index = rank_expert_indices[rank][-1] + + output[start_index:end_index] = input_ + + # All-reduce. + torch.distributed.all_reduce(output, group=get_model_parallel_group()) + + # Bf16 convert + if dt == torch.bfloat16 and get_fp32_allreduce(): + output = output.bfloat16() + + return output + + +def _dmoe_split(input_, tokens_per_expert): + """Split the tensor along its first dimension according to where tokens + were routed, keeping the corresponding slice.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension, getting the expert tokens + output = get_expert_tokens_for_rank(input_, tokens_per_expert) + + return output + + +def _dmoe_gather(input_: torch.Tensor, tokens_per_expert: torch.Tensor): + """Gather tensors and concatinate along the first dimension)""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and get_fp32_allreduce(): + input_ = input_.float() + + # Gather along first dimension + gather_dim = 0 + rank = get_model_parallel_rank() + + tokens_by_rank = [ + get_expert_token_counts_for_rank(tokens_per_expert, r) + for r in range(world_size) + ] + # print(f"{torch.cuda.current_device()}: tokens_by_rank {tokens_by_rank}") + tensor_list = [ + torch.empty(sum(r), input_.shape[-1], device=input_.device, dtype=input_.dtype) + for r in tokens_by_rank + ] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=gather_dim) # Bf16 convert if dt == torch.bfloat16 and get_fp32_allreduce(): @@ -114,6 +208,31 @@ def backward(ctx, grad_output): return _reduce(grad_output) +class _CopyToExpertModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_, tokens_per_expert): + # TODO: not sure if this is sufficient? not sure how this gets used downstream... + return get_expert_tokens_for_rank(input_, tokens_per_expert) + + @staticmethod + def forward(ctx, input_, tokens_per_expert): + # Save tokens_per_expert in the context for later use in the backward pass + ctx.save_for_backward(tokens_per_expert) + + return get_expert_tokens_for_rank(input_, tokens_per_expert) + + @staticmethod + def backward(ctx, grad_output): + # Retrieve the tokens_per_expert from the context + (tokens_per_expert,) = ctx.saved_tensors + + # no grad for tokens_per_expert + # return _dmoe_reduce(grad_output, tokens_per_expert), None + return _dmoe_gather(grad_output, tokens_per_expert), None + + class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -162,6 +281,34 @@ def backward(ctx, grad_output): return _split(grad_output) +class _GatherFromExpertModelParallelRegion(torch.autograd.Function): + """Gather the input from expert model parallel region and concatinate. + + The major difference between this and _GatherFromModelParallelRegion is in the + dMoE case, we need to gather & split along the first dimension, not the last + """ + + @staticmethod + def symbolic(graph, input_, tokens_per_expert): + # TODO: not sure if this is sufficient? not sure how this gets used downstream... + return _dmoe_gather(input_, tokens_per_expert) + + @staticmethod + def forward(ctx, input_, tokens_per_expert): + # Save tokens_per_expert in the context for later use in the backward pass + ctx.save_for_backward(tokens_per_expert) + + return _dmoe_gather(input_, tokens_per_expert) + + @staticmethod + def backward(ctx, grad_output): + # Retrieve the tokens_per_expert from the context + (tokens_per_expert,) = ctx.saved_tensors + + # no grad for tokens_per_expert + return _dmoe_split(grad_output, tokens_per_expert), None + + # ----------------- # Helper functions. # ----------------- @@ -171,6 +318,10 @@ def copy_to_model_parallel_region(input_): return _CopyToModelParallelRegion.apply(input_) +def copy_to_expert_model_parallel_region(input_, tokens_per_expert): + return _CopyToExpertModelParallelRegion.apply(input_, tokens_per_expert) + + def reduce_from_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) @@ -181,3 +332,7 @@ def scatter_to_model_parallel_region(input_): def gather_from_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + + +def gather_from_expert_model_parallel_region(input_, tokens_per_expert): + return _GatherFromExpertModelParallelRegion.apply(input_, tokens_per_expert) diff --git a/megatron/neox_arguments/__init__.py b/megatron/neox_arguments/__init__.py index 025464cbf..42c0f5e90 100644 --- a/megatron/neox_arguments/__init__.py +++ b/megatron/neox_arguments/__init__.py @@ -24,7 +24,7 @@ **code structure** -* NeoX args (in ./arguments) inherits from the following subclasses: NeoXArgsDeepspeedRunner, NeoXArgsDeepspeedConfig, NeoXArgsModel, NeoXArgsTokenizer, NeoXArgsTraining, NeoXArgsParallelism, NeoXArgsLogging, NeoXArgsOther, NeoXArgsTextgen +* NeoX args (in ./arguments) inherits from the following subclasses: NeoXArgsDeepspeedRunner, NeoXArgsDeepspeedConfig, NeoXArgsModel, NeoXArgsTokenizer, NeoXArgsTraining, NeoXArgsParallelism, NeoXArgsLogging, NeoXArgsOther, NeoXArgsTextgen, NeoXArgsMoE * The Subclasses group args according to their purpose * The attributes of NeoXArgsDeepspeedRunner are directly mapped to the expected command line args of deepspeed.launcher.runner.main; no attributes unknown to deepspeed should be included; no arguments relevant for deepspeed should be omitted * The attributes of NeoXArgsDeepspeedConfig are directly mapped to the expected keys of the deepspeed config; no arguments relevant for deepspeed should be omitted diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 98a444ea4..a41874971 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -42,6 +42,7 @@ NeoXArgsTokenizer, NeoXArgsTraining, NeoXArgsParallelism, + NeoXArgsMoE, NeoXArgsLogging, NeoXArgsOther, NeoXArgsTextgen, @@ -89,6 +90,7 @@ NeoXArgsDeepspeedRunner, NeoXArgsDeepspeedConfig, NeoXArgsModel, + NeoXArgsMoE, NeoXArgsLRScheduler, NeoXArgsOptimizer, NeoXArgsTokenizer, @@ -1031,18 +1033,9 @@ def calculate_derived(self): self.update_value("dynamic_loss_scale", self.loss_scale is None) # Update 'is pipe parallel' flag - # if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with + # if we set pipe_parallel_size to 0, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs - self.update_value( - "is_pipe_parallel", - self.pipe_parallel_size > 1 and self.moe_num_experts == 1, - ) - if self.moe_num_experts > 1: - assert not ( - self.is_pipe_parallel or self.pipe_parallel_size > 1 - ), "MoE not supported with pipeline parallelism" - assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" - assert self.mlp_type == "regular", "MoE not compatible with LLaMA" + self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) # Attention config if self.attention_config is None: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..3083b7282 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -85,11 +85,6 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ - expert_interval: int = 2 - """ - Have one MoE layer every expert_interval layers - """ - @dataclass class NeoXArgsModel(NeoXArgsTemplate): @@ -1266,69 +1261,32 @@ class NeoXArgsTextgen(NeoXArgsTemplate): NOTE: Requires internet connection """ - moe_top_k: int = 1 - """ - Activate top K experts in MoE - """ - use_tutel: bool = False +@dataclass +class NeoXArgsMoE(NeoXArgsTemplate): """ - Use Tutel optimizations in MoE + Mixture of Expert (MoE) Arguments """ moe_num_experts: int = 1 """ - Number of MoE experts - """ - - moe_loss_coeff: float = 0.1 - """ - Coefficient for MoE loss - """ - - moe_train_capacity_factor: float = 1.0 - """ - The capacity of the expert at train time - """ - - moe_eval_capacity_factor: float = 1.0 - """ - The capacity of the expert at eval time + The number of experts in MoE layers. MoE layers not used if set to 1 """ - moe_min_capacity: int = 4 + moe_expert_interval: int = 1 """ - The minimum capacity per expert regardless of the capacity_factor - """ - - moe_token_dropping: bool = False - """ - Whether to drop tokens when exceeding capacity - """ - - create_moe_param_group: bool = True - """ - Whether to create a separate parameter group for MoE parameters - """ - - moe_use_residual: bool = True - """ - Whether to use residual in MoE - """ - - moe_expert_parallel_size: int = 1 - """ - Number of parallel experts in MoE + Have one MoE layer every expert_interval layers """ - moe_type: str = "megablocks" + moe_top_k: int = 1 """ - Either `deepspeed` or `megablocks` + The number of experts each token is routed to in MoE layers. """ - moe_glu: bool = False + moe_router_type: Literal["sinkhorn", "topk"] = "sinkhorn" """ - Use gated linear units in MoE + What token routing algorithm to use. Currently only sinkhorn is supported for training. + TopK is only used for inference/eval. """ moe_lbl_in_fp32: bool = False @@ -1341,8 +1299,3 @@ class NeoXArgsTextgen(NeoXArgsTemplate): Coefficient for MoE routing jitter. Jitter is not used if set to None """ - - enable_expert_tensor_parallelism: bool = False - """ - Enable expert tensor parallelism - """ diff --git a/megatron/training.py b/megatron/training.py index 3265680c5..6a67d36f8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -365,56 +365,6 @@ def get_batch_sequential(forward_input, neox_args): return (forward_input[0], forward_input[1], attention_mask) -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group()) - averaged_losses = averaged_losses / torch.distributed.get_world_size( - group=mpu.get_data_parallel_group() - ) - - return averaged_losses - - -def mb_moe_loss_func(args, loss_mask, output_tensor=None): - from megatron.model import megablocks_utils - from megatron.model.megablocks_utils import moe - - # NOTE: For pipeline parallelism this function will be run on the - # non-final stages to calculate load balancing loss contribution - # for the MoE layers within the stage. For these cases, output_tensor - # will be None. - loss, loss_dict = (None, {}) - if False: - assert output_tensor is not None - loss, loss_dict = loss_func(loss_mask, output_tensor) - assert loss.numel() == 1 - - # NOTE: If recompute is enabled we will collect duplicate load - # balancing loss contributions. Prune these before calculating - # the load balancing loss. - if args.checkpoint_activations: - # Ignore load balancing loss contributions compute during - # the forward pass if recompute is turned on. - load_balancing_loss_data = moe.get_load_balancing_loss() - if args.num_layers * 2 == len(load_balancing_loss_data): - load_balancing_loss_data = load_balancing_loss_data[args.num_layers :] - moe.clear_load_balancing_loss() - for x in load_balancing_loss_data: - moe.save_load_balancing_loss(x) - - # Compute the load balancing loss for all MoE layers. - megablocks_args = args = megablocks_utils.as_megablocks_args(args) - lbl = moe.batched_load_balancing_loss(megablocks_args) - moe.clear_load_balancing_loss() - - # Average the load balancing loss across data parallel - # replicas and save for logging. - averaged_lbl = average_losses_across_data_parallel_group([lbl]) - loss_dict["load balancing loss"] = averaged_lbl[0] - return averaged_lbl, loss_dict - - def forward_step( data_iterator, model, neox_args, timers, return_logits=False, is_train=False ): @@ -438,13 +388,7 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") - # Sequential returns moe_losses, but this is not yet supported by pipe parallel - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, moe_losses = maybe_tuple - else: - outputs = maybe_tuple - moe_losses = [] + outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args) if ( is_train and neox_args.curriculum_learning @@ -452,19 +396,9 @@ def forward_step( ): loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() labels = labels[:, : neox_args.curriculum_seqlen].contiguous() - main_loss = cross_entropy( + loss = cross_entropy( outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy ) - if neox_args.moe_num_experts > 1: - if neox_args.moe_type == "deepspeed": - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) - elif neox_args.moe_type == "megablocks": - moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] - else: - raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") - else: - moe_loss = 0.0 - loss = main_loss + moe_loss if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -560,16 +494,6 @@ def get_optimizer(model, neox_args): f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}' ) - if neox_args.create_moe_param_group: - from deepspeed.moe.utils import ( - is_moe_param, - split_params_into_different_moe_groups_for_optimizer, - ) - - param_groups = split_params_into_different_moe_groups_for_optimizer( - param_groups - ) - # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group["params"]: @@ -765,9 +689,6 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) - if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": - # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. - model.has_moe_layers = True model.total_params = get_total_params(model.module) print_rank_0(f' > total params: {"{:,}".format(model.total_params)}') diff --git a/requirements/requirements-moe.txt b/requirements/requirements-moe.txt new file mode 100644 index 000000000..e75e5e9fd --- /dev/null +++ b/requirements/requirements-moe.txt @@ -0,0 +1,2 @@ +grouped-gemm==0.1.4 +megablocks==0.5.1 \ No newline at end of file