From dd9b80d95174a4787a37f4655ddd307ba7ecde85 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 3 Dec 2024 18:40:27 +0100 Subject: [PATCH] [WIP] integrate NxDOptimizer --- optimum/neuron/accelerate/accelerator.py | 69 +++++++++++++++++++----- optimum/neuron/accelerate/optimizer.py | 17 +++--- optimum/neuron/accelerate/state.py | 30 +++++++---- optimum/neuron/trainers.py | 53 ++++++++++++++++-- 4 files changed, 133 insertions(+), 36 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index dc4915cc3..643500c10 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -23,7 +23,7 @@ import warnings from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, Dict import torch from accelerate import Accelerator @@ -81,6 +81,7 @@ xm = None if is_neuronx_distributed_available(): + import neuronx_distributed as nxd from neuronx_distributed.utils.model_utils import move_model_to_device @@ -98,6 +99,7 @@ class NeuronAccelerator(Accelerator): def __init__( self, + nxd_config: Dict[str, Any], *args, mp_plugin: Optional[ModelParallelismPlugin] = None, zero_1: bool = False, @@ -146,13 +148,17 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo accelerate.state.is_torch_xla_available = patched_is_torch_xla_available - patched_accelerator_state = partial( - NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend - ) - with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]): + self.mp_plugin = mp_plugin + self.nxd_config = nxd_config + + # patched_accelerator_state = partial( + # NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend + # ) + # with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]): + with Patcher([("accelerate.accelerator.AcceleratorState", NeuronAcceleratorState)]): super().__init__(**full_kwargs) - self.zero_1 = zero_1 + self.zero_1 = self.nxd_config["optimizer_config"]["zero_one_enabled"] if self.autocast_handler is None: enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP @@ -300,17 +306,32 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device ) return zero_1_optimizer + @requires_neuronx_distributed @patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer)) def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None): - if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: - optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement) - if self.zero_1: - optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement) + import neuronx_distributed as nxd + + #cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values()) + #xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla) + #print(xla_parameters) + + optimizer = nxd.initialize_parallel_optimizer( + self.nxd_config, + optimizer.__class__, + # xla_parameters, + optimizer.param_groups, + **optimizer.defaults, + ) + optimizer.zero_grad() + # if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: + # optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement) + # if self.zero_1: + # optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement) # Edge case: if the optimizer was created lazily outside of the Model Parallelism and/or ZeRO-1 setting, we make # sure to actually load the proper parameters. - if hasattr(optimizer, "_args_to_recreate"): - args, kwargs = optimizer._args_to_recreate - optimizer = optimizer.__class__(*args, **kwargs) + # if hasattr(optimizer, "_args_to_recreate"): + # args, kwargs = optimizer._args_to_recreate + # optimizer = optimizer.__class__(*args, **kwargs) return super().prepare_optimizer(optimizer, device_placement=device_placement) @@ -449,6 +470,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): def prepare_model( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False ): + print("Prepare model") # If the model was already prepared, we skip. if model in self._models: return model @@ -500,7 +522,26 @@ def backward(self, loss, **kwargs): self.scaler.scale(loss).backward(**kwargs) else: loss.backward(**kwargs) - + # vector_norm = [torch.vector_norm(p.grad, 2) for p in self._models[0].parameters() if p.requires_grad] + # norm = torch.nn.utils.clip_grad_norm_([p for p in self._models[0].parameters() if p.requires_grad], 1.0) + # xm.mark_step() + # print(vector_norm) + # self._models[0].to("cpu") + # print(self._models[0]) + # print(norm) + # for n, p in self._models[0].named_parameters(): + # if not p.requires_grad or p.grad is None: + # continue + # p = p.grad + # print(f"Gradient of {n}") + # print(f"Min: {p.min():.3f}") + # print(f"Max: {p.max():.3f}") + # print(f"Mean: {p.mean():.3f}") + # print(f"Std: {p.std():.3f}") + # print(f"L1 norm: {p.norm(p=1):.3f}") + # print(f"L2 norm: {p.norm(p=2):.3f}") + # assert 3==2 + @contextlib.contextmanager def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None): if cache_enabled: diff --git a/optimum/neuron/accelerate/optimizer.py b/optimum/neuron/accelerate/optimizer.py index c75074c20..2e7df5a32 100644 --- a/optimum/neuron/accelerate/optimizer.py +++ b/optimum/neuron/accelerate/optimizer.py @@ -73,14 +73,18 @@ def __init__( self.parameters = [p for group in self.optimizer.param_groups for p in group["params"]] self.parameter_ids = {id(p) for p in self.parameters} + self.total_grad_norm = [] + # TODO: might be needed to override this soon. def load_state_dict(self, state_dict): return super().load_state_dict(state_dict) def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2): parameter_ids = {id(p) for p in parameters} - if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer): - self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type} + # if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer): + # assert 3==2 + self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type} + return self.total_grad_norm @requires_neuronx_distributed def step(self, closure=None): @@ -88,9 +92,8 @@ def step(self, closure=None): from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients if self.gradient_state.sync_gradients: - # For sequence-parallel, we have to explicitly all-reduce the layernorm gradients. - if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: - allreduce_sequence_parallel_gradients(self.optimizer) + self.optimizer.step() + return if isinstance(self.optimizer, ZeroRedundancyOptimizer): if self.clip_grad_norm_to_perform is not None: @@ -113,7 +116,9 @@ def step(self, closure=None): if parallel_layers.parallel_state.get_data_parallel_size() > 1: bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer)) if self.clip_grad_norm_to_perform is not None: - parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform) + self.total_grad_norm.clear() + total_grad_norm = parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform) + self.total_grad_norm.append(total_grad_norm) self.clip_grad_norm_to_perform = None self.optimizer.step() elif self.scaler is not None: diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index 51f87f9d1..001abf859 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -45,7 +45,8 @@ import torch_xla.core.xla_model as xm if is_neuronx_distributed_available(): - from neuronx_distributed.parallel_layers import parallel_state + import neuronx_distributed as nxd + # from neuronx_distributed.parallel_layers import parallel_state logger = logging.get_logger() @@ -146,7 +147,7 @@ def __init__( os.environ["ACCELERATE_USE_AMP"] = "true" NeuronPartialState(cpu, **kwargs) self.__dict__.update(NeuronPartialState._shared_state) - self._check_initialized(mixed_precision, cpu, autocast_backend) + self._check_initialized(mixed_precision, cpu) #, autocast_backend) if not self.initialized: self.deepspeed_plugin = None self.ipex_plugin = None @@ -200,11 +201,18 @@ def __init__( self.mp_plugin = mp_plugin - if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized(): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size, - pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size, - ) + # nxd_config = nxd.neuronx_distributed_config( + # tensor_parallel_size=self.mp_plugin.tensor_parallel_size, + # pipeline_parallel_size=self.mp_plugin.pipeline_parallel_size, + # expert_parallel_size=1, # TODO: add proper argument here once we support MOE + + # ) + + # if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized(): + # parallel_state.initialize_model_parallel( + # tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size, + # pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size, + # ) if self.distributed_type is DistributedType.NO: if is_ipex_available(): @@ -221,16 +229,16 @@ def __init__( PartialState._shared_state["distributed_type"] = self.distributed_type - def _check_initialized(self, mixed_precision=None, cpu=None, autocast_backend=None): + def _check_initialized(self, mixed_precision=None, cpu=None): # autocast_backend=None): "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" super()._check_initialized(mixed_precision=mixed_precision, cpu=cpu) err = ( "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and " "pass `{flag}` to `Accelerator()`." ) - if self.initialized: - if autocast_backend is not None and autocast_backend != self.autocast_backend: - raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'")) + # if self.initialized: + # if autocast_backend is not None and autocast_backend != self.autocast_backend: + # raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'")) @property def autocast_backend(self): diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index b735ad2ab..8e25a421a 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -245,6 +245,7 @@ def prepare_for_precompilation(self, args: "TrainingArguments"): logger.info("Disabling prediction during precompilation as this is not well supported yet.") args.do_predict = False + @requires_neuronx_distributed def create_accelerator_and_postprocess(self): grad_acc_kwargs = {} if self.args.accelerator_config.gradient_accumulation_kwargs is not None: @@ -286,7 +287,25 @@ def create_accelerator_and_postprocess(self): } # create accelerator object + import neuronx_distributed as nxd + optimizer_config = { + "zero_one_enabled": self.args.zero_1, + "grad_clipping": self.args.max_grad_norm is not None and self.args.max_grad_norm > 0, + "max_grad_norm": self.args.max_grad_norm, + } + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=self.args.mp_plugin.tensor_parallel_size, + pipeline_parallel_size=self.args.mp_plugin.pipeline_parallel_size, + expert_parallel_size=1, # TODO: enable that once MOE is supported + pipeline_config=None, # TODO: integrate support for Pipeline with the NxD config + optimizer_config=optimizer_config, + activation_checkpoint_config=None, # TODO: integrate support for activation checkpointing with the NxD config + pad_model=False, + sequence_parallel=not self.args.disable_sequence_parallel, + # TODO: integrate other parameters as well. + ) self.accelerator = NeuronAccelerator( + nxd_config, mp_plugin=self.args.mp_plugin, zero_1=self.args.zero_1, mixed_precision="bf16" if self.args.bf16 else "no", @@ -372,6 +391,7 @@ def get_optimizer_cls_and_kwargs( args: TrainingArguments, model: Optional[PreTrainedModel] = None ) -> Tuple[Any, Any]: optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args, model=model) + print(optimizer_kwargs) lazy_load = args.mp_plugin.should_parallelize or args.zero_1 if lazy_load: optimizer_cls = make_optimizer_constructor_lazy(optimizer_cls) @@ -379,6 +399,7 @@ def get_optimizer_cls_and_kwargs( @patch_within_function(("transformers.Trainer.get_optimizer_cls_and_kwargs", get_optimizer_cls_and_kwargs)) def create_optimizer(self): + print("Create optimizer") return super().create_optimizer() def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: @@ -799,6 +820,7 @@ def _inner_training_loop( debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + delay_optimizer_creation = True # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -1069,7 +1091,6 @@ def _inner_training_loop( f"{tr_loss_step.device}" ) tr_loss += tr_loss_step - print(tr_loss) self.current_flos += float(self.floating_point_ops(inputs)) @@ -1101,20 +1122,42 @@ def _inner_training_loop( args.max_grad_norm, ) else: - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) + _grad_norm = None + # _grad_norm = self.accelerator.clip_grad_norm_( + # model.parameters(), + # args.max_grad_norm, + # ) grad_norm = _grad_norm + old_weights = {} + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + old_weights[name] = param.clone().detach() + # Optimizer step self.optimizer.step() + grad_norm = self.optimizer.optimizer.grad_norm + + for name, param in model.named_parameters(): + continue + if not param.requires_grad: + continue + diff = (param - old_weights[name]).abs().mean() + xm.master_print(f"Layer {name} weight change: {diff.item()}, {diff.device}") + + # for n, p in model.named_parameters(): + # if p.requires_grad and p.grad is not None: + # print(f"{n} => {p.grad.norm().item()}, {p.grad.min()}, {p.grad.max()}") + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() self.optimizer.zero_grad() + if isinstance(grad_norm, list) and len(grad_norm) > 0: + grad_norm = grad_norm[0] self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch