Skip to content

Commit

Permalink
Add multi-device execution support in ONNX
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv committed Dec 18, 2024
1 parent 25a134a commit 33afe10
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
179 changes: 179 additions & 0 deletions docs/proposals/ONNXMultiDeviceProposal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
<!--
Copyright (c) ONNX Project Contributors
-->

<!--- SPDX-License-Identifier: Apache-2.0 -->

# ONNX Multi-Device Proposal

## Background

The recent trend in increasingly larger models has spurred an interest in distributed inference. A key performance bottleneck for inference for these large models has been the memory limits of GPUs and other accelerators as well as communication bandwidth. Thus, efficient distributed inference typically requires parallelization of the computation across multiple devices taking memory and bandwidth into account.

Our goal is to extend ONNX so that it can serve as a representation of a parallelized model. This is driven by the current state-of-the-art techniques used for distributed inference (eg., see [GSPMD: General and Scalable Parallelization for ML Computation Graphs](https://arxiv.org/pdf/2105.04663.pdf)). In particular, two techniques of interest are tensor parallelism and pipelining. In tensor parallelism (also known as horizontal parallelism or operator parallelism), the computation of a single operator (node) in the graph is parallelized across multiple devices by sharding its inputs, In pipeline parallelism, different subgraphs are assigned to different devices.


## Design

See [this commit](https://github.com/kevinch-nv/onnx/commit/07e97452096b28ba7c46fec6927d195907431e07) for the proposed additions to the ONNX spec.

The key point of this design is that all multi-device specific annotations are at the node level, and do not affect the main computational graph. This means:
- All communication operations required for multi-device execution are implicit
- A backend may choose to ignore the annotations if the provided configurations are either not supported or not available

### Sharding Specification

Sharding refers to modifying a tensor into multiple parts to be sent across multiple devices. A tensor may be sharded across any of its axis.

Modification of a tensor generally falls into two categories: splitting and duplication.

#### Sharding as a Split

For example, consider the following 2x2 tensor:

`[[1, 2], [3, 4]]`

If a sharding across axis 0 is specified over two devices, then:
- Device 0 will receive a tensor of shape 1x2 with data `[[1, 2]]`
- Device 1 will receive a tensor of shape 1x2 with data `[[3, 4]]`

The corresponding ShardingSpecProto for the above will look like:
```
{
device = [0, 1]
sharded_dim =[
{
axis = 0
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```

If a sharding across axis 1 is specified over two devices, then:
- Device 0 will receive a tensor of shape 2x1 with data `[[1], [3]]`
- Device 1 will receive a tensor of shape 2x1 with data `[[2], [4]]`

The corresponding ShardingSpecProto for the above will look like:
```
{
device = [0, 1]
sharded_dim =[
{
axis = 1
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```

If a sharding across axis 0 and axis 1 is specified over four devices, then:
- Device 0 will receive a tensor of shape 1x1 with data `[[1]]`
- Device 1 will receive a tensor of shape 1x1 with data `[[2]]`
- Device 2 will receive a tensor of shape 1x1 with data `[[3]]`
- Device 3 will receive a tensor of shape 1x1 with data `[[4]]`

The corresponding ShardingSpecProto for the above will look like:
```
{
device = [0, 1, 2, 3]
sharded_dim =[
{
axis = 0
simple_sharding =
[
{
num_shards = 2
}
]
}
{
axis = 1
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```

A key observation in the above example shows how indexing is performed when multiple sharding axes are provided. In general, the splitting is done as:

```
split_tensors = []
for a in range(num_shards_a):
a_index = a * input.shape[axis0] / num_shards_a
for b in range(num_shards_b):
b_index = b * input.shape[axis1] / num_shards_b
split = input[a_index : a_index + num_shards_a, b_index : b_index + num_shards_b]
split_tensors.append(split)
```

Note that the above examples assume that the num_shards are evenly divisible into the axis that's being sharded. While this is not a hard restriction, it is up to the backend on how to handle non-evenly divisble cases.


#### Sharding as a Broadcast

There may be cases where data in a tensor must be duplicated across multiple devices to ensure that operations stay functionaly correct.

For example consider replicating the same 2x2 tensor across two devices. We can do so by providing the following ShardingSpecProto:

```
{
device = [-1] // keys into device_map
device_map = {-1: [0, 1]}
sharded_dim =[]
}
```

It is also possible to mix splitting and broadcasting, consider the following ShardingSpecProto:

```
{
device = [-1, -2] // keys into device_map
device_map = {-1: [0, 1], -2: [2, 3]}
sharded_dim =[
{
axis = 0
simple_sharding =
[
{
num_shards = 2
}
]
}
]
}
```

On device 0 and 1, the following 1x2 tensor is produced: `[[1,2]]`
On device 2 and 3, the following 1x2 tensor is produced: `[[2,3]]`

#### Pipeline Parallelism

Pipeline stages are represented as an optional integer value in a node's NodeConfigurationProto. It is a hint to the backend on how to run a model in a pipelined fashion across multiple devices. For example, consider the following diagram:

```
Nodes below have a pipeline id of 1:
A -> B -> C -> D -> E
| Nodes below have a pipeline id of 2:
F -> G -> H -> I -> J -> K
```

It is possible to have both pipeline and tensor parallel annotations in the same ONNX graph.

71 changes: 71 additions & 0 deletions onnx/onnx.in.proto
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,63 @@ message NodeProto {

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 9;

// Configuration of multi-device annotations.
repeated NodeConfigurationProto configuration = 10;
}

// Multi-device configuration proto for NodeProto.
message NodeConfigurationProto {
// ID of the configuration.
string configuration_id = 1;
// Sharding spec for the node.
repeated ShardingSpecProto sharding_spec = 2;
// Pipeline stage of this node.
optional int pipeline_stage = 3;
}

// ShardingSpecProto: This describes the sharding spec for a specific
// input/output of a node.
message ShardingSpecProto {
// Identifies the input/output of the node that is being sharded.
// It is called `logical tensor` in subsequent descriptions.
string tensor_name = 1;

// The following is the list of devices across which the logical
// tensor is sharded or replicated.
repeated int64 device = 2;

// Each element v in above field devices may represent either a
// device or a set of devices (when we want the same shard/tensor
// to be replicated across a subset of devices), as indicated by
// the following optional map. If the map contains an entry for v,
// then v represents a device group, and the map indicates the set
// of devices in that group.
optional map<int, IntListProto> index_to_device_group_map = 3;

// The following is the sharded-shape of the tensor, consisting of
// the sharding-spec for each axis of the tensor.
repeated ShardedDimProto sharded_dim = 4;
}

// ShardedDimProto: This describes the sharding spec for a single
// axis of a sharded tensor.
message ShardedDimProto {
int32 axis = 1; // the axis this sharding corresponds to
// The common-case is described by a single instance of SimpleShardedDimProto
// We use multiple instances to handle cases produced when a sharded
// tensor is reshaped, fusing multiple axes into one.
repeated SimpleShardedDimProto simple_sharding = 2;
}

// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
// Here, N is allowed to be symbolic, which M is required to be a constant.
message SimpledShardedDimProto {
oneof dim {
int64 dim_value = 1;
string dim_param = 2;
}
optional int32 num_shards = 3;
}

// Training information
Expand Down Expand Up @@ -430,8 +487,22 @@ message ModelProto {
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;

// Describes different target configurations for a multi-device use case.
// A model can describe multiple multi-device configurations for execution.
repeated ConfigurationProto configuration = 26;
};

// ConfigurationProto describes a multi-device configuration for a model.
message ConfigurationProto {
// Name of the configuration.
string name = 1;
// Name of the device.
string device = 2;
// Number of devices inside this configuration.
int32 num_devices = 3;
}

// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
Expand Down

0 comments on commit 33afe10

Please sign in to comment.