Skip to content

Commit 15f87cc

Browse files
gioannidesdacorvo
gioannides
authored andcommitted
Update PR with more edge cases where tensor may not be contiguous after placed on cpu
1 parent ece4f98 commit 15f87cc

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

optimum/neuron/distributed/checkpointing.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,21 @@ def consolidate_tensor_parallel_checkpoints(
145145
# This might not be the case anymore when `ParameterMetadata` uses slices.
146146
sharded_metadata = sharded_metadatas[name]
147147
if sharded_metadata.is_tied:
148-
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu")
148+
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
149149
else:
150150
# Ensure that all tensors are contiguous before concatenating or further processing
151151
weights = [state_dict[name].contiguous() for state_dict in state_dicts]
152152
tp_size = len(weights)
153153

154-
full_weight = torch.cat(
155-
weights,
156-
dim=sharded_metadata.partition_dim,
157-
).contiguous() # Ensure the result is also contiguous
158-
154+
full_weight = (
155+
torch.cat(
156+
weights,
157+
dim=sharded_metadata.partition_dim,
158+
)
159+
.to("cpu")
160+
.contiguous()
161+
) # Ensure the result is also contiguous
162+
159163
if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
160164
full_weight = (
161165
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()

0 commit comments

Comments
 (0)