Skip to content

Commit

Permalink
[WIP] integrate NxDOptimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Dec 3, 2024
1 parent 22d00ab commit dd9b80d
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 36 deletions.
69 changes: 55 additions & 14 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, Dict

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -81,6 +81,7 @@
xm = None

if is_neuronx_distributed_available():
import neuronx_distributed as nxd
from neuronx_distributed.utils.model_utils import move_model_to_device


Expand All @@ -98,6 +99,7 @@
class NeuronAccelerator(Accelerator):
def __init__(
self,
nxd_config: Dict[str, Any],
*args,
mp_plugin: Optional[ModelParallelismPlugin] = None,
zero_1: bool = False,
Expand Down Expand Up @@ -146,13 +148,17 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo

accelerate.state.is_torch_xla_available = patched_is_torch_xla_available

patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
)
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
self.mp_plugin = mp_plugin
self.nxd_config = nxd_config

# patched_accelerator_state = partial(
# NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
# )
# with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
with Patcher([("accelerate.accelerator.AcceleratorState", NeuronAcceleratorState)]):
super().__init__(**full_kwargs)

self.zero_1 = zero_1
self.zero_1 = self.nxd_config["optimizer_config"]["zero_one_enabled"]

if self.autocast_handler is None:
enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP
Expand Down Expand Up @@ -300,17 +306,32 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
)
return zero_1_optimizer

@requires_neuronx_distributed
@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
if self.zero_1:
optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
import neuronx_distributed as nxd

#cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
#xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla)
#print(xla_parameters)

optimizer = nxd.initialize_parallel_optimizer(
self.nxd_config,
optimizer.__class__,
# xla_parameters,
optimizer.param_groups,
**optimizer.defaults,
)
optimizer.zero_grad()
# if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
# optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
# if self.zero_1:
# optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
# Edge case: if the optimizer was created lazily outside of the Model Parallelism and/or ZeRO-1 setting, we make
# sure to actually load the proper parameters.
if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
optimizer = optimizer.__class__(*args, **kwargs)
# if hasattr(optimizer, "_args_to_recreate"):
# args, kwargs = optimizer._args_to_recreate
# optimizer = optimizer.__class__(*args, **kwargs)

return super().prepare_optimizer(optimizer, device_placement=device_placement)

Expand Down Expand Up @@ -449,6 +470,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
def prepare_model(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
print("Prepare model")
# If the model was already prepared, we skip.
if model in self._models:
return model
Expand Down Expand Up @@ -500,7 +522,26 @@ def backward(self, loss, **kwargs):
self.scaler.scale(loss).backward(**kwargs)
else:
loss.backward(**kwargs)

# vector_norm = [torch.vector_norm(p.grad, 2) for p in self._models[0].parameters() if p.requires_grad]
# norm = torch.nn.utils.clip_grad_norm_([p for p in self._models[0].parameters() if p.requires_grad], 1.0)
# xm.mark_step()
# print(vector_norm)
# self._models[0].to("cpu")
# print(self._models[0])
# print(norm)
# for n, p in self._models[0].named_parameters():
# if not p.requires_grad or p.grad is None:
# continue
# p = p.grad
# print(f"Gradient of {n}")
# print(f"Min: {p.min():.3f}")
# print(f"Max: {p.max():.3f}")
# print(f"Mean: {p.mean():.3f}")
# print(f"Std: {p.std():.3f}")
# print(f"L1 norm: {p.norm(p=1):.3f}")
# print(f"L2 norm: {p.norm(p=2):.3f}")
# assert 3==2

@contextlib.contextmanager
def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None):
if cache_enabled:
Expand Down
17 changes: 11 additions & 6 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,27 @@ def __init__(
self.parameters = [p for group in self.optimizer.param_groups for p in group["params"]]
self.parameter_ids = {id(p) for p in self.parameters}

self.total_grad_norm = []

# TODO: might be needed to override this soon.
def load_state_dict(self, state_dict):
return super().load_state_dict(state_dict)

def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2):
parameter_ids = {id(p) for p in parameters}
if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer):
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
# if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer):
# assert 3==2
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
return self.total_grad_norm

@requires_neuronx_distributed
def step(self, closure=None):
from neuronx_distributed import parallel_layers
from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients

if self.gradient_state.sync_gradients:
# For sequence-parallel, we have to explicitly all-reduce the layernorm gradients.
if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
allreduce_sequence_parallel_gradients(self.optimizer)
self.optimizer.step()
return

if isinstance(self.optimizer, ZeroRedundancyOptimizer):
if self.clip_grad_norm_to_perform is not None:
Expand All @@ -113,7 +116,9 @@ def step(self, closure=None):
if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
if self.clip_grad_norm_to_perform is not None:
parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform)
self.total_grad_norm.clear()
total_grad_norm = parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform)
self.total_grad_norm.append(total_grad_norm)
self.clip_grad_norm_to_perform = None
self.optimizer.step()
elif self.scaler is not None:
Expand Down
30 changes: 19 additions & 11 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
import torch_xla.core.xla_model as xm

if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state
import neuronx_distributed as nxd
# from neuronx_distributed.parallel_layers import parallel_state


