Skip to content

Commit

Permalink
reenable last all reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 2, 2024
1 parent 23ad8e5 commit 74a4707
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions tests/test_dist/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def foo(**kwargs):
edm.maybe_reinit_global_pg()
if test_value == 1:
return
# assert edm.mesh_count == 2
assert edm.mesh_count == 2
assert edm.global_pg.size() == global_world_size

# a = torch.arange(3) * (test_value + 1)
# sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
# dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
# assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))
a = torch.arange(3) * (test_value + 1)
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

dist.barrier(edm.global_pg)

Expand All @@ -155,13 +155,13 @@ def bar(**kwargs):
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

edm.maybe_reinit_global_pg()
# assert edm.mesh_count == 2
assert edm.mesh_count == 2
assert edm.global_pg.size() == global_world_size

# a = torch.arange(3) * test_value
# sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
# dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
# assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))
a = torch.arange(3) * test_value
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

dist.barrier(edm.global_pg)

Expand All @@ -186,7 +186,6 @@ def bar(**kwargs):
"GLOBAL_RANK": str(global_rank),
"GLOBAL_WORLD_SIZE": str(global_world_size),
"ZERO_BAND_LOG_LEVEL": "DEBUG",
"ZERO_BAND_LOG_ALL_RANK": "true",
"TEST_VALUE": str(global_rank),
},
)
Expand Down

0 comments on commit 74a4707

Please sign in to comment.