From 07e97452096b28ba7c46fec6927d195907431e07 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 16 Sep 2024 16:34:48 -0700 Subject: [PATCH] OnnxProto changes for multidevice --- onnx/onnx.in.proto | 71 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/onnx/onnx.in.proto b/onnx/onnx.in.proto index d30e9393cc1..a0633237b7d 100644 --- a/onnx/onnx.in.proto +++ b/onnx/onnx.in.proto @@ -231,6 +231,63 @@ message NodeProto { // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 9; + + // Configuration of multi-device annotations. + repeated NodeConfigurationProto configuration = 10; +} + +// Multi-device configuration proto for NodeProto. +message NodeConfigurationProto { + // ID of the configuration. + string configuration_id = 1; + // Sharding spec for the node. + repeated ShardingSpecProto sharding_spec = 2; + // Pipeline stage of this node. + optional int pipeline_stage = 3; +} + +// ShardingSpecProto: This describes the sharding spec for a specific +// input/output of a node. +message ShardingSpecProto { + // Identifies the input/output of the node that is being sharded. + // It is called `logical tensor` in subsequent descriptions. + string tensor_name = 1; + + // The following is the list of devices across which the logical + // tensor is sharded or replicated. + repeated int64 device = 2; + + // Each element v in above field devices may represent either a + // device or a set of devices (when we want the same shard/tensor + // to be replicated across a subset of devices), as indicated by + // the following optional map. If the map contains an entry for v, + // then v represents a device group, and the map indicates the set + // of devices in that group. + optional map index_to_device_group_map = 3; + + // The following is the sharded-shape of the tensor, consisting of + // the sharding-spec for each axis of the tensor. + repeated ShardedDimProto sharded_dim = 4; +} + +// ShardedDimProto: This describes the sharding spec for a single +// axis of a sharded tensor. +message ShardedDimProto { + int32 axis = 1; // the axis this sharding corresponds to + // The common-case is described by a single instance of SimpleShardedDimProto + // We use multiple instances to handle cases produced when a sharded + // tensor is reshaped, fusing multiple axes into one. + repeated SimpleShardedDimProto simple_sharding = 2; +} + +// SimpleShardedDimProto: Indicates that N blocks are divided into M shards. +// Here, N is allowed to be symbolic, which M is required to be a constant. +message SimpledShardedDimProto { + oneof dim { + int64 dim_value = 1; + string dim_param = 2; + } + optional int32 num_shards = 3; } // Training information @@ -430,8 +487,22 @@ message ModelProto { // One FunctionProto can reference other FunctionProto in the model, however, recursive reference // is not allowed. repeated FunctionProto functions = 25; + + // Describes different target configurations for a multi-device use case. + // A model can describe multiple multi-device configurations for execution. + repeated ConfigurationProto configuration = 26; }; +// ConfigurationProto describes a multi-device configuration for a model. +message ConfigurationProto { + // Name of the configuration. + string name = 1; + // Name of the device. + string device = 2; + // Number of devices inside this configuration. + int32 num_devices = 3; +} + // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto {