Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test-paddle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ jobs:
COVERAGE_FILE=.coverage_1 CUDA_VISIBLE_DEVICES="" pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/1.log 2>&1
COVERAGE_FILE=.coverage_2 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/2.log 2>&1 &
COVERAGE_FILE=.coverage_3 torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 --no-python pytest -s --cov=${LIBDIR} test_multi.py > /tmp/pytest-log/3.log 2>&1
coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3
python -m paddle.distributed.launch --nproc_per_node 2 run_distributed_paddle_test.py -s --cov=${LIBDIR} test_multi_paddle.py
coverage combine .coverage_0 .coverage_1 .coverage_2 .coverage_3 .coverage_4 .coverage_5
coverage html
mv htmlcov /tmp/pytest-log
- name: upload pytest log
Expand Down
4 changes: 2 additions & 2 deletions examples/paddle_case/run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader
dist.init_parallel_env()
backend = "nccl" if paddle.is_compiled_with_cuda() else "gloo"
backend = "nccl" if paddle.device.cuda.device_count() else "gloo"
pg = dist.new_group(ranks=[0,1], backend=backend)
device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu"
device = "gpu" if paddle.device.cuda.device_count() else "cpu"
loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True, framework="paddle")
loader.add_filenames({0: ["a_paddle.safetensors"], 1:["b_paddle.safetensors"]}) # {rank: files}

Expand Down
6 changes: 3 additions & 3 deletions examples/paddle_case/run_single.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import paddle
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu"
device = "gpu:0" if paddle.device.cuda.device_count() else "cpu"
loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework="paddle")
loader.add_filenames({0: ["a_paddle.safetensors", "b_paddle.safetensors"]}) # {rank: files}
fb = loader.copy_files_to_device()
tensor_a0 = fb.get_tensor(tensor_name="a0")
tensor_b0 = fb.get_tensor(tensor_name="b0")
print(f"a0: {tensor_a0}")
mycat = paddle.concat([tensor_a0, tensor_b0], axis=1)
print(f"a0: {tensor_a0}\n b0 : {tensor_b0}")
mycat = paddle.concat([tensor_a0, tensor_b0])
print(f"cat: {mycat}, size={mycat.size}")
fb.close()
loader.close()
9 changes: 9 additions & 0 deletions examples/run_paddle_parallel_cpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# !/usr/bin/env python3
PIDS=()

runner="python -m paddle.distributed.launch"

cd paddle_case
rm -rf log
# It can only be used on the CPU version of paddlepaddle
${runner} --nproc_per_node 2 run_parallel.py
11 changes: 11 additions & 0 deletions examples/run_paddle_parallel_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# !/usr/bin/env python3
PIDS=()

runner="python -m paddle.distributed.launch"

cd paddle_case
rm -rf log
# It can only be used on the GPU version of paddlepaddle-gpu
# A machine multy gpu (case : 1 machine 2 gpus)
# Different to torch script because the paddle distributed use nccl to communicate in gpus
CUDA_VISIBLE_DEVICES=0,1 ${runner} --gpus 0,1 run_parallel.py
15 changes: 0 additions & 15 deletions examples/run_paddle_parrallel.sh

This file was deleted.

15 changes: 12 additions & 3 deletions fastsafetensors/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,17 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: in
if device == "cpu":
d_id = None
else:
d_id = device.split(":") # "gpu:0" or "gpu"
d_id = int(d_id[1]) if len(d_id) == 2 else 0
if isinstance(self.pg, SingleGroup):
# For single (gpu:x, gpu)
# gpu:x, like gpu:0, gpu:1, ...
d_id = device.split(":")
d_id = int(d_id[1]) if len(d_id) == 2 else 0
else:
# For distributed
# The gpu determines the current rank
# rank0 use gpu:0, rank1 use gpu:1
d_id = self.pg.rank() % paddle.device.cuda.device_count()
self.device = f"gpu:{d_id}"
Comment on lines -74 to +84
Copy link
Contributor Author

@zeroRains zeroRains Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this part, maybe, It dose not need to consider distributed case in fastsafetensors.

We just need to load the tensors to correct device which provided by user.

