From d62aa1235616c9f1ff0279a1d376aff672926fc7 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 20 Nov 2024 14:58:19 -0800 Subject: [PATCH] Updates to the sharding formalism doc (#2) * Add explanation of broadcast composition * Fix image name * Fix spacing * Improve explanation * Update docs/proposals/ShardingFormalism.md --------- Co-authored-by: Kevin Chen <45886021+kevinch-nv@users.noreply.github.com> --- docs/{ => proposals}/ShardingFormalism.md | 141 +++++++++++++++--- .../images/composing_broadcast_axes.png | Bin 0 -> 6903 bytes 2 files changed, 123 insertions(+), 18 deletions(-) rename docs/{ => proposals}/ShardingFormalism.md (50%) create mode 100644 docs/proposals/images/composing_broadcast_axes.png diff --git a/docs/ShardingFormalism.md b/docs/proposals/ShardingFormalism.md similarity index 50% rename from docs/ShardingFormalism.md rename to docs/proposals/ShardingFormalism.md index a4973b913bb..416ea4670c8 100644 --- a/docs/ShardingFormalism.md +++ b/docs/proposals/ShardingFormalism.md @@ -86,32 +86,125 @@ _Add, And, BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, Equal, Great **Constraints on input sharding** * For any non-broadcast axis, the sharding spec of the two (or more) inputs must be identical -* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices that participate in the parallel computation (that is, all devices identified in the node's sharding spec). +* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices +that participate in the parallel computation (that is, all devices identified in the node's sharding spec). +* The case where there are two or more broadcast axes is more involved. Some conditions must be satisfied +to ensure that the natural output (without extra communication ops) has a proper (complete) sharding. +The constraint is that the sharding specs of the multiple broadcast axes must be *composable*, +which is illustrated down below. **Inference of output sharding** -* The sharding spec for any axes of the output is the same as the sharding spec for the axes of the -corresponding input axes in the case of non-broadcast. In the case of broadcast, the output axes -derives the sharding spec from the corresponding input axes with a size other than 1, if any. -In the special case where all corresponding input axes have a size of 1, the output axis inherits +* The sharding spec for any axis of the output is the same as the sharding spec for the corresponding +input axes in the case of non-broadcast. +* In the case of a single broadcast axis, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. +* In the special case where all corresponding input axes have a size of 1, the output axis inherits the same sharding (that is, replicated across all devices of the node op). - -_Note_: The above can be generalized, but the generalization is hard to describe in words. -TODO: either add example figures or code to describe more complex scenarios. +* In the case of two or more broadcast axes, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. However, the device assignment is inferred by composing the +sharding specs of all broadcast axes (where each output shard resides in the intersection of the sets of +devices that contain the corresponding input shards used to compute that output shard). See below for +an illustration of this. + +**Composing Sharding Specs on Different Axes** + +Consider the example of an `Add (Input1, Input2)` op. Consider the case where `Input1` has shape `[M, 1]` and +`Input2` has shape `[1, N]`. The output has shape `[M, N]`, as a result of broadcasting. + +The figure below shows how we can use sharding for both the `M` and `N` axes: + +![Composing sharding specs on different axes](images/composing_broadcast_axes.png) + +Note that in this example, both the `M` and `N` axes are split into two shards each. +This means that the output itself has 4 shards, as shown in the figure. +In this example, we want each output-shard to be on one device, as described by +the sharding spec +``` +{ + device = [0, 1, 2, 3] + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + { + axis = 1 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` +To produce this output, however, we need to ensure that the input-shards are +each available in two devices each, as shown in the figure above. In particular, +the first shard of `Input1` is needed by both devices 0 and 1, as it is used +to compute the first two output shards. Likewise, the first shard of `Input2` +is needed by both devices 0 and 2. + +Thus, the sharding spec for `Input1` is as below: + +``` +{ + device = [-1, -2] // keys into device_map + device_map = {-1: [0, 1], -2: [2, 3]} + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` +The sharding spec for `Input2` is analogous, as explained and shown in figure above. + +This leads to the following constraint for input-sharding and inference rule +for output-sharding in the presence of two broadcast axes: +* The (inferred) devices for `output-shard[i,j]` is the intersection of the set of devices +for `input-1-shard[i]` and `input-2-shard[j]`. If this set is empty, then the input +sharding specs are not compatible (for broadcast composition). + +This rule is extended to the case of more than two broadcast axes accordingly. ### Reduction ops **Constraints on input sharding** * No constraints on input sharding. -* Sharding along non-reduction axes is straightforward, since parallel iteration over the non-reduction -axes is possible. -* Sharding along reduction axes can be supported, but it requires an implicit collective-reduce operation. +* Sharding along non-reduction axes is straightforward. It indicates +parallelization of the iteration over the non-reduction axes. +* Sharding along reduction axes is permitted. It indicates parallelization of the reduction +loop, but this involves performing the reduction in two steps. In the first step, the +reduction is done locally on the shard, and in the second step the reduction is done +across the different shards. This can be typically mapped to a collective-reduce operation. **Inference of output sharding** * Non-reduction axes inherit the sharding of the corresponding axes of the input. -* Two natural possibilities exist for the reduction axes, if they are sharded. The result can be -broadcast to all devices containing some shard along the reduction axes, or just to the devices -containing a distinguished shard (say, the first one). As a default, we assume a broadcast (the -first option). +* Since the size of the reduction axis is one after the reduction, it can't be used +for any meaningful sharding. The axis may even be omitted from the output shape, +depending on the value of the attribute `keep_dims`. If the axis is retained, it +is treated as having no sharding. + +In the case where the inputs are only sharded along one or more reduction axes, +there will be no sharded axis in the inferred output sharding specification. +However, there is still a choice as to whether the computed output is replicated +on all the devices that participate in this operation, or whether it is stored +only in some distinguished node. Collective-reduce operations typically +support both variations. The default inferred output specification is to +broadcast the computed result to all devices that participate in the particular +reduction (the first option). ### MatMul-like ops @@ -119,13 +212,25 @@ List of operations: MatMul, Gemm, quantized variations of these ops, special cas The constraints for these ops follow analogous cases above. Consider the simple case of matrix multiplication of two matrices of dimensions `[M, K]` and `[K, N]` producing an output matrix of dimension `[M, N]`. +This operation is essentially a broadcast-reduction operation, where the first +input is interpreted to have the shape `[M, K, 1]` and the second input is interpreted to have +the shape `[1, K, N]`, and we perform a broadcast element-wise multiplication, followed +by a reduce-sum along the `K` axis. The constraints and inference for the operation follows +from the corresponding rules for broadcast and reduction described above. + Axis 0 of the first input (with value `M`) is conceptually broadcast to the second input. Hence, its constraints and handling are similar to the treatment of broadcast axes for n-ary -elementwise ops. +elementwise ops. Specifically, since only the first input has this axis, the partitioning of +this axis is not constrained by the partitioning of the second input. Furthermore, the output +matrix will inherit the partitioning for the corresponding axis from the partitioning of axis +0 of the first input. + Axis 1 of the second input (with value `N`) is also handled similarly. -The axes with size value `K` represent reduction axes. The corresponding two axes must have -compatible sharding. +The two axes with size value (the _reduction_ axes) are both required to +have the same sharding (similar to non-broadcast axes in a binary operation above). + +The output device assignment follows the rules described above for broadcast axes. ### Pooling and Convolution ops diff --git a/docs/proposals/images/composing_broadcast_axes.png b/docs/proposals/images/composing_broadcast_axes.png new file mode 100644 index 0000000000000000000000000000000000000000..d14705673142d92971fea02fbddcdd2fe084a714 GIT binary patch literal 6903 zcmeHKdsK|u`yaOu-I(MQnMN^_<~~i$=%yx_YPwA6hBTMA3{6urqkAVuQjPA0q#My) zAwoADT@@mfE)tPAN=GE!_`VazWqr?Dzh$l8_dm1N?DyTzexA>Mp3i=s{q7xOXJa8J zt11hDK;&p!s0;{Xjs*Od&YuTLaO#*V;G-ef(M7~y2Ecp;d{3@7022lI0x%$u>j{Ab zcJu7E$=AV_P7Ojgn!4vS`+4wkdOezU!V2y?T{!TRwcE7i>rtdlf z6J)5SukMObc-EJ0BUy74=cd%LeWnqNE`n=jW=AtK`zIhUjP$T1!!30ts~UFR6K5+& z`77?dZkI8X4Uf+eo1e4mmMx|n+djE~u9^&0vA9 z6;TOi1MD*g!d~CT{d6cUwBl>pCdbtYNo3#&sls^5G zRVxBLJFn*s#?;lb3JwjRJpnVVOpSO%;pPd`Rv_~23a>IZui=nU6`3X7mIq%En-c$A zmb_)tfmn)meyyp*DGs`v+no^d!D3aocsssa+vAywiJZ1(lfuV^HI=%D?m74{u=4M! zI^Q?x->F=3QT5VHn0HU{@;^T<-OgPdUcC5bSn{J6yOP#=>{7dO<@K&nb>H%NOO0P9 zYmk$(a_w~Hj(EMJ2f;*3-kmM4E*tlogm4Oq2GkxuRy>BVe_zp5IaLudHEvy5v&P9; zkmkbBdJg8DKNrj|7rHft&F7(+96k#`2l9Nu#DhQ#j01g{>^*=8#sa*!K8Em-${IL~ z%Q1vI>(a4wUo&7AcT2DUa0s?>WC!nIlR0o>BUyt$3JAahL`+y9&)Y{x2{eSy;!?nO zsTc!?&8mp@7{XoXb}%!(0DuwEL^Kv<9?12_!;NHN1_F*Jg+bl?1p@3D!gq;8z7z~5 zARqu8KtS^aUKkviOvYgG7(5;YYM_KcJ|boy%15|X3h^0(3JBQ(uCIv8_kl?G=(kLOFLe}(rEeqjOR0~5&f#o*9b43CHT-a;rc_Xk0~IP|X;LPxMnF$_S+ z-z#7P=Kg?>Xzlk99QIdx-@O9w*>E^)4B!p$Kvf|)EAEFWEogMRuNG1Yytq8ySu2q2 zA2dZ=&u?V?;G1-0Hk|Jr0nNYS{-FI&?6bk@d(p6qci}$I@qUNH|^H@1STtLJ`x44M?Fta5NXh;b4hO zfTd4DVf9)1C?b={MClU%GK!>+2RK+Vkptj>?;vaiTrew{-rsvAh2nrvL=p?9#~~3> zy1Ke#6p^f}kJ4x2a43?ezAjr&m%t?8ShG+ZHf1wkz+-~T$>lM<0F1AX*X)3laLOh- znjsvI#(tC7c{4?xpaEC|TptcUK=^IQk;?-dL`*54I6VT6tfz}55J*H4Ne}-`$q5h$ z!CaK0;;?9f{_Kc!VJKiYAhk?srh)*oaxfZ-nE+sl_yR{h-`fx_bqXf6{5nhr3yQ-O zF{w-u0D@xiL<$Z^!Rb5VNE95FLLh)-Qn26Y^Eq74p#Mu-T0Ss?&x^i=D+K2cniYMn zDF?vsbMJH4n>$-fFxYHaP?+q`AqbiN0B6=u5bN_0dl%Ek3jp_zFA4jPock}ypi3YV z0Wf(`L>vh~5s7+i6pM}3M{#g?Ju*{Ymqf-ge?%AZJw*Xb0kFvn6`=2aqxpm>J(u2FH8}79-s;{)yHA^Iv=z%qo1h#ejC7W#HBY?u3}HTj3YJ zq?^wF@b_gM{)aPw(7z`6DSm&^^^2~bV&JEge`VJ%x_*j*pHlvnUH@-%$$q;`0Y2af zC;+@Fg^+vBfY&SyJ8MUCKv#G6&h@RCuSq!i@Poa7q_&~CF#l?~{0nt(aE^#!;{Y*v z-U(j+A#+$(7F0;|g4J8V&SKv!u0jX||EKgn2f=ac0fh@iG`ji17fY5yWHrjjbEhB> zSssnL$uY2dxVQ3nA!VhE2iKpjs@K}iQ%&in6aEm9{_1SY@9c^WKC*FPYbGSZv}4?O zPe-6ha*E=zWAhiztKV1)9}QQN_}F_35%nv@JVkS?^+`oGVxIdO`bifV`)C~WYSeFL z=a*(ng{HA5T~3G+@!0+QO<;C|)`JVf~Tl)U;>^mJ3DSTSaM-|1EB29MPq+~2!Ne6yx-!TQO4 zf|%*E5ABO0psu*zORd#k4|^Cik!98HB{0rxlPKNV>QP=+WY>Bjfl*pX(3E6t5bKmC5kSq%LTo@jjaF z@SA|Y?af2@rwT;xC~~#oxb?Am6A{ z;y$f))(UFh8x|Gp&=9?3Kx3dMX@+P&ywQUB8vVF^GU{)Od)TGp3CM*uMhQ+!qWlv~ zZJCJfN;%`fW#_M2z3F%pG-pOGz+MF@DR9DUyz}T}enh$fc4zX1jClL|b-$rR?yv2u z9@g2|r@s@;r0*!#Cj(+*tG-*xk0y%DpsF0h`ZIR>XWGlFE35qs59(A>nXl9y?C*G> zl3=4(LISde=jU90a7|}>#{8eoqL+FNC(a@WQC~8&?d`dt2aiS+BL9A2Y+l7y>cK~{ z26O)&i%9mK3|*Xfj5!u#suM(z$0sVAd1m;o4rj?tZq_JuIS}>f5b|YwJ4T%Gytf76 zgh9IE{=SQJ2;Mv23ZeZ}OAN^8=r#4{Nj^eb9AwW=St6aM)RkM>m2Ntf+ooAoWXHozP0Wa<+Sy2fP@BUzT= z>8CAl1gVOdO&zU6X4|Hvk_S4$F=!0i20_wR4j; zCDc5ZJ5?c>ce~VBWMS{#l$E}jg3PRQ{w;Q~?a+|Hi4V(@63Cx)wt2r|%uQ-M-g-gr z*apO?To}|HadmmjaNRlu#KNRFLTbaB0y@)Rj`7S16>CkPBn^J)=oa2&bZu5fuQ^-% zLD6b_J-^88MIIWv{6S}r>6?ZUT50-as#3_%liNbsoLk7ge)G#bvCeo_ml|SI z5!5(0x^LM2g<%V0?LgXyi%r$3NM)oAT|D^k)T&FW$k58ul75ml!x0SyRNJ$!cZ`u5 z)G%fC)C`A;}1z9ClgQKd77`@Fo5}>S5$35WKi$vZ*9-aj0!-e834F~nur+s z-Y{%Ob%fT!rxqJm-Wc&XmG4>BCrP`0sjzxC{POVDkgg*u3zNxZ!Puj*+N)EXyOk+@ zJE;qscU$Jx$25zgR#^^Wm zWEM{%{cov$gx2wa( z*4+tDg&XA>)L3i0)h!H5?H^W2$lk4z;4*W_;tFN4EaJqzxnLtYnEkw5lMv9I^X%ST z+w;*go5eFSvOx-3dUC~bmdw}8+sOB7$k8Gnw^R8e20;quVQFq5&{3lnm{pOSc(SqJ zvVHPJx05dZ@ufAy1X+nrakyu@Q-!ND{UN-t^>Dz=v6j4?@RF1xp(BpgUB$;Oj&vBE zvMwwZT_`$!sEaK$oBCUs4(p8CnQ}lLh9PAf6^x0&^*?Fl`^{FvW#Z6 zX?Zl1qn(gfTyT*~wmO3NsGt4F^SIi~06#Z$EG@=Kywd&sM)|cZmh*0jq5P^GYqYbf z#m-3l&bOjx%-PA1wOQ{^_j|hCJ)U)Gp=dlT#NeBt_-!%`ZG> zmY;qvI5ByTdE#2KSJjj4n4}5sZHo_n%F9{IcsnSpem1rL;jZp(lJVS+&@!shJkPX!$Ot!u!jpI6{e?wze}SmVkC@Y$!!p=Db{RKds7y{G9QAMMz=l=-c(Jb~$Z* zMZR78;7*y#8{}?eu4t6h*pR1b3Gt%0`{3(uowK`$EE_zVe${7ar}de(zE`Kt1-d6p z?{(?R(bQVEF6z#$iWQLsxk(Ee;E8Wzw}%O<<86+b>H1BRw6+*&jy`>mc4>3?z$HxF z;gq=Cbmx+oe%>>pe3?qt`T*pH^$)JxdZ7DCj9^@9E;fjTn)7#bc}5-WhsG8KC1rMA z>FbdZR#UZfL;=(;ap9@)Mei~6kiE9K#m!IOGkbiy-Kr`gtu1fPC5M;YDXEVft$Mnw zxuPe32RXieO=D1s%hi{09eF1r*Cm(tW8NvdO-+1+dtD}FUvHCDiW`jAX?(hv_Xpge z%Q!DE?TM%>vGy4(S=6l>-rcg)Al33n{F1|Y=@|O9jVth--c*=)*YVZ-f}N*ovc30xKIpVHof{g@qb)jbY{GpKa&0lU zNsz^;(n@HAw$lN1&y#I3p{5!ryCHA0G800!yIIZH9@`R{a&uLjSe5SqCS KwZzn8|9=1lQrvg| literal 0 HcmV?d00001