1313from hivemind .p2p import P2P
1414
1515
16- @pytest .mark .forked
1716@pytest .mark .asyncio
17+ @pytest .mark .forked
1818async def test_partitioning ():
1919 all_tensors = [
2020 torch .randn (30_000 , 128 ),
@@ -61,8 +61,8 @@ async def write_tensors():
6161 ],
6262)
6363@pytest .mark .parametrize ("peer_fractions" , [(0.33 , 0.44 , 0.23 ), (0.5 , 0.5 ), (0.1 , 0.0 , 0.9 ), (1.0 ,), (0.1 ,) * 9 ])
64- @pytest .mark .forked
6564@pytest .mark .asyncio
65+ @pytest .mark .forked
6666async def test_partitioning_edge_cases (tensors : Sequence [torch .Tensor ], peer_fractions : Sequence [float ]):
6767 partition = TensorPartContainer (tensors , peer_fractions , part_size_bytes = 16 )
6868 for peer_index in range (len (peer_fractions )):
@@ -75,8 +75,8 @@ async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fra
7575 tensor_index += 1
7676
7777
78- @pytest .mark .forked
7978@pytest .mark .asyncio
79+ @pytest .mark .forked
8080async def test_partitioning_asynchronous ():
8181 """ensure that tensor partitioning does not interfere with asynchronous code"""
8282 tensors = [torch .randn (2048 , 2048 ), torch .randn (1024 , 4096 ), torch .randn (4096 , 1024 ), torch .randn (30_000 , 1024 )]
@@ -114,8 +114,8 @@ async def wait_synchronously():
114114@pytest .mark .parametrize ("num_senders" , [1 , 2 , 4 , 10 ])
115115@pytest .mark .parametrize ("num_parts" , [0 , 1 , 100 ])
116116@pytest .mark .parametrize ("synchronize_prob" , [1.0 , 0.1 , 0.0 ])
117- @pytest .mark .forked
118117@pytest .mark .asyncio
118+ @pytest .mark .forked
119119async def test_reducer (num_senders : int , num_parts : int , synchronize_prob : float ):
120120 tensor_part_shapes = [torch .Size ([i ]) for i in range (num_parts )]
121121 reducer = TensorPartReducer (tensor_part_shapes , num_senders )
@@ -168,8 +168,8 @@ async def send_tensors(sender_index: int):
168168 ((AUX , AUX , AUX , AUX ), (0.0 , 0.0 , 0.0 , 0.0 ), (1 , 2 , 3 , 4 ), 2 ** 20 ),
169169 ],
170170)
171- @pytest .mark .forked
172171@pytest .mark .asyncio
172+ @pytest .mark .forked
173173async def test_allreduce_protocol (peer_modes , averaging_weights , peer_fractions , part_size_bytes ):
174174 """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
175175
0 commit comments