diff --git a/docs/proposals/ONNXMultiDeviceProposal.md b/docs/proposals/ONNXMultiDeviceProposal.md index 209fbac97c2..ca15fef4bdd 100644 --- a/docs/proposals/ONNXMultiDeviceProposal.md +++ b/docs/proposals/ONNXMultiDeviceProposal.md @@ -115,10 +115,12 @@ A key observation in the above example shows how indexing is performed when mult ``` split_tensors = [] for a in range(num_shards_a): - a_index = a * input.shape[axis0] / num_shards_a + a_width = input.shape[axis0] / num_shards_a + a_index = a * a_width for b in range(num_shards_b): - b_index = b * input.shape[axis1] / num_shards_b - split = input[a_index : a_index + num_shards_a, b_index : b_index + num_shards_b] + b_width = input.shape[axis1] / num_shards_b + b_index = b * b_width + split = input[a_index : a_index + a_width, b_index : b_index + b_width] split_tensors.append(split) ```