Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SPMD fsdp compute dtype #13

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import (Any, Callable, Dict, Optional, Union)
import warnings
from typing import (Any, Callable, Dict, Optional, Union)

import numpy as np
import torch
import torch.nn as nn
from torch._prims_common import TensorLike, TensorSequenceType

import numpy as np

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp._init_utils import _materialize_module
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import _cast_floats_tensors

FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]


def _prepare_spmd_partition_spec(param):
Expand Down Expand Up @@ -40,13 +42,18 @@ class SpmdFullyShardedDataParallel(nn.Module):
The callable should have the signature (output, mesh) -> None.
If None, the default implementation will shard the first tensor in the output.
If the output is a tuple, only the first tensor will be sharded.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
"""

def __init__(
self,
module: nn.Module,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
compute_dtype: Optional[torch.dtype] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
):
Expand Down Expand Up @@ -96,6 +103,11 @@ def __init__(
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES:
raise ValueError(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32

_materialize_module(
module,
None, [],
Expand Down Expand Up @@ -150,6 +162,9 @@ def module(self) -> nn.Module:
return self._orig_module

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.compute_dtype != torch.float32:
# Cast the input float tensors to the specified compute_dtype
args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs)
output = self.module(*args, **kwargs)
# Need to shard the output of the forward to instruct the compiler
# to enforce the FSDP algorithm.
Expand Down
Loading