44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import os
7+ from enum import Enum
78from typing import List , Optional
89
910import torch
1011import torch .distributed as dist
12+ from torch .distributed import DeviceMesh
13+ from torch .distributed .tensor .parallel import ColwiseParallel , RowwiseParallel , parallelize_module
1114from torch import nn
1215if os .uname ().sysname != "Darwin" :
1316 from torch .distributed import _functional_collectives as funcol
1619 funcol = None
1720
1821from model import Attention , FeedForward , Transformer
19- from quantize import WeightOnlyInt4Linear
22+ from quantize import WeightOnlyInt4Linear , WeightOnlyInt8Linear
2023
2124
2225def _get_rank () -> int :
@@ -33,6 +36,12 @@ def local_break():
3336def _get_world_size () -> int :
3437 return int (os .environ .get ("LOCAL_WORLD_SIZE" , "1" ))
3538
39+ global device_mesh
40+
41+ def _get_tp_mesh ():
42+ # device_mesh has only TP dimension for now
43+ return device_mesh
44+
3645def maybe_init_dist () -> Optional [int ]:
3746 try :
3847 # provided by torchrun
@@ -48,86 +57,97 @@ def maybe_init_dist() -> Optional[int]:
4857
4958 torch .cuda .set_device (rank )
5059 dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
60+
61+ global device_mesh
62+ device_mesh = dist .init_device_mesh (
63+ "cuda" ,
64+ (world_size ,), # Only TP dimension for now
65+ )
5166 return rank
5267
68+ class TPMode (Enum ):
69+ MANUAL = 0
70+ DTENSOR = 1
5371
54- def _apply_tp_linear (linear : nn .Linear , style : str , weight_splits : List [ int ] = [] ) -> None :
72+ def _apply_tp_linear (linear : nn .Linear , style : str ) -> None :
5573 rank = _get_rank ()
5674 world_size = _get_world_size ()
75+ tp_mesh = _get_tp_mesh ()
5776
5877 # Linear's weight matrix is transposed, and is of shape
5978 # (linear.out_features, linear.in_features)
6079 dim_lookup = {
61- "colwise" : (0 , "out_features" ),
62- "rowwise" : (1 , "in_features" )
80+ "colwise" : (0 , "out_features" , ColwiseParallel () ),
81+ "rowwise" : (1 , "in_features" , RowwiseParallel ()),
6382 }
6483 assert style in dim_lookup
65- shard_dim , size_attr = dim_lookup [style ]
84+ shard_dim , size_attr , tp_plan = dim_lookup [style ]
6685
6786 # ensure we can shard evenly
6887 assert getattr (linear , size_attr ) % world_size == 0
6988 def shard (x , dim ):
7089 assert x .size (dim = dim ) % world_size == 0
7190 return torch .tensor_split (x , world_size , dim = dim )[rank ]
7291
73- def shard_qkv (qkv , dim , weight_splits ):
74- q , k , v = qkv .split (weight_splits , dim = dim )
75- q = shard (q , dim )
76- k = shard (k , dim )
77- v = shard (v , dim )
78- return torch .cat ((q ,k ,v ), dim = dim )
79-
80- # shard
81- if weight_splits :
82- # attention
83- assert len (weight_splits ) == 3
84-
85- if isinstance (linear , WeightOnlyInt4Linear ):
86- sharded_weight = shard_qkv (linear .weight , shard_dim , [i // 8 for i in weight_splits ])
87- linear .scales_and_zeros = shard_qkv (linear .scales_and_zeros , 1 - shard_dim , weight_splits )
88- else :
89- sharded_weight = shard_qkv (linear .weight , shard_dim , weight_splits )
90- if hasattr (linear , "scales" ) and style == "colwise" :
91- linear .scales = shard_qkv (linear .scales , 0 , weight_splits )
92- else :
93- sharded_weight = shard (linear .weight , shard_dim )
94- if isinstance (linear , WeightOnlyInt4Linear ):
92+ def shard_scale (linear , shard_dim ):
93+ if hasattr (linear , "scales_and_zeros" ):
9594 linear .scales_and_zeros = shard (linear .scales_and_zeros , 1 - shard_dim )
9695 if style == "rowwise" :
9796 assert linear .scales_and_zeros .shape [0 ] * 32 == sharded_weight .shape [1 ] * sharded_weight .shape [2 ] * sharded_weight .shape [3 ]
9897 assert linear .scales_and_zeros .shape [1 ] == sharded_weight .shape [0 ] * 8
99- if hasattr (linear , "scales" ) and style == "colwise" :
100- linear .scales = shard (linear .scales , 0 )
98+ elif hasattr (linear , "scale" ):
99+ if style == "colwise" :
100+ linear .scales = shard (linear .scales , 0 )
101+
102+ # shard
103+ tp_mode : TPMode
104+ if isinstance (linear , (WeightOnlyInt4Linear , WeightOnlyInt8Linear )):
105+ # TODO: DTensor doesn't have a way to distribute quantized tensor yet.
106+ # Should revisit when that capability is added.
107+ sharded_weight = shard (linear .weight , shard_dim )
108+ linear .weight = nn .Parameter (sharded_weight , requires_grad = False )
109+ shard_scale (linear , shard_dim )
110+ tp_mode = TPMode .MANUAL
111+ else :
112+ # Use DTensor based TP
113+ parallelize_module (linear , tp_mesh , tp_plan )
114+ tp_mode = TPMode .DTENSOR
101115
102116 # local_break()
103- linear .weight = nn .Parameter (sharded_weight , requires_grad = False )
104117 setattr (linear , size_attr , getattr (linear , size_attr ) // world_size )
105118
106119 # shape info should still be synced
107120 # assert linear.weight.shape == (linear.out_features, linear.in_features)
121+ return tp_mode
108122
109123
110124def _apply_tp_ffn (mlp : FeedForward ) -> None :
111125 assert hasattr (mlp , "w1" )
112126 assert hasattr (mlp , "w3" )
113127 assert hasattr (mlp , "w2" )
114128
115- _apply_tp_linear (mlp .w1 , "colwise" )
116- _apply_tp_linear (mlp .w3 , "colwise" )
117- _apply_tp_linear (mlp .w2 , "rowwise" )
129+ tp_mode = _apply_tp_linear (mlp .w1 , "colwise" )
130+ tp_mode = _apply_tp_linear (mlp .w3 , "colwise" )
131+ tp_mode = _apply_tp_linear (mlp .w2 , "rowwise" )
118132
119- world_size = _get_world_size ()
120- mlp .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
121- output , "sum" , list (range (world_size ))))
133+ if tp_mode == TPMode .MANUAL :
134+ # In manual mode, we need to manually add an all-reduce at the end
135+ world_size = _get_world_size ()
136+ mlp .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
137+ output , "sum" , list (range (world_size ))))
122138
123139
124140def _apply_tp_attn (attn : Attention ) -> None :
125- assert hasattr (attn , "wqkv" )
141+ assert hasattr (attn , "wq" )
142+ assert hasattr (attn , "wk" )
143+ assert hasattr (attn , "wv" )
126144 assert hasattr (attn , "wo" )
127145
128146 kv_size = attn .n_local_heads * attn .head_dim
129- _apply_tp_linear (attn .wqkv , "colwise" , [attn .dim , kv_size , kv_size ])
130- _apply_tp_linear (attn .wo , "rowwise" )
147+ tp_mode = _apply_tp_linear (attn .wq , "colwise" )
148+ tp_mode = _apply_tp_linear (attn .wk , "colwise" )
149+ tp_mode = _apply_tp_linear (attn .wv , "colwise" )
150+ tp_mode = _apply_tp_linear (attn .wo , "rowwise" )
131151
132152 # overwrite
133153 world_size = _get_world_size ()
@@ -136,8 +156,10 @@ def _apply_tp_attn(attn: Attention) -> None:
136156 attn .head_dim = attn .dim // attn .n_head
137157 attn .n_local_heads = attn .n_local_heads // world_size
138158
139- attn .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
140- output [0 ], "sum" , list (range (world_size ))))
159+ if tp_mode == TPMode .MANUAL :
160+ # In manual mode, we need to manually add an all-reduce at the end
161+ attn .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
162+ output [0 ], "sum" , list (range (world_size ))))
141163
142164
143165def _apply_tp_Transformer (Transformer : Transformer ) -> None :
0 commit comments