diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index c3ef669f3..05d5a8a3e 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -389,7 +389,6 @@ def patch_model_for_neuron( def _prepare_model_for_mp( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False ): - import torch_xla.core.xla_model as xm from neuronx_distributed.pipeline import NxDPPModel if model in self._models or Parallelizer.was_parallelized(model): @@ -442,10 +441,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings): cpu_ids[name]: xla_params[name] for name, _ in model.named_parameters() } - xm.mark_step() - device_placement = False - - return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) + return model @requires_torch_xla @requires_neuronx_distributed @@ -491,8 +487,8 @@ def prepare_model( if should_apply_activation_checkpointing: apply_activation_checkpointing(model) move_model_to_device(model, xm.xla_device()) - device_placement = False - model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) + device_placement = False + model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) xm.mark_step() return model diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index ef34ce4f3..1f55e02bb 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -759,8 +759,6 @@ def should_parallelize_layer_predicate_func(layer): f"Could not find information for the parameter {name} to set its `requires_grad` attribute." ) - xm.mark_step() - if is_main_worker(): logger.info("Load and initialization of the weights done.") diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 69d077d6b..a62d0c2bc 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -1349,19 +1349,27 @@ def duplicate_module_with_random_weights_on_cpu(module: torch.nn.Module) -> torc """ clone = torch.nn.Module() + children_names = {n for n, _ in module.named_children()} buffer_names = {n for n, _ in module.named_buffers()} parameter_names = {n for n, _ in module.named_parameters()} for name in dir(module): attr = getattr(module, name) - if name in buffer_names or parameter_names: + if name in (children_names | buffer_names | parameter_names) or name.startswith("__"): continue setattr(clone, name, copy.deepcopy(attr)) + for name, mod in module.named_children(): + clone.add_module(name, duplicate_module_with_random_weights_on_cpu(mod)) + for name, buffer in module.named_buffers(): + if "." in name: + continue clone.register_buffer(name, torch.empty_like(buffer, device="cpu")) for name, param in module.named_parameters(): + if "." in name: + continue clone.register_parameter(name, torch.nn.Parameter(torch.empty_like(param, device="cpu"))) clone.__class__ = module.__class__ @@ -1585,11 +1593,12 @@ def wrapper(*args, **kwargs): patcher = Patcher(patching_specs=patching_specs) else: patcher = contextlib.nullcontext() - with patcher: - try: - yield - finally: - pass + try: + patcher.__enter__() + yield + finally: + patcher.__exit__(None, None, None) + pass def make_optimizer_constructor_lazy(optimizer_cls: Type["torch.optim.Optimizer"]): diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 237018c9f..f63b6b469 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -446,8 +446,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno reduced_tr_loss = self._reduce_loss(tr_loss) if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - # reset tr_loss to zero - tr_loss.zero_() + if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor): + tr_loss.data = self._zero_loss_value.data + else: + tr_loss.zero_() def log_closure(self, reduced_tr_loss, grad_norm): if is_main_worker_for_metrics(): @@ -764,7 +766,7 @@ def _inner_training_loop( self.state.save_steps = args.save_steps # Activate gradient checkpointing if needed - # It is handled differentlt if pipeline parallelism is enabled. + # It is handled differently if pipeline parallelism is enabled. if args.gradient_checkpointing and args.pipeline_parallel_size == 1: if args.gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} @@ -897,6 +899,9 @@ def _inner_training_loop( grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + # Mark step before training to materialize any tensor before creating the training graph. + xm.mark_step() + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): @@ -1052,6 +1057,15 @@ def _inner_training_loop( self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + # `_zero_loss_value` is used to reset the value of `tr_loss`. + # By doing that, we do not have to do `tr_loss.zero_()` when logging the loss. + # This way we do not insert a new op in the XLA graph (for `tr_loss.zero_()`) which woud create + # multiple graphs depending on the fact that we are logging or not. + # Here we always create a scalar whose value is `0.0`, this way the graph stays the same whether or + # not we are logging. The only difference when logging is that we set + # `tr_loss.data = self._zero_loss_value.data`, which should not create new graph ops. + self._zero_loss_value = torch.tensor(0.0, device=args.device) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control)