Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-paddle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: pytest-log-paddle-${{ matrix.python-version }}
path: /tmp/pytest-log
path: /tmp/pytest-log
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ dist/
htmlcov/
.coverage
.coverage_*
.pytest_cache/
.pytest_cache/
.vscode
*.log
*.pyc
examples/paddle_case/log
Binary file added examples/paddle_case/a_paddle.safetensors
Binary file not shown.
Binary file added examples/paddle_case/b_paddle.safetensors
Binary file not shown.
6 changes: 6 additions & 0 deletions examples/paddle_case/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os
import paddle
t0 = paddle.concat([paddle.full((1,8), i, dtype=paddle.float16) for i in range(0, 16)], dim=0)
from safetensors.paddle import save_file
for file_prefix in ["a", "b"]:
save_file({f"{file_prefix}0": t0}, f"{file_prefix}_paddle.safetensors", metadata={"fst": "sample"})
32 changes: 32 additions & 0 deletions examples/paddle_case/run_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# !/usr/bin/env python3
# PIDS=()

# runner="python -m paddle.distributed.launch"

# cd paddle_case
# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=0 run_parallel.py &
# PIDS+=($!)
# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=1 run_parallel.py &
# PIDS+=($!)
# wait "${PIDS[@]}"

import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader
dist.init_parallel_env()
backend = "nccl" if paddle.is_compiled_with_cuda() else "gloo"
pg = dist.new_group(ranks=[0,1], backend=backend)
device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu"
loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True, framework="paddle")
loader.add_filenames({0: ["a_paddle.safetensors"], 1:["b_paddle.safetensors"]}) # {rank: files}

# load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU
fb = loader.copy_files_to_device()

# every rank must call get_tensor and get_sharded in the same order since they internally call paddle.distributed collective ops
tensor_a0 = fb.get_tensor(tensor_name="a0") # broadcast
tensor_b0_sharded = fb.get_sharded(tensor_name="b0", dim=1) # partition and scatter
print(f"RANK {pg.process_group.rank()}: tensor_a0={tensor_a0}")
print(f"RANK {pg.process_group.rank()}: tensor_b0_sharded={tensor_b0_sharded}")
fb.close()
loader.close()
13 changes: 13 additions & 0 deletions examples/paddle_case/run_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import paddle
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu"
loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework="paddle")
loader.add_filenames({0: ["a_paddle.safetensors", "b_paddle.safetensors"]}) # {rank: files}
fb = loader.copy_files_to_device()
tensor_a0 = fb.get_tensor(tensor_name="a0")
tensor_b0 = fb.get_tensor(tensor_name="b0")
print(f"a0: {tensor_a0}")
mycat = paddle.concat([tensor_a0, tensor_b0], axis=1)
print(f"cat: {mycat}, size={mycat.size}")
fb.close()
loader.close()
15 changes: 15 additions & 0 deletions examples/run_paddle_parrallel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# !/usr/bin/env python3
PIDS=()

runner="python -m paddle.distributed.launch"
# runner="torchrun"

cd paddle_case
rm -rf log
# one machine multy gpu (case : 1 machine 2 gpus)
# different to torch script because the paddle distributed use nccl to communicate in gpus
CUDA_VISIBLE_DEVICES=0 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=0 run_parallel.py &
PIDS+=($!)
CUDA_VISIBLE_DEVICES=1 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=1 run_parallel.py &
PIDS+=($!)
wait "${PIDS[@]}"
2 changes: 1 addition & 1 deletion examples/run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dist.barrier()
pg = dist.group.WORLD
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=True)
loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True)
loader.add_filenames({0: ["a.safetensors"], 1:["b.safetensors"]}) # {rank: files}

# load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU
Expand Down
2 changes: 1 addition & 1 deletion examples/run_reuse_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
sys.path.insert(0, "/nvme/manish/repos/fastsafetensors/fastsafetensors")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loader = SafeTensorsFileLoader(SingleGroup(), device)#, nogds=True, debug_log=True)
loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True)

