File tree 1 file changed +10
-6
lines changed
optimum/neuron/distributed
1 file changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -145,17 +145,21 @@ def consolidate_tensor_parallel_checkpoints(
145
145
# This might not be the case anymore when `ParameterMetadata` uses slices.
146
146
sharded_metadata = sharded_metadatas [name ]
147
147
if sharded_metadata .is_tied :
148
- consolidated_state_dict [original_name ] = state_dicts [0 ][name ].to ("cpu" )
148
+ consolidated_state_dict [original_name ] = state_dicts [0 ][name ].to ("cpu" ). contiguous ()
149
149
else :
150
150
# Ensure that all tensors are contiguous before concatenating or further processing
151
151
weights = [state_dict [name ].contiguous () for state_dict in state_dicts ]
152
152
tp_size = len (weights )
153
153
154
- full_weight = torch .cat (
155
- weights ,
156
- dim = sharded_metadata .partition_dim ,
157
- ).contiguous () # Ensure the result is also contiguous
158
-
154
+ full_weight = (
155
+ torch .cat (
156
+ weights ,
157
+ dim = sharded_metadata .partition_dim ,
158
+ )
159
+ .to ("cpu" )
160
+ .contiguous ()
161
+ ) # Ensure the result is also contiguous
162
+
159
163
if weight_name in ["weight_k" , "weight_v" , "bias_k" , "bias_v" ]:
160
164
full_weight = (
161
165
torch .chunk (full_weight , gqa_qkv_metadata ["kv_size_multiplier" ], dim = 0 )[0 ].detach ().clone ()
You can’t perform that action at this time.
0 commit comments