Skip to content

Commit

Permalink
add better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 22, 2024
1 parent 242c93e commit 5e5537f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 27 deletions.
4 changes: 2 additions & 2 deletions configs/debug/debug.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sharding_strategy = "SHARD_GRAD_OP"
[optim]
batch_size = 16
warmup_steps = 10
total_steps = 5000
total_steps = 10

[data]
fake_data = true
fake_data = true
5 changes: 3 additions & 2 deletions configs/debug/diloco.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ sharding_strategy = "FULL_SHARD"
[optim]
batch_size = 16
warmup_steps = 10
total_steps = 5000
total_steps = 10

[data]
fake_data = true

[diloco]
inner_steps = 10
inner_steps = 5

68 changes: 45 additions & 23 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,59 @@
import copy
import os
import subprocess
import pytest
import socket


def get_random_available_port():
def get_random_available_port_list(num_port):
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
ports = []

while len(ports) < num_port:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
new_port = s.getsockname()[1]

@pytest.fixture()
def random_available_port():
return get_random_available_port()
if new_port not in ports:
ports.append(new_port)

return ports

@pytest.mark.parametrize("num_gpu", [1, 2])
def test_multi_gpu(random_available_port, num_gpu):
cmd = [
"torchrun",
f"--nproc_per_node={num_gpu}",
"--rdzv-endpoint",
f"localhost:{random_available_port}",
"src/zeroband/train.py",
"@configs/debug/debug.toml",
"--optim.total_steps",
"10",
]

result = subprocess.run(cmd)
def get_random_available_port(num_port):
return get_random_available_port_list(num_port)[0]

if result.returncode != 0:
pytest.fail(f"Process {result} failed {result.stderr}")

def gpus_to_use(num_nodes, num_gpu, rank):
return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu)))


@pytest.mark.parametrize("num_gpus", [[1, 1], [2, 1], [1, 2]])
@pytest.mark.parametrize("config", ["debug/debug.toml", "debug/diloco.toml"])
def test_multi_gpu(num_gpus, config):
num_nodes, num_gpu = num_gpus[0], num_gpus[1]

processes = []
ports = get_random_available_port_list(num_nodes)
for i in range(num_nodes):
cmd = [
"torchrun",
f"--nproc_per_node={num_gpu}",
"--rdzv-endpoint",
f"localhost:{ports[i]}",
"src/zeroband/train.py",
f"@configs/{config}",
]

env = copy.deepcopy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = gpus_to_use(num_nodes, num_gpu, i)
process1 = subprocess.Popen(cmd, env=env)
processes.append(process1)

for process in processes:
result = process.wait()
if result != 0:
pytest.fail(f"Process {result} failed {result}")


@pytest.mark.parametrize("num_gpu", [1, 2])
Expand All @@ -44,7 +66,7 @@ def test_multi_gpu_diloco(random_available_port, num_gpu):
"src/zeroband/train.py",
"@configs/debug/diloco.toml",
"--optim.total_steps",
"10",
"50",
]

result = subprocess.run(cmd)
Expand Down

0 comments on commit 5e5537f

Please sign in to comment.