Skip to content

Commit

Permalink
Fix consolidation for TP (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jul 5, 2024
1 parent 542328d commit 17fe854
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Functions handling checkpointing under parallel settings."""

import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Union

Expand All @@ -29,7 +30,7 @@
)

from ..utils.peft_utils import ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla
from .utils import MODEL_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank


Expand All @@ -44,6 +45,31 @@
PEFT_SAFETENSORS_WEIGHTS_NAME = PEFT_WEIGHTS_NAME = ""


@requires_torch_xla
def xser_load_on_cpu(path: str):
"""
Modified version from neuronx_distributed `_xser_load` function load located at:
https://github.com/aws-neuron/neuronx-distributed/blob/main/src/neuronx_distributed/parallel_layers/checkpointing.py#L265-L283.
Instead of moving the loaded tensors to the XLA device, it keeps them on CPU.
"""
import torch_xla.core.xla_model as xm
import torch_xla.utils.serialization as xser

ref_data = torch.load(path)

def convert_fn(tensors):
rewritten_tensors = []
for t in tensors:
rewritten_tensors.append(torch.load(os.path.join(path + ".tensors", "tensor_{}.pt".format(t.tid))))
return rewritten_tensors

def select_fn(v):
return type(v) == xser.TensorReference

return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)


def create_gqa_query_or_output_projection_weight_from_full_weight(
full_weight: torch.Tensor,
tp_size: int,
Expand Down Expand Up @@ -148,8 +174,6 @@ def consolidate_tensor_parallel_checkpoints(

@requires_neuronx_distributed
def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "torch.Tensor"]:
from neuronx_distributed.parallel_layers.checkpointing import _xser_load

model_checkpoint_dir = checkpoint_dir / "model"

# Case 1: the checkpoint was saved with xser.
Expand All @@ -159,7 +183,7 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "t
sharded_checkpoints = [
p for p in sharded_checkpoints if not (p.name.endswith("info.pt") or p.name.endswith("tensors"))
]
load_function = _xser_load
load_function = xser_load_on_cpu

# Case 2: If no file was found, maybe the checkpoint was saved without xser.
if not sharded_checkpoints:
Expand Down

0 comments on commit 17fe854

Please sign in to comment.