diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index feaf76894..409b58d17 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -15,6 +15,7 @@ """Functions handling checkpointing under parallel settings.""" import json +import os from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Union @@ -29,7 +30,7 @@ ) from ..utils.peft_utils import ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME -from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors +from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla from .utils import MODEL_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank @@ -44,6 +45,31 @@ PEFT_SAFETENSORS_WEIGHTS_NAME = PEFT_WEIGHTS_NAME = "" +@requires_torch_xla +def xser_load_on_cpu(path: str): + """ + Modified version from neuronx_distributed `_xser_load` function load located at: + https://github.com/aws-neuron/neuronx-distributed/blob/main/src/neuronx_distributed/parallel_layers/checkpointing.py#L265-L283. + + Instead of moving the loaded tensors to the XLA device, it keeps them on CPU. + """ + import torch_xla.core.xla_model as xm + import torch_xla.utils.serialization as xser + + ref_data = torch.load(path) + + def convert_fn(tensors): + rewritten_tensors = [] + for t in tensors: + rewritten_tensors.append(torch.load(os.path.join(path + ".tensors", "tensor_{}.pt".format(t.tid)))) + return rewritten_tensors + + def select_fn(v): + return type(v) == xser.TensorReference + + return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data) + + def create_gqa_query_or_output_projection_weight_from_full_weight( full_weight: torch.Tensor, tp_size: int, @@ -148,8 +174,6 @@ def consolidate_tensor_parallel_checkpoints( @requires_neuronx_distributed def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "torch.Tensor"]: - from neuronx_distributed.parallel_layers.checkpointing import _xser_load - model_checkpoint_dir = checkpoint_dir / "model" # Case 1: the checkpoint was saved with xser. @@ -159,7 +183,7 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "t sharded_checkpoints = [ p for p in sharded_checkpoints if not (p.name.endswith("info.pt") or p.name.endswith("tensors")) ] - load_function = _xser_load + load_function = xser_load_on_cpu # Case 2: If no file was found, maybe the checkpoint was saved without xser. if not sharded_checkpoints: