Skip to content

Commit

Permalink
Fix sharding formula (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinch-nv authored Nov 21, 2024
1 parent d62aa12 commit bd38276
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions docs/proposals/ONNXMultiDeviceProposal.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ A key observation in the above example shows how indexing is performed when mult
```
split_tensors = []
for a in range(num_shards_a):
a_index = a * input.shape[axis0] / num_shards_a
a_width = input.shape[axis0] / num_shards_a
a_index = a * a_width
for b in range(num_shards_b):
b_index = b * input.shape[axis1] / num_shards_b
split = input[a_index : a_index + num_shards_a, b_index : b_index + num_shards_b]
b_width = input.shape[axis1] / num_shards_b
b_index = b * b_width
split = input[a_index : a_index + a_width, b_index : b_index + b_width]
split_tensors.append(split)
```

Expand Down

0 comments on commit bd38276

Please sign in to comment.