loader.add_filenames({0: ["a.safetensors"]}) # {rank: files}
fb = loader.copy_files_to_device()
Expand Down
8 changes: 8 additions & 0 deletions examples/run_torch_parrallel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# !/usr/bin/env python3
PIDS=()

torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 run_parallel.py &
PIDS+=$($!)
torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 run_parallel.py &
PIDS+=$($!)
wait ${PIDS[@]}
111 changes: 82 additions & 29 deletions fastsafetensors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

import torch
try:
import paddle
from paddle.framework import core as paddle_core
paddle_loaded = True
except:
paddle_loaded = False
import os
import json
from collections import OrderedDict
Expand Down Expand Up @@ -32,9 +38,27 @@ def rank(self):
torch.float8_e4m3fn: torch.int8,
}

if paddle_loaded:
framework_index = {
"pytorch": 1,
"paddle": 2,
}
dtype_convert = {
'BOOL': (1, torch.bool, paddle.bool), 'U8': (1, torch.uint8, paddle.uint8), 'I8': (1, torch.int8, paddle.int8), 'F8_E5M2': (1, torch.float8_e5m2, paddle.float8_e5m2), 'F8_E4M3': (1, torch.float8_e4m3fn, paddle.float8_e4m3fn),
'I16': (2, torch.int16, paddle.int16), 'U16': (2, torch.int16, paddle.bfloat16), 'I32': (4, torch.int32, paddle.int32), 'U32': (4, torch.int32, paddle.int32), 'I64': (8, torch.int64, paddle.int64), 'U64': (8, torch.int64, paddle.int64),
'F16': (2, torch.float16, paddle.float16), 'BF16': (2, torch.bfloat16, paddle.bfloat16), 'F32': (4, torch.float32, paddle.float32), 'F64': (8, torch.float64, paddle.float64)
}

need_workaround_dtypes = {
torch.float8_e5m2: torch.int8,
torch.float8_e4m3fn: torch.int8,
paddle.float8_e5m2 : paddle.int8,
paddle.float8_e4m3fn : paddle.int8
}

def str_to_dtype(dtype_str: str, framework: str="pytorch")->torch.dtype:
if framework != "pytorch":
raise NotImplementedError(f"str_to_dtype: Not implemented for other frameworks than pytorch")
if framework not in framework_index.keys():
raise NotImplementedError(f"str_to_dtype: Not implemented for other frameworks than {framework_index.keys()}")
if dtype_str not in dtype_convert:
raise ValueError(f"str_to_dtype: Not supported dtype: {dtype_str}")
return dtype_convert[dtype_str][framework_index[framework]]
Expand All @@ -53,24 +77,31 @@ def get_device_numa_node(device: int):
with open(syspath) as f:
return int(f.read().strip())

def alloc_tensor_memory(length: int, dev: torch.device)->fstcpp.gds_device_buffer:
if dev.type == 'cuda':
def alloc_tensor_memory(length: int, dev: torch.device, framework: str="pytorch")->fstcpp.gds_device_buffer:
dev_is_gpu = True
if framework == "pytorch" and dev.type == 'cuda':
rbuf = torch.cuda.caching_allocator_alloc(length)
elif paddle_loaded and framework == "paddle" and "gpu" in dev:
rbuf = fstcpp.gpu_malloc(length)
else:
dev_is_gpu = False
rbuf = fstcpp.cpu_malloc(length)
return fstcpp.gds_device_buffer(rbuf, length, dev.type == 'cuda')
return fstcpp.gds_device_buffer(rbuf, length, dev_is_gpu)

def free_tensor_memory(gbuf: fstcpp.gds_device_buffer, dev: torch.device):
if dev.type == 'cuda':
def free_tensor_memory(gbuf: fstcpp.gds_device_buffer, dev: torch.device, framework: str="pytorch"):
if framework =="pytorch" and dev.type == 'cuda':
rbuf = torch.cuda.caching_allocator_delete(gbuf.get_base_address())
elif paddle_loaded and framework == "paddle" and "gpu" in dev:
rbuf = fstcpp.gpu_free(gbuf.get_base_address())
else:
rbuf = fstcpp.cpu_free(gbuf.get_base_address())
return rbuf


