Skip to content

Commit c1394d1

Browse files
committed
Try forking before asyncio usage
1 parent 2eed33a commit c1394d1

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

tests/test_allreduce.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from hivemind.p2p import P2P
1414

1515

16-
@pytest.mark.forked
1716
@pytest.mark.asyncio
17+
@pytest.mark.forked
1818
async def test_partitioning():
1919
all_tensors = [
2020
torch.randn(30_000, 128),
@@ -61,8 +61,8 @@ async def write_tensors():
6161
],
6262
)
6363
@pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
64-
@pytest.mark.forked
6564
@pytest.mark.asyncio
65+
@pytest.mark.forked
6666
async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float]):
6767
partition = TensorPartContainer(tensors, peer_fractions, part_size_bytes=16)
6868
for peer_index in range(len(peer_fractions)):
@@ -75,8 +75,8 @@ async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fra
7575
tensor_index += 1
7676

7777

78-
@pytest.mark.forked
7978
@pytest.mark.asyncio
79+
@pytest.mark.forked
8080
async def test_partitioning_asynchronous():
8181
"""ensure that tensor partitioning does not interfere with asynchronous code"""
8282
tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
@@ -114,8 +114,8 @@ async def wait_synchronously():
114114
@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
115115
@pytest.mark.parametrize("num_parts", [0, 1, 100])
116116
@pytest.mark.parametrize("synchronize_prob", [1.0, 0.1, 0.0])
117-
@pytest.mark.forked
118117
@pytest.mark.asyncio
118+
@pytest.mark.forked
119119
async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
120120
tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
121121
reducer = TensorPartReducer(tensor_part_shapes, num_senders)
@@ -168,8 +168,8 @@ async def send_tensors(sender_index: int):
168168
((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2**20),
169169
],
170170
)
171-
@pytest.mark.forked
172171
@pytest.mark.asyncio
172+
@pytest.mark.forked
173173
async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
174174
"""Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
175175

tests/test_averaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from test_utils.dht_swarms import launch_dht_instances
1818

1919

20-
@pytest.mark.forked
2120
@pytest.mark.asyncio
21+
@pytest.mark.forked
2222
async def test_key_manager():
2323
dht = hivemind.DHT(start=True)
2424
key_manager = GroupKeyManager(

tests/test_dht.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def test_run_coroutine():
104104
dht.shutdown()
105105

106106

107-
@pytest.mark.forked
108107
@pytest.mark.asyncio
108+
@pytest.mark.forked
109109
async def test_dht_get_visible_maddrs():
110110
# test 1: IPv4 localhost multiaddr is visible by default
111111

0 commit comments

Comments
 (0)