Skip to content

Latest commit

 

History

History
69 lines (47 loc) · 2.53 KB

distributed.tensor.parallel.rst

File metadata and controls

69 lines (47 loc) · 2.53 KB

Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides several parallelism styles: Rowwise, Colwise and Pairwise Parallelism.

Warning

Tensor Parallelism APIs are experimental and subject to change.

The entrypoint to parallelize your nn.Module using Tensor Parallelism is:

.. automodule:: torch.distributed.tensor.parallel

.. currentmodule:: torch.distributed.tensor.parallel

.. autofunction::  parallelize_module

Tensor Parallelism supports the following parallel styles:

.. autoclass:: torch.distributed.tensor.parallel.style.RowwiseParallel
  :members:

.. autoclass:: torch.distributed.tensor.parallel.style.ColwiseParallel
  :members:

.. autoclass:: torch.distributed.tensor.parallel.style.PairwiseParallel
  :members:

Warning

Sequence Parallelism are still in experimental and no evaluation has been done.

.. autoclass:: torch.distributed.tensor.parallel.style.SequenceParallel
  :members:

Since Tensor Parallelism is built on top of DTensor, we need to specify the input and output placement of the module with DTensors so it can expectedly interacts with the module before and after. The followings are functions used for input/output preparation:

.. currentmodule:: torch.distributed.tensor.parallel.style

.. autofunction::  make_input_replicate_1d
.. autofunction::  make_input_reshard_replicate
.. autofunction::  make_input_shard_1d
.. autofunction::  make_input_shard_1d_last_dim
.. autofunction::  make_output_replicate_1d
.. autofunction::  make_output_reshard_tensor
.. autofunction::  make_output_shard_1d
.. autofunction::  make_output_tensor

Currently, there are some constraints which makes it hard for the nn.MultiheadAttention module to work out of box for Tensor Parallelism, so we built this multihead_attention module for Tensor Parallelism users. Also, in parallelize_module, we automatically swap nn.MultiheadAttention to this custom module when specifying PairwiseParallel.

.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
  :members:

We also enabled 2D parallelism to integrate with FullyShardedDataParallel. Users just need to call the following API explicitly:

.. currentmodule:: torch.distributed.tensor.parallel.fsdp
.. autofunction::  enable_2d_with_fsdp