class SafeTensorsMetadata:
def __init__(self, string: str, header_length: int, size_bytes: int, src: str="", keep_orig_dict: bool=False):
def __init__(self, string: str, header_length: int, size_bytes: int, src: str="", keep_orig_dict: bool=False, framework: str="pytorch"):
self.src = src
self.framework = framework
ser = json.loads(string, object_pairs_hook=OrderedDict)
self.metadata = ser.get('__metadata__', '')
if self.metadata:
Expand All @@ -83,7 +114,7 @@ def __init__(self, string: str, header_length: int, size_bytes: int, src: str=""

start = 0
for _, (k, buffer) in enumerate(sorted(ser.items(), key=lambda x: x[1]['data_offsets'][0])):
t = TensorFrame.from_buffer(buffer)
t = TensorFrame.from_buffer(buffer, self.framework)
self.tensors[k] = t
# validation
s, e = t.data_offsets
Expand All @@ -95,7 +126,11 @@ def __init__(self, string: str, header_length: int, size_bytes: int, src: str=""
nelements = 1
for sh in t.shape:
nelements *= sh
nbytes = nelements * t.dtype.itemsize
if self.framework == "pytorch":
t_dtype_size = t.dtype.itemsize
elif paddle_loaded and self.framework == "paddle":
t_dtype_size = paddle_core.size_of_dtype(t.dtype)
nbytes = nelements * t_dtype_size
if (e - s) != nbytes:
raise Exception(f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}")
self.size_bytes = size_bytes
Expand All @@ -120,7 +155,7 @@ def from_buffer(self, buf: int, buffer_len: int, filename: str):
return SafeTensorsMetadata(string, n + 8, buffer_len)

@classmethod
def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False):
def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False, framework: str="pytorch"):
status = os.fstat(fd)
buffer_len = status.st_size
if buffer_len < 8:
Expand All @@ -136,33 +171,49 @@ def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False):
# NOTE: Add when we move to 0.4.0
#if string.startswith('{'):
# raise Exception(f"{filename}: InvalidHeaderStart")
return SafeTensorsMetadata(string, n + 8, buffer_len, filename, keep_orig_dict=keep_orig_dict)
return SafeTensorsMetadata(string, n + 8, buffer_len, filename, keep_orig_dict=keep_orig_dict, framework=framework)

@classmethod
def from_file(self, filename: str):
def from_file(self, filename: str, framework: str="pytorch"):
fd = os.open(filename, os.O_RDONLY, 0o644)
ret = self.from_fd(fd, filename, keep_orig_dict=False)
ret = self.from_fd(fd, filename, keep_orig_dict=False, framework=framework)
os.close(fd)
return ret

def get_tensors(self, gbuf: fstcpp.gds_device_buffer, device: torch.device, copy_start_offset: int, dtype: torch.dtype=None) -> Dict[str, torch.Tensor]:
ret = {}
for tensor_name, t in self.tensors.items():
dst_dev_ptr = gbuf.get_base_address() + self.header_length + t.data_offsets[0]-copy_start_offset
if t.dtype in need_workaround_dtypes:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)).view(t.dtype)
else:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device))
if dtype is not None and dtype != t.dtype:
if dtype.itemsize > t.dtype.itemsize:
raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})")
t3 = t2.to(dtype=dtype)
if dtype in need_workaround_dtypes:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)).view(dtype)
if self.framework == "pytorch":
if t.dtype in need_workaround_dtypes:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)).view(t.dtype)
else:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device))
if dtype is not None and dtype != t.dtype:
if dtype.itemsize > t.dtype.itemsize:
raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})")
t3 = t2.to(dtype=dtype)
if dtype in need_workaround_dtypes:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)).view(dtype)
else:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device))
t2.copy_(t3)
self.tensors[tensor_name].dtype = dtype
elif paddle_loaded and self.framework == "paddle":
if t.dtype in need_workaround_dtypes:
t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device))
else:
t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device))
t2.copy_(t3)
self.tensors[tensor_name].dtype = dtype
t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device))
if dtype is not None and dtype != t.dtype:
if paddle_core.size_of_dtype(dtype) > paddle_core.size_of_dtype(t.dtype):
raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})")
t3 = t2.to(dtype=dtype)
if t.dtype in need_workaround_dtypes:
t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device))
else:
t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device))
paddle.assign(t3, output=t2)
self.tensors[tensor_name].dtype = dtype
ret[tensor_name] = t2
return ret

