Skip to content

Commit

Permalink
Updates to the sharding formalism doc (#2)
Browse files Browse the repository at this point in the history
* Add explanation of broadcast composition

* Fix image name

* Fix spacing

* Improve explanation

* Update docs/proposals/ShardingFormalism.md

---------

Co-authored-by: Kevin Chen <[email protected]>
  • Loading branch information
gramalingam and kevinch-nv authored Nov 20, 2024
1 parent 6c2e9c0 commit d62aa12
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 18 deletions.
141 changes: 123 additions & 18 deletions docs/ShardingFormalism.md → docs/proposals/ShardingFormalism.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,46 +86,151 @@ _Add, And, BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, Equal, Great

**Constraints on input sharding**
* For any non-broadcast axis, the sharding spec of the two (or more) inputs must be identical
* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices that participate in the parallel computation (that is, all devices identified in the node's sharding spec).
* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices
that participate in the parallel computation (that is, all devices identified in the node's sharding spec).
* The case where there are two or more broadcast axes is more involved. Some conditions must be satisfied
to ensure that the natural output (without extra communication ops) has a proper (complete) sharding.
The constraint is that the sharding specs of the multiple broadcast axes must be *composable*,
which is illustrated down below.

**Inference of output sharding**
* The sharding spec for any axes of the output is the same as the sharding spec for the axes of the
corresponding input axes in the case of non-broadcast. In the case of broadcast, the output axes
derives the sharding spec from the corresponding input axes with a size other than 1, if any.
In the special case where all corresponding input axes have a size of 1, the output axis inherits
* The sharding spec for any axis of the output is the same as the sharding spec for the corresponding
input axes in the case of non-broadcast.
* In the case of a single broadcast axis, the output axis derives the sharding spec from the corresponding
input axes with a size other than 1, if any.
* In the special case where all corresponding input axes have a size of 1, the output axis inherits
the same sharding (that is, replicated across all devices of the node op).

_Note_: The above can be generalized, but the generalization is hard to describe in words.
TODO: either add example figures or code to describe more complex scenarios.
* In the case of two or more broadcast axes, the output axis derives the sharding spec from the corresponding
input axes with a size other than 1, if any. However, the device assignment is inferred by composing the
sharding specs of all broadcast axes (where each output shard resides in the intersection of the sets of
devices that contain the corresponding input shards used to compute that output shard). See below for
an illustration of this.

**Composing Sharding Specs on Different Axes**

Consider the example of an `Add (Input1, Input2)` op. Consider the case where `Input1` has shape `[M, 1]` and
`Input2` has shape `[1, N]`. The output has shape `[M, N]`, as a result of broadcasting.

The figure below shows how we can use sharding for both the `M` and `N` axes:

![Composing sharding specs on different axes](images/composing_broadcast_axes.png)

Note that in this example, both the `M` and `N` axes are split into two shards each.
This means that the output itself has 4 shards, as shown in the figure.
In this example, we want each output-shard to be on one device, as described by
the sharding spec
```
{
device = [0, 1, 2, 3]
sharded_dim =[
{
axis = 0
simple_sharding =
[
{
num_shards = 2
}
]
}
{
axis = 1
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```
To produce this output, however, we need to ensure that the input-shards are
each available in two devices each, as shown in the figure above. In particular,
the first shard of `Input1` is needed by both devices 0 and 1, as it is used
to compute the first two output shards. Likewise, the first shard of `Input2`
is needed by both devices 0 and 2.

Thus, the sharding spec for `Input1` is as below:

```
{
device = [-1, -2] // keys into device_map
device_map = {-1: [0, 1], -2: [2, 3]}
sharded_dim =[
{
axis = 0
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```
The sharding spec for `Input2` is analogous, as explained and shown in figure above.

This leads to the following constraint for input-sharding and inference rule
for output-sharding in the presence of two broadcast axes:
* The (inferred) devices for `output-shard[i,j]` is the intersection of the set of devices
for `input-1-shard[i]` and `input-2-shard[j]`. If this set is empty, then the input
sharding specs are not compatible (for broadcast composition).

This rule is extended to the case of more than two broadcast axes accordingly.

### Reduction ops

**Constraints on input sharding**
* No constraints on input sharding.
* Sharding along non-reduction axes is straightforward, since parallel iteration over the non-reduction
axes is possible.
* Sharding along reduction axes can be supported, but it requires an implicit collective-reduce operation.
* Sharding along non-reduction axes is straightforward. It indicates
parallelization of the iteration over the non-reduction axes.
* Sharding along reduction axes is permitted. It indicates parallelization of the reduction
loop, but this involves performing the reduction in two steps. In the first step, the
reduction is done locally on the shard, and in the second step the reduction is done
across the different shards. This can be typically mapped to a collective-reduce operation.

**Inference of output sharding**
* Non-reduction axes inherit the sharding of the corresponding axes of the input.
* Two natural possibilities exist for the reduction axes, if they are sharded. The result can be
broadcast to all devices containing some shard along the reduction axes, or just to the devices
containing a distinguished shard (say, the first one). As a default, we assume a broadcast (the
first option).
* Since the size of the reduction axis is one after the reduction, it can't be used
for any meaningful sharding. The axis may even be omitted from the output shape,
depending on the value of the attribute `keep_dims`. If the axis is retained, it
is treated as having no sharding.

In the case where the inputs are only sharded along one or more reduction axes,
there will be no sharded axis in the inferred output sharding specification.
However, there is still a choice as to whether the computed output is replicated
on all the devices that participate in this operation, or whether it is stored
only in some distinguished node. Collective-reduce operations typically
support both variations. The default inferred output specification is to
broadcast the computed result to all devices that participate in the particular
reduction (the first option).

### MatMul-like ops

List of operations: MatMul, Gemm, quantized variations of these ops, special cases of EinSum

The constraints for these ops follow analogous cases above. Consider the simple case of matrix multiplication
of two matrices of dimensions `[M, K]` and `[K, N]` producing an output matrix of dimension `[M, N]`.
This operation is essentially a broadcast-reduction operation, where the first
input is interpreted to have the shape `[M, K, 1]` and the second input is interpreted to have
the shape `[1, K, N]`, and we perform a broadcast element-wise multiplication, followed
by a reduce-sum along the `K` axis. The constraints and inference for the operation follows
from the corresponding rules for broadcast and reduction described above.

Axis 0 of the first input (with value `M`) is conceptually broadcast to the second input.
Hence, its constraints and handling are similar to the treatment of broadcast axes for n-ary
elementwise ops.
elementwise ops. Specifically, since only the first input has this axis, the partitioning of
this axis is not constrained by the partitioning of the second input. Furthermore, the output
matrix will inherit the partitioning for the corresponding axis from the partitioning of axis
0 of the first input.

Axis 1 of the second input (with value `N`) is also handled similarly.

The axes with size value `K` represent reduction axes. The corresponding two axes must have
compatible sharding.
The two axes with size value (the _reduction_ axes) are both required to
have the same sharding (similar to non-broadcast axes in a binary operation above).

The output device assignment follows the rules described above for broadcast axes.

### Pooling and Convolution ops

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d62aa12

Please sign in to comment.