diff --git a/configs/debug/debug.toml b/configs/debug/debug.toml index eedfea20..ae283e9e 100644 --- a/configs/debug/debug.toml +++ b/configs/debug/debug.toml @@ -8,7 +8,7 @@ sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 16 warmup_steps = 10 -total_steps = 5000 +total_steps = 10 [data] -fake_data = true \ No newline at end of file +fake_data = true diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index 24a8602c..9283c964 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -8,10 +8,11 @@ sharding_strategy = "FULL_SHARD" [optim] batch_size = 16 warmup_steps = 10 -total_steps = 5000 +total_steps = 10 [data] fake_data = true [diloco] -inner_steps = 10 \ No newline at end of file +inner_steps = 5 + diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index b669dd58..8e98716a 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,37 +1,59 @@ +import copy +import os import subprocess import pytest import socket -def get_random_available_port(): +def get_random_available_port_list(num_port): # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] + ports = [] + while len(ports) < num_port: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + new_port = s.getsockname()[1] -@pytest.fixture() -def random_available_port(): - return get_random_available_port() + if new_port not in ports: + ports.append(new_port) + return ports -@pytest.mark.parametrize("num_gpu", [1, 2]) -def test_multi_gpu(random_available_port, num_gpu): - cmd = [ - "torchrun", - f"--nproc_per_node={num_gpu}", - "--rdzv-endpoint", - f"localhost:{random_available_port}", - "src/zeroband/train.py", - "@configs/debug/debug.toml", - "--optim.total_steps", - "10", - ] - result = subprocess.run(cmd) +def get_random_available_port(num_port): + return get_random_available_port_list(num_port)[0] - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") + +def gpus_to_use(num_nodes, num_gpu, rank): + return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu))) + + +@pytest.mark.parametrize("num_gpus", [[1, 1], [2, 1], [1, 2]]) +@pytest.mark.parametrize("config", ["debug/debug.toml", "debug/diloco.toml"]) +def test_multi_gpu(num_gpus, config): + num_nodes, num_gpu = num_gpus[0], num_gpus[1] + + processes = [] + ports = get_random_available_port_list(num_nodes) + for i in range(num_nodes): + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpu}", + "--rdzv-endpoint", + f"localhost:{ports[i]}", + "src/zeroband/train.py", + f"@configs/{config}", + ] + + env = copy.deepcopy(os.environ) + env["CUDA_VISIBLE_DEVICES"] = gpus_to_use(num_nodes, num_gpu, i) + process1 = subprocess.Popen(cmd, env=env) + processes.append(process1) + + for process in processes: + result = process.wait() + if result != 0: + pytest.fail(f"Process {result} failed {result}") @pytest.mark.parametrize("num_gpu", [1, 2]) @@ -44,7 +66,7 @@ def test_multi_gpu_diloco(random_available_port, num_gpu): "src/zeroband/train.py", "@configs/debug/diloco.toml", "--optim.total_steps", - "10", + "50", ] result = subprocess.run(cmd)