Expand All @@ -180,8 +231,10 @@ def __init__(self, dtype: torch.dtype, shape: torch.Size, data_offsets: List[int

@classmethod
def from_buffer(self, entry: OrderedDict[str, List[int]], framework:str="pytorch"):
dtype = str_to_dtype(entry['dtype'])
shape = torch.Size(entry['shape'])
dtype = str_to_dtype(entry['dtype'], framework=framework)
shape = entry['shape']
if framework == "pytorch":
shape = torch.Size(shape)
data_offsets = list(entry['data_offsets'])
strides = []
offsets = []
Expand Down
25 changes: 19 additions & 6 deletions fastsafetensors/copier/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
from .. import cpp as fstcpp
from typing import Dict
from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN
from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, paddle_loaded
if paddle_loaded:
import paddle

class GdsFileCopier:
def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.gds_file_reader, debug_log: bool=False):
Expand All @@ -16,13 +18,24 @@ def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader:
self.fh = 0
self.copy_reqs: Dict[int, int] = {}
self.aligned_length = 0
self.o_direct = False
try:
if self.metadata.framework == "pytorch":
cuda_vers_list = torch.version.cuda.split('.')
elif paddle_loaded and self.metadata.framework == "paddle":
cuda_vers_list = paddle.version.cuda().split('.')
cudavers = list(map(int, cuda_vers_list))
# CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
# Compatible with CUDA 11.x
self.o_direct = not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
except:
self.o_direct = True

def set_o_direct(self, enable: bool):
self.o_direct = enable

def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer:
self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, self.device.type == 'cuda')
dev_is_cuda = (self.metadata.framework == "pytorch" and self.device.type == 'cuda') or (paddle_loaded and self.metadata.framework == "paddle" and "gpu" in self.device)
self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda)
offset = self.metadata.header_length
length = self.metadata.size_bytes - self.metadata.header_length
head_bytes = offset % ALIGN
Expand All @@ -34,7 +47,7 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd
aligned_length = length + head_bytes
aligned_offset = offset - head_bytes

gbuf = alloc_tensor_memory(aligned_length, self.device)
gbuf = alloc_tensor_memory(aligned_length, self.device, self.metadata.framework)
if use_buf_register:
count = 0
while count < aligned_length:
Expand Down Expand Up @@ -75,7 +88,7 @@ def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noali
if not noalign and not self.metadata.aligned and self.aligned_length > 0:
misaligned_bytes = self.metadata.header_length % CUDA_PTR_ALIGN
length = 1024*1024*1024
tmp_gbuf = alloc_tensor_memory(length, self.device)
tmp_gbuf = alloc_tensor_memory(length, self.device, self.metadata.framework)
count = 0
while count + misaligned_bytes < self.aligned_length:
l = self.aligned_length - misaligned_bytes - count
Expand All @@ -85,6 +98,6 @@ def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noali
print("wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format(gbuf.get_base_address(), misaligned_bytes, count, tmp_gbuf.get_base_address()))
gbuf.memmove(count, misaligned_bytes + count, tmp_gbuf, l)
count += l
free_tensor_memory(tmp_gbuf, self.device)
free_tensor_memory(tmp_gbuf, self.device, self.metadata.framework)
self.aligned_offset += misaligned_bytes
return self.metadata.get_tensors(gbuf, self.device, self.aligned_offset, dtype=dtype)
2 changes: 1 addition & 1 deletion fastsafetensors/copier/nogds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader:

def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer:
total_length = self.metadata.size_bytes - self.metadata.header_length
gbuf = alloc_tensor_memory(total_length, self.device)
gbuf = alloc_tensor_memory(total_length, self.device, self.metadata.framework)
count = 0
while count < total_length:
l = total_length - count
Expand Down
Loading
Loading