In a machine with multi gpus, user should set the device like that device="gpu:{pg.rank()}" in distributed code then send device to the SafeTensorsFileLoader so that different processes can load tensors to different gpus .

What do you think?

Copy link
Collaborator

@takeshi-yoshimura takeshi-yoshimura Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because safetensors files that are distributed online are not composed like that.

node = get_device_numa_node(d_id)
if node is not None:
fstcpp.set_numa_node(node)
Expand Down Expand Up @@ -140,7 +149,7 @@ def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=T
if self.device.type != "cpu":
torch.cuda.set_device(self.device)
elif paddle_loaded and self.framework == "paddle":
if self.device != paddle.CPUPlace():
if "gpu" in self.device:
paddle.set_device(self.device)

need_wait: List[LazyTensorFactory] = []
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.distributed as dist
from fastsafetensors import cpp as fstcpp
from fastsafetensors import SingleGroup
from fastsafetensors.common import paddle_loaded
from typing import List

TESTS_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -39,6 +40,24 @@ def pg():
PG = SingleGroup()
return PG

@pytest.fixture(scope='session', autouse=True)
def pg_paddle():
PG = SingleGroup()

if paddle_loaded:
# The following code can only be successfully
# executed by running the code using
# `python -m paddle.distributed.launch`
try:
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
backend = "nccl" if paddle.device.cuda.device_count() else "gloo"
PG = dist.new_group(ranks=[0,1], backend=backend)
except:
pass
return PG

@pytest.fixture(scope='session', autouse=True)
def dev_init() -> None:
if torch.cuda.is_available():
Expand Down
11 changes: 11 additions & 0 deletions tests/run_distributed_paddle_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
import sys
import os

if __name__ == "__main__":
# There are 4 commands before this test
# GPU ditributed need at least 2 GPU
rank = int(os.getenv("PADDLE_TRAINER_ID")) + 4
os.environ["COVERAGE_FILE"] = f".coverage_{rank}"
pytest_args = sys.argv[1:]
sys.exit(pytest.main(pytest_args))
2 changes: 1 addition & 1 deletion tests/test_fastsafetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework="pytorch"):
if framework == "pytorch":
data_type = torch.float16
elif framework == "paddle":
# There are some lack of accuracy in paddle.float16 (about 1e-4)
# There are some lack of accuracy in paddle.float16 (about 1e-4) in cpu.
data_type = paddle.float32
else:
raise NotImplementedError(f"Do not support the framework: {framework}")
Expand Down
45 changes: 45 additions & 0 deletions tests/test_multi_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024- IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import paddle
from safetensors import safe_open

from fastsafetensors import cpp as fstcpp
from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata
from fastsafetensors.common import paddle_loaded

def test_shuffle_paddle(fstcpp_log, input_files, pg_paddle):
if paddle_loaded:
device = "gpu" if paddle.device.cuda.device_count() else "cpu"
loader = SafeTensorsFileLoader(pg_paddle, device, nogds=True, debug_log=True, framework="paddle")
loader.add_filenames({0: input_files})
bufs = loader.copy_files_to_device()
key_dims = {key: -1 for key in loader.get_keys()}
for i in range(0, 12):
key_dims[f"h.{i}.mlp.c_proj.weight"] = 0
key_dims[f"h.{i}.mlp.c_fc.weight"] = 1
tensors = bufs.as_dict(key_dims)
with safe_open(input_files[0], framework="pt") as f:
for key in tensors.keys():
dim = key_dims[key]
if dim == 0 or dim == 1:
t = f.get_slice(key)
rank_slices = ()
shape = t.get_shape()
size = shape[dim]
block_size = (size + pg_paddle.process_group.size() - 1) // pg_paddle.process_group.size()
for i in range(0, len(shape)):
if i < dim:
rank_slices += (slice(None,None,None),)
elif i == dim:
rank_slices += (slice(pg_paddle.process_group.rank() * block_size, (pg_paddle.process_group.rank() + 1) * block_size, 1),)
break
t = t[rank_slices]
t = t.clone().detach()
else:
t = f.get_tensor(key)
assert paddle.all(paddle.to_tensor(t.numpy(),place=loader.device).equal(tensors[key]))
bufs.close()
loader.close()
Loading