diff --git a/.github/workflows/test-paddle.yaml b/.github/workflows/test-paddle.yaml index e517fd1..985394f 100644 --- a/.github/workflows/test-paddle.yaml +++ b/.github/workflows/test-paddle.yaml @@ -47,7 +47,8 @@ jobs: COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/1.log 2>&1 COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/2.log 2>&1 & COVERAGE_FILE=.coverage_3 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/3.log 2>&1 - coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 + python -m paddle.distributed.launch --nproc_per_node 2 run_distributed_paddle_test.py -s --cov=${LIBDIR} test_multi_paddle.py + coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 .coverage_4 .coverage_5 coverage html mv htmlcov /tmp/pytest-log - name: upload pytest log diff --git a/examples/paddle_case/run_parallel.py b/examples/paddle_case/run_parallel.py index 59c273b..f132337 100644 --- a/examples/paddle_case/run_parallel.py +++ b/examples/paddle_case/run_parallel.py @@ -14,9 +14,9 @@ import paddle.distributed as dist from fastsafetensors import SafeTensorsFileLoader dist.init_parallel_env() -backend = "nccl" if paddle.is_compiled_with_cuda() else "gloo" +backend = "nccl" if paddle.device.cuda.device_count() else "gloo" pg = dist.new_group(ranks=[0,1], backend=backend) -device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu" +device = "gpu" if paddle.device.cuda.device_count() else "cpu" loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True, framework="paddle") loader.add_filenames({0: ["a_paddle.safetensors"], 1:["b_paddle.safetensors"]}) # {rank: files} diff --git a/examples/paddle_case/run_single.py b/examples/paddle_case/run_single.py index e7679a6..6d5b042 100644 --- a/examples/paddle_case/run_single.py +++ b/examples/paddle_case/run_single.py @@ -1,13 +1,13 @@ import paddle from fastsafetensors import SafeTensorsFileLoader, SingleGroup -device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu" +device = "gpu:0" if paddle.device.cuda.device_count() else "cpu" loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework="paddle") loader.add_filenames({0: ["a_paddle.safetensors", "b_paddle.safetensors"]}) # {rank: files} fb = loader.copy_files_to_device() tensor_a0 = fb.get_tensor(tensor_name="a0") tensor_b0 = fb.get_tensor(tensor_name="b0") -print(f"a0: {tensor_a0}") -mycat = paddle.concat([tensor_a0, tensor_b0], axis=1) +print(f"a0: {tensor_a0}\n b0 : {tensor_b0}") +mycat = paddle.concat([tensor_a0, tensor_b0]) print(f"cat: {mycat}, size={mycat.size}") fb.close() loader.close() diff --git a/examples/run_paddle_parallel_cpu.sh b/examples/run_paddle_parallel_cpu.sh new file mode 100755 index 0000000..6a95391 --- /dev/null +++ b/examples/run_paddle_parallel_cpu.sh @@ -0,0 +1,9 @@ +# !/usr/bin/env python3 +PIDS=() + +runner="python -m paddle.distributed.launch" + +cd paddle_case +rm -rf log +# It can only be used on the CPU version of paddlepaddle +${runner} --nproc_per_node 2 run_parallel.py \ No newline at end of file diff --git a/examples/run_paddle_parallel_gpu.sh b/examples/run_paddle_parallel_gpu.sh new file mode 100755 index 0000000..8995129 --- /dev/null +++ b/examples/run_paddle_parallel_gpu.sh @@ -0,0 +1,11 @@ +# !/usr/bin/env python3 +PIDS=() + +runner="python -m paddle.distributed.launch" + +cd paddle_case +rm -rf log +# It can only be used on the GPU version of paddlepaddle-gpu +# A machine multy gpu (case : 1 machine 2 gpus) +# Different to torch script because the paddle distributed use nccl to communicate in gpus +CUDA_VISIBLE_DEVICES=0,1 ${runner} --gpus 0,1 run_parallel.py \ No newline at end of file diff --git a/examples/run_paddle_parrallel.sh b/examples/run_paddle_parrallel.sh deleted file mode 100755 index 108242e..0000000 --- a/examples/run_paddle_parrallel.sh +++ /dev/null @@ -1,15 +0,0 @@ -# !/usr/bin/env python3 -PIDS=() - -runner="python -m paddle.distributed.launch" -# runner="torchrun" - -cd paddle_case -rm -rf log -# one machine multy gpu (case : 1 machine 2 gpus) -# different to torch script because the paddle distributed use nccl to communicate in gpus -CUDA_VISIBLE_DEVICES=0 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=0 run_parallel.py & -PIDS+=($!) -CUDA_VISIBLE_DEVICES=1 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=1 run_parallel.py & -PIDS+=($!) -wait "${PIDS[@]}" \ No newline at end of file diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 4c75910..a26cbb4 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -71,8 +71,17 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: in if device == "cpu": d_id = None else: - d_id = device.split(":") # "gpu:0" or "gpu" - d_id = int(d_id[1]) if len(d_id) == 2 else 0 + if isinstance(self.pg, SingleGroup): + # For single (gpu:x, gpu) + # gpu:x, like gpu:0, gpu:1, ... + d_id = device.split(":") + d_id = int(d_id[1]) if len(d_id) == 2 else 0 + else: + # For distributed + # The gpu determines the current rank + # rank0 use gpu:0, rank1 use gpu:1 + d_id = self.pg.rank() % paddle.device.cuda.device_count() + self.device = f"gpu:{d_id}" node = get_device_numa_node(d_id) if node is not None: fstcpp.set_numa_node(node) @@ -140,7 +149,7 @@ def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=T if self.device.type != "cpu": torch.cuda.set_device(self.device) elif paddle_loaded and self.framework == "paddle": - if self.device != paddle.CPUPlace(): + if "gpu" in self.device: paddle.set_device(self.device) need_wait: List[LazyTensorFactory] = [] diff --git a/tests/conftest.py b/tests/conftest.py index 40ee22e..7b38f55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import torch.distributed as dist from fastsafetensors import cpp as fstcpp from fastsafetensors import SingleGroup +from fastsafetensors.common import paddle_loaded from typing import List TESTS_DIR = os.path.dirname(__file__) @@ -39,6 +40,24 @@ def pg(): PG = SingleGroup() return PG +@pytest.fixture(scope='session', autouse=True) +def pg_paddle(): + PG = SingleGroup() + + if paddle_loaded: + # The following code can only be successfully + # executed by running the code using + # `python -m paddle.distributed.launch` + try: + import paddle + import paddle.distributed as dist + dist.init_parallel_env() + backend = "nccl" if paddle.device.cuda.device_count() else "gloo" + PG = dist.new_group(ranks=[0,1], backend=backend) + except: + pass + return PG + @pytest.fixture(scope='session', autouse=True) def dev_init() -> None: if torch.cuda.is_available(): diff --git a/tests/run_distributed_paddle_test.py b/tests/run_distributed_paddle_test.py new file mode 100644 index 0000000..ceea4e0 --- /dev/null +++ b/tests/run_distributed_paddle_test.py @@ -0,0 +1,11 @@ +import pytest +import sys +import os + +if __name__ == "__main__": + # There are 4 commands before this test + # GPU ditributed need at least 2 GPU + rank = int(os.getenv("PADDLE_TRAINER_ID")) + 4 + os.environ["COVERAGE_FILE"] = f".coverage_{rank}" + pytest_args = sys.argv[1:] + sys.exit(pytest.main(pytest_args)) \ No newline at end of file diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py index ccb5574..c32e159 100644 --- a/tests/test_fastsafetensors.py +++ b/tests/test_fastsafetensors.py @@ -228,7 +228,7 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework="pytorch"): if framework == "pytorch": data_type = torch.float16 elif framework == "paddle": - # There are some lack of accuracy in paddle.float16 (about 1e-4) + # There are some lack of accuracy in paddle.float16 (about 1e-4) in cpu. data_type = paddle.float32 else: raise NotImplementedError(f"Do not support the framework: {framework}") diff --git a/tests/test_multi_paddle.py b/tests/test_multi_paddle.py new file mode 100644 index 0000000..6c8fda6 --- /dev/null +++ b/tests/test_multi_paddle.py @@ -0,0 +1,45 @@ +# Copyright 2024- IBM Inc. All rights reserved +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import paddle +from safetensors import safe_open + +from fastsafetensors import cpp as fstcpp +from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata +from fastsafetensors.common import paddle_loaded + +def test_shuffle_paddle(fstcpp_log, input_files, pg_paddle): + if paddle_loaded: + device = "gpu" if paddle.device.cuda.device_count() else "cpu" + loader = SafeTensorsFileLoader(pg_paddle, device, nogds=True, debug_log=True, framework="paddle") + loader.add_filenames({0: input_files}) + bufs = loader.copy_files_to_device() + key_dims = {key: -1 for key in loader.get_keys()} + for i in range(0, 12): + key_dims[f"h.{i}.mlp.c_proj.weight"] = 0 + key_dims[f"h.{i}.mlp.c_fc.weight"] = 1 + tensors = bufs.as_dict(key_dims) + with safe_open(input_files[0], framework="pt") as f: + for key in tensors.keys(): + dim = key_dims[key] + if dim == 0 or dim == 1: + t = f.get_slice(key) + rank_slices = () + shape = t.get_shape() + size = shape[dim] + block_size = (size + pg_paddle.process_group.size() - 1) // pg_paddle.process_group.size() + for i in range(0, len(shape)): + if i < dim: + rank_slices += (slice(None,None,None),) + elif i == dim: + rank_slices += (slice(pg_paddle.process_group.rank() * block_size, (pg_paddle.process_group.rank() + 1) * block_size, 1),) + break + t = t[rank_slices] + t = t.clone().detach() + else: + t = f.get_tensor(key) + assert paddle.all(paddle.to_tensor(t.numpy(),place=loader.device).equal(tensors[key])) + bufs.close() + loader.close()