Skip to content

Commit

Permalink
Add and remove some mark steps (#644)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jul 5, 2024
1 parent 281bad8 commit 542328d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 18 deletions.
10 changes: 3 additions & 7 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def patch_model_for_neuron(
def _prepare_model_for_mp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
import torch_xla.core.xla_model as xm
from neuronx_distributed.pipeline import NxDPPModel

if model in self._models or Parallelizer.was_parallelized(model):
Expand Down Expand Up @@ -442,10 +441,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
cpu_ids[name]: xla_params[name] for name, _ in model.named_parameters()
}

xm.mark_step()
device_placement = False

return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
return model

@requires_torch_xla
@requires_neuronx_distributed
Expand Down Expand Up @@ -491,8 +487,8 @@ def prepare_model(
if should_apply_activation_checkpointing:
apply_activation_checkpointing(model)
move_model_to_device(model, xm.xla_device())
device_placement = False
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
device_placement = False
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
xm.mark_step()
return model

Expand Down
2 changes: 0 additions & 2 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,6 @@ def should_parallelize_layer_predicate_func(layer):
f"Could not find information for the parameter {name} to set its `requires_grad` attribute."
)

xm.mark_step()

if is_main_worker():
logger.info("Load and initialization of the weights done.")

Expand Down
21 changes: 15 additions & 6 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,19 +1349,27 @@ def duplicate_module_with_random_weights_on_cpu(module: torch.nn.Module) -> torc
"""
clone = torch.nn.Module()

children_names = {n for n, _ in module.named_children()}
buffer_names = {n for n, _ in module.named_buffers()}
parameter_names = {n for n, _ in module.named_parameters()}

for name in dir(module):
attr = getattr(module, name)
if name in buffer_names or parameter_names:
if name in (children_names | buffer_names | parameter_names) or name.startswith("__"):
continue
setattr(clone, name, copy.deepcopy(attr))

for name, mod in module.named_children():
clone.add_module(name, duplicate_module_with_random_weights_on_cpu(mod))

for name, buffer in module.named_buffers():
if "." in name:
continue
clone.register_buffer(name, torch.empty_like(buffer, device="cpu"))

for name, param in module.named_parameters():
if "." in name:
continue
clone.register_parameter(name, torch.nn.Parameter(torch.empty_like(param, device="cpu")))

clone.__class__ = module.__class__
Expand Down Expand Up @@ -1585,11 +1593,12 @@ def wrapper(*args, **kwargs):
patcher = Patcher(patching_specs=patching_specs)
else:
patcher = contextlib.nullcontext()
with patcher:
try:
yield
finally:
pass
try:
patcher.__enter__()
yield
finally:
patcher.__exit__(None, None, None)
pass


def make_optimizer_constructor_lazy(optimizer_cls: Type["torch.optim.Optimizer"]):
Expand Down
20 changes: 17 additions & 3 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
reduced_tr_loss = self._reduce_loss(tr_loss)

if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
# reset tr_loss to zero
tr_loss.zero_()
if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
tr_loss.data = self._zero_loss_value.data
else:
tr_loss.zero_()

def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
Expand Down Expand Up @@ -764,7 +766,7 @@ def _inner_training_loop(
self.state.save_steps = args.save_steps

# Activate gradient checkpointing if needed
# It is handled differentlt if pipeline parallelism is enabled.
# It is handled differently if pipeline parallelism is enabled.
if args.gradient_checkpointing and args.pipeline_parallel_size == 1:
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
Expand Down Expand Up @@ -897,6 +899,9 @@ def _inner_training_loop(
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# Mark step before training to materialize any tensor before creating the training graph.
xm.mark_step()

# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
Expand Down Expand Up @@ -1052,6 +1057,15 @@ def _inner_training_loop(
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

# `_zero_loss_value` is used to reset the value of `tr_loss`.
# By doing that, we do not have to do `tr_loss.zero_()` when logging the loss.
# This way we do not insert a new op in the XLA graph (for `tr_loss.zero_()`) which woud create
# multiple graphs depending on the fact that we are logging or not.
# Here we always create a scalar whose value is `0.0`, this way the graph stays the same whether or
# not we are logging. The only difference when logging is that we set
# `tr_loss.data = self._zero_loss_value.data`, which should not create new graph ops.
self._zero_loss_value = torch.tensor(0.0, device=args.device)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Expand Down

0 comments on commit 542328d

Please sign in to comment.