logger = logging.get_logger()
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
os.environ["ACCELERATE_USE_AMP"] = "true"
NeuronPartialState(cpu, **kwargs)
self.__dict__.update(NeuronPartialState._shared_state)
self._check_initialized(mixed_precision, cpu, autocast_backend)
self._check_initialized(mixed_precision, cpu) #, autocast_backend)
if not self.initialized:
self.deepspeed_plugin = None
self.ipex_plugin = None
Expand Down Expand Up @@ -200,11 +201,18 @@ def __init__(

self.mp_plugin = mp_plugin

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
)
# nxd_config = nxd.neuronx_distributed_config(
# tensor_parallel_size=self.mp_plugin.tensor_parallel_size,
# pipeline_parallel_size=self.mp_plugin.pipeline_parallel_size,
# expert_parallel_size=1, # TODO: add proper argument here once we support MOE

# )

# if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
# parallel_state.initialize_model_parallel(
# tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
# pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
# )

if self.distributed_type is DistributedType.NO:
if is_ipex_available():
Expand All @@ -221,16 +229,16 @@ def __init__(

PartialState._shared_state["distributed_type"] = self.distributed_type

def _check_initialized(self, mixed_precision=None, cpu=None, autocast_backend=None):
def _check_initialized(self, mixed_precision=None, cpu=None): # autocast_backend=None):
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
super()._check_initialized(mixed_precision=mixed_precision, cpu=cpu)
err = (
"AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and "
"pass `{flag}` to `Accelerator()`."
)
if self.initialized:
if autocast_backend is not None and autocast_backend != self.autocast_backend:
raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'"))
# if self.initialized:
# if autocast_backend is not None and autocast_backend != self.autocast_backend:
# raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'"))

@property
def autocast_backend(self):
Expand Down
53 changes: 48 additions & 5 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def prepare_for_precompilation(self, args: "TrainingArguments"):
logger.info("Disabling prediction during precompilation as this is not well supported yet.")
args.do_predict = False

@requires_neuronx_distributed
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {}
if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
Expand Down Expand Up @@ -286,7 +287,25 @@ def create_accelerator_and_postprocess(self):
}

# create accelerator object
import neuronx_distributed as nxd
optimizer_config = {
"zero_one_enabled": self.args.zero_1,
"grad_clipping": self.args.max_grad_norm is not None and self.args.max_grad_norm > 0,
"max_grad_norm": self.args.max_grad_norm,
}
nxd_config = nxd.neuronx_distributed_config(
tensor_parallel_size=self.args.mp_plugin.tensor_parallel_size,
pipeline_parallel_size=self.args.mp_plugin.pipeline_parallel_size,
expert_parallel_size=1, # TODO: enable that once MOE is supported
pipeline_config=None, # TODO: integrate support for Pipeline with the NxD config
optimizer_config=optimizer_config,
activation_checkpoint_config=None, # TODO: integrate support for activation checkpointing with the NxD config
pad_model=False,
sequence_parallel=not self.args.disable_sequence_parallel,
# TODO: integrate other parameters as well.
)
self.accelerator = NeuronAccelerator(
nxd_config,
mp_plugin=self.args.mp_plugin,
zero_1=self.args.zero_1,
mixed_precision="bf16" if self.args.bf16 else "no",
Expand Down Expand Up @@ -372,13 +391,15 @@ def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args, model=model)
print(optimizer_kwargs)
lazy_load = args.mp_plugin.should_parallelize or args.zero_1
if lazy_load:
optimizer_cls = make_optimizer_constructor_lazy(optimizer_cls)
return optimizer_cls, optimizer_kwargs

@patch_within_function(("transformers.Trainer.get_optimizer_cls_and_kwargs", get_optimizer_cls_and_kwargs))
def create_optimizer(self):
print("Create optimizer")
return super().create_optimizer()

def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
Expand Down Expand Up @@ -799,6 +820,7 @@ def _inner_training_loop(
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = True

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -1069,7 +1091,6 @@ def _inner_training_loop(
f"{tr_loss_step.device}"
)
tr_loss += tr_loss_step
print(tr_loss)

self.current_flos += float(self.floating_point_ops(inputs))

Expand Down Expand Up @@ -1101,20 +1122,42 @@ def _inner_training_loop(
args.max_grad_norm,
)
else:
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
_grad_norm = None
# _grad_norm = self.accelerator.clip_grad_norm_(
# model.parameters(),
# args.max_grad_norm,
# )
grad_norm = _grad_norm

old_weights = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old_weights[name] = param.clone().detach()

# Optimizer step
self.optimizer.step()
grad_norm = self.optimizer.optimizer.grad_norm

for name, param in model.named_parameters():
continue
if not param.requires_grad:
continue
diff = (param - old_weights[name]).abs().mean()
xm.master_print(f"Layer {name} weight change: {diff.item()}, {diff.device}")

# for n, p in model.named_parameters():
# if p.requires_grad and p.grad is not None:
# print(f"{n} => {p.grad.norm().item()}, {p.grad.min()}, {p.grad.max()}")

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
self.optimizer.zero_grad()
if isinstance(grad_norm, list) and len(grad_norm) > 0:
grad_norm = grad_norm[0]

self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
Expand Down

0 comments on commit dd9b80d

Please sign in to comment.