Skip to content

Commit

Permalink
OnnxProto changes for multidevice
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinch-nv committed Sep 16, 2024
1 parent 84d4f7b commit 07e9745
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions onnx/onnx.in.proto
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,63 @@ message NodeProto {

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 9;

// Configuration of multi-device annotations.
repeated NodeConfigurationProto configuration = 10;
}

// Multi-device configuration proto for NodeProto.
message NodeConfigurationProto {
// ID of the configuration.
string configuration_id = 1;
// Sharding spec for the node.
repeated ShardingSpecProto sharding_spec = 2;
// Pipeline stage of this node.
optional int pipeline_stage = 3;
}

// ShardingSpecProto: This describes the sharding spec for a specific
// input/output of a node.
message ShardingSpecProto {
// Identifies the input/output of the node that is being sharded.
// It is called `logical tensor` in subsequent descriptions.
string tensor_name = 1;

// The following is the list of devices across which the logical
// tensor is sharded or replicated.
repeated int64 device = 2;

// Each element v in above field devices may represent either a
// device or a set of devices (when we want the same shard/tensor
// to be replicated across a subset of devices), as indicated by
// the following optional map. If the map contains an entry for v,
// then v represents a device group, and the map indicates the set
// of devices in that group.
optional map<int, IntListProto> index_to_device_group_map = 3;

// The following is the sharded-shape of the tensor, consisting of
// the sharding-spec for each axis of the tensor.
repeated ShardedDimProto sharded_dim = 4;
}

// ShardedDimProto: This describes the sharding spec for a single
// axis of a sharded tensor.
message ShardedDimProto {
int32 axis = 1; // the axis this sharding corresponds to
// The common-case is described by a single instance of SimpleShardedDimProto
// We use multiple instances to handle cases produced when a sharded
// tensor is reshaped, fusing multiple axes into one.
repeated SimpleShardedDimProto simple_sharding = 2;
}

// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.
// Here, N is allowed to be symbolic, which M is required to be a constant.
message SimpledShardedDimProto {
oneof dim {
int64 dim_value = 1;
string dim_param = 2;
}
optional int32 num_shards = 3;
}

// Training information
Expand Down Expand Up @@ -430,8 +487,22 @@ message ModelProto {
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;

// Describes different target configurations for a multi-device use case.
// A model can describe multiple multi-device configurations for execution.
repeated ConfigurationProto configuration = 26;
};

// ConfigurationProto describes a multi-device configuration for a model.
message ConfigurationProto {
// Name of the configuration.
string name = 1;
// Name of the device.
string device = 2;
// Number of devices inside this configuration.
int32 num_devices = 3;
}

// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
Expand Down

0 comments on commit 07e9745

Please sign in to comment.