diff --git a/.github/workflows/test-paddle.yaml b/.github/workflows/test-paddle.yaml index 3cb7881..e517fd1 100644 --- a/.github/workflows/test-paddle.yaml +++ b/.github/workflows/test-paddle.yaml @@ -55,4 +55,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: pytest-log-paddle-${{ matrix.python-version }} - path: /tmp/pytest-log \ No newline at end of file + path: /tmp/pytest-log diff --git a/.gitignore b/.gitignore index 3ebed4a..acc54b9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,8 @@ dist/ htmlcov/ .coverage .coverage_* -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +.vscode +*.log +*.pyc +examples/paddle_case/log \ No newline at end of file diff --git a/examples/paddle_case/a_paddle.safetensors b/examples/paddle_case/a_paddle.safetensors new file mode 100644 index 0000000..345a469 Binary files /dev/null and b/examples/paddle_case/a_paddle.safetensors differ diff --git a/examples/paddle_case/b_paddle.safetensors b/examples/paddle_case/b_paddle.safetensors new file mode 100644 index 0000000..372c037 Binary files /dev/null and b/examples/paddle_case/b_paddle.safetensors differ diff --git a/examples/paddle_case/gen.py b/examples/paddle_case/gen.py new file mode 100644 index 0000000..8ca69fd --- /dev/null +++ b/examples/paddle_case/gen.py @@ -0,0 +1,6 @@ +import os +import paddle +t0 = paddle.concat([paddle.full((1,8), i, dtype=paddle.float16) for i in range(0, 16)], dim=0) +from safetensors.paddle import save_file +for file_prefix in ["a", "b"]: + save_file({f"{file_prefix}0": t0}, f"{file_prefix}_paddle.safetensors", metadata={"fst": "sample"}) diff --git a/examples/paddle_case/run_parallel.py b/examples/paddle_case/run_parallel.py new file mode 100644 index 0000000..59c273b --- /dev/null +++ b/examples/paddle_case/run_parallel.py @@ -0,0 +1,32 @@ +# !/usr/bin/env python3 +# PIDS=() + +# runner="python -m paddle.distributed.launch" + +# cd paddle_case +# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=0 run_parallel.py & +# PIDS+=($!) +# ${runner} --nnodes=2 --master=127.0.0.1:12345 --rank=1 run_parallel.py & +# PIDS+=($!) +# wait "${PIDS[@]}" + +import paddle +import paddle.distributed as dist +from fastsafetensors import SafeTensorsFileLoader +dist.init_parallel_env() +backend = "nccl" if paddle.is_compiled_with_cuda() else "gloo" +pg = dist.new_group(ranks=[0,1], backend=backend) +device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu" +loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True, framework="paddle") +loader.add_filenames({0: ["a_paddle.safetensors"], 1:["b_paddle.safetensors"]}) # {rank: files} + +# load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU +fb = loader.copy_files_to_device() + +# every rank must call get_tensor and get_sharded in the same order since they internally call paddle.distributed collective ops +tensor_a0 = fb.get_tensor(tensor_name="a0") # broadcast +tensor_b0_sharded = fb.get_sharded(tensor_name="b0", dim=1) # partition and scatter +print(f"RANK {pg.process_group.rank()}: tensor_a0={tensor_a0}") +print(f"RANK {pg.process_group.rank()}: tensor_b0_sharded={tensor_b0_sharded}") +fb.close() +loader.close() diff --git a/examples/paddle_case/run_single.py b/examples/paddle_case/run_single.py new file mode 100644 index 0000000..e7679a6 --- /dev/null +++ b/examples/paddle_case/run_single.py @@ -0,0 +1,13 @@ +import paddle +from fastsafetensors import SafeTensorsFileLoader, SingleGroup +device = "gpu:0" if paddle.is_compiled_with_cuda() else "cpu" +loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework="paddle") +loader.add_filenames({0: ["a_paddle.safetensors", "b_paddle.safetensors"]}) # {rank: files} +fb = loader.copy_files_to_device() +tensor_a0 = fb.get_tensor(tensor_name="a0") +tensor_b0 = fb.get_tensor(tensor_name="b0") +print(f"a0: {tensor_a0}") +mycat = paddle.concat([tensor_a0, tensor_b0], axis=1) +print(f"cat: {mycat}, size={mycat.size}") +fb.close() +loader.close() diff --git a/examples/run_paddle_parrallel.sh b/examples/run_paddle_parrallel.sh new file mode 100755 index 0000000..108242e --- /dev/null +++ b/examples/run_paddle_parrallel.sh @@ -0,0 +1,15 @@ +# !/usr/bin/env python3 +PIDS=() + +runner="python -m paddle.distributed.launch" +# runner="torchrun" + +cd paddle_case +rm -rf log +# one machine multy gpu (case : 1 machine 2 gpus) +# different to torch script because the paddle distributed use nccl to communicate in gpus +CUDA_VISIBLE_DEVICES=0 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=0 run_parallel.py & +PIDS+=($!) +CUDA_VISIBLE_DEVICES=1 ${runner} --nnodes=2 --master=127.0.0.1:8800 --rank=1 run_parallel.py & +PIDS+=($!) +wait "${PIDS[@]}" \ No newline at end of file diff --git a/examples/run_parallel.py b/examples/run_parallel.py index 0d7fd53..eb25aae 100644 --- a/examples/run_parallel.py +++ b/examples/run_parallel.py @@ -14,7 +14,7 @@ dist.barrier() pg = dist.group.WORLD device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=True) +loader = SafeTensorsFileLoader(pg, device, nogds=False, debug_log=True) loader.add_filenames({0: ["a.safetensors"], 1:["b.safetensors"]}) # {rank: files} # load a.safetensors to rank 0 GPU and b.safetensors to rank 1 GPU diff --git a/examples/run_reuse_loader.py b/examples/run_reuse_loader.py index d37358c..de15f83 100644 --- a/examples/run_reuse_loader.py +++ b/examples/run_reuse_loader.py @@ -5,7 +5,7 @@ sys.path.insert(0, "/nvme/manish/repos/fastsafetensors/fastsafetensors") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -loader = SafeTensorsFileLoader(SingleGroup(), device)#, nogds=True, debug_log=True) +loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True) loader.add_filenames({0: ["a.safetensors"]}) # {rank: files} fb = loader.copy_files_to_device() diff --git a/examples/run_torch_parrallel.sh b/examples/run_torch_parrallel.sh new file mode 100755 index 0000000..9ca6d11 --- /dev/null +++ b/examples/run_torch_parrallel.sh @@ -0,0 +1,8 @@ +# !/usr/bin/env python3 +PIDS=() + +torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 run_parallel.py & +PIDS+=$($!) +torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 run_parallel.py & +PIDS+=$($!) +wait ${PIDS[@]} \ No newline at end of file diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index b4de0f6..b8ac341 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -2,6 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import torch +try: + import paddle + from paddle.framework import core as paddle_core + paddle_loaded = True +except: + paddle_loaded = False import os import json from collections import OrderedDict @@ -32,9 +38,27 @@ def rank(self): torch.float8_e4m3fn: torch.int8, } +if paddle_loaded: + framework_index = { + "pytorch": 1, + "paddle": 2, + } + dtype_convert = { + 'BOOL': (1, torch.bool, paddle.bool), 'U8': (1, torch.uint8, paddle.uint8), 'I8': (1, torch.int8, paddle.int8), 'F8_E5M2': (1, torch.float8_e5m2, paddle.float8_e5m2), 'F8_E4M3': (1, torch.float8_e4m3fn, paddle.float8_e4m3fn), + 'I16': (2, torch.int16, paddle.int16), 'U16': (2, torch.int16, paddle.bfloat16), 'I32': (4, torch.int32, paddle.int32), 'U32': (4, torch.int32, paddle.int32), 'I64': (8, torch.int64, paddle.int64), 'U64': (8, torch.int64, paddle.int64), + 'F16': (2, torch.float16, paddle.float16), 'BF16': (2, torch.bfloat16, paddle.bfloat16), 'F32': (4, torch.float32, paddle.float32), 'F64': (8, torch.float64, paddle.float64) + } + + need_workaround_dtypes = { + torch.float8_e5m2: torch.int8, + torch.float8_e4m3fn: torch.int8, + paddle.float8_e5m2 : paddle.int8, + paddle.float8_e4m3fn : paddle.int8 + } + def str_to_dtype(dtype_str: str, framework: str="pytorch")->torch.dtype: - if framework != "pytorch": - raise NotImplementedError(f"str_to_dtype: Not implemented for other frameworks than pytorch") + if framework not in framework_index.keys(): + raise NotImplementedError(f"str_to_dtype: Not implemented for other frameworks than {framework_index.keys()}") if dtype_str not in dtype_convert: raise ValueError(f"str_to_dtype: Not supported dtype: {dtype_str}") return dtype_convert[dtype_str][framework_index[framework]] @@ -53,24 +77,31 @@ def get_device_numa_node(device: int): with open(syspath) as f: return int(f.read().strip()) -def alloc_tensor_memory(length: int, dev: torch.device)->fstcpp.gds_device_buffer: - if dev.type == 'cuda': +def alloc_tensor_memory(length: int, dev: torch.device, framework: str="pytorch")->fstcpp.gds_device_buffer: + dev_is_gpu = True + if framework == "pytorch" and dev.type == 'cuda': rbuf = torch.cuda.caching_allocator_alloc(length) + elif paddle_loaded and framework == "paddle" and "gpu" in dev: + rbuf = fstcpp.gpu_malloc(length) else: + dev_is_gpu = False rbuf = fstcpp.cpu_malloc(length) - return fstcpp.gds_device_buffer(rbuf, length, dev.type == 'cuda') + return fstcpp.gds_device_buffer(rbuf, length, dev_is_gpu) -def free_tensor_memory(gbuf: fstcpp.gds_device_buffer, dev: torch.device): - if dev.type == 'cuda': +def free_tensor_memory(gbuf: fstcpp.gds_device_buffer, dev: torch.device, framework: str="pytorch"): + if framework =="pytorch" and dev.type == 'cuda': rbuf = torch.cuda.caching_allocator_delete(gbuf.get_base_address()) + elif paddle_loaded and framework == "paddle" and "gpu" in dev: + rbuf = fstcpp.gpu_free(gbuf.get_base_address()) else: rbuf = fstcpp.cpu_free(gbuf.get_base_address()) return rbuf class SafeTensorsMetadata: - def __init__(self, string: str, header_length: int, size_bytes: int, src: str="", keep_orig_dict: bool=False): + def __init__(self, string: str, header_length: int, size_bytes: int, src: str="", keep_orig_dict: bool=False, framework: str="pytorch"): self.src = src + self.framework = framework ser = json.loads(string, object_pairs_hook=OrderedDict) self.metadata = ser.get('__metadata__', '') if self.metadata: @@ -83,7 +114,7 @@ def __init__(self, string: str, header_length: int, size_bytes: int, src: str="" start = 0 for _, (k, buffer) in enumerate(sorted(ser.items(), key=lambda x: x[1]['data_offsets'][0])): - t = TensorFrame.from_buffer(buffer) + t = TensorFrame.from_buffer(buffer, self.framework) self.tensors[k] = t # validation s, e = t.data_offsets @@ -95,7 +126,11 @@ def __init__(self, string: str, header_length: int, size_bytes: int, src: str="" nelements = 1 for sh in t.shape: nelements *= sh - nbytes = nelements * t.dtype.itemsize + if self.framework == "pytorch": + t_dtype_size = t.dtype.itemsize + elif paddle_loaded and self.framework == "paddle": + t_dtype_size = paddle_core.size_of_dtype(t.dtype) + nbytes = nelements * t_dtype_size if (e - s) != nbytes: raise Exception(f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}") self.size_bytes = size_bytes @@ -120,7 +155,7 @@ def from_buffer(self, buf: int, buffer_len: int, filename: str): return SafeTensorsMetadata(string, n + 8, buffer_len) @classmethod - def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False): + def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False, framework: str="pytorch"): status = os.fstat(fd) buffer_len = status.st_size if buffer_len < 8: @@ -136,12 +171,12 @@ def from_fd(self, fd: int, filename: str, keep_orig_dict: bool=False): # NOTE: Add when we move to 0.4.0 #if string.startswith('{'): # raise Exception(f"{filename}: InvalidHeaderStart") - return SafeTensorsMetadata(string, n + 8, buffer_len, filename, keep_orig_dict=keep_orig_dict) + return SafeTensorsMetadata(string, n + 8, buffer_len, filename, keep_orig_dict=keep_orig_dict, framework=framework) @classmethod - def from_file(self, filename: str): + def from_file(self, filename: str, framework: str="pytorch"): fd = os.open(filename, os.O_RDONLY, 0o644) - ret = self.from_fd(fd, filename, keep_orig_dict=False) + ret = self.from_fd(fd, filename, keep_orig_dict=False, framework=framework) os.close(fd) return ret @@ -149,20 +184,36 @@ def get_tensors(self, gbuf: fstcpp.gds_device_buffer, device: torch.device, copy ret = {} for tensor_name, t in self.tensors.items(): dst_dev_ptr = gbuf.get_base_address() + self.header_length + t.data_offsets[0]-copy_start_offset - if t.dtype in need_workaround_dtypes: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)).view(t.dtype) - else: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device)) - if dtype is not None and dtype != t.dtype: - if dtype.itemsize > t.dtype.itemsize: - raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})") - t3 = t2.to(dtype=dtype) - if dtype in need_workaround_dtypes: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)).view(dtype) + if self.framework == "pytorch": + if t.dtype in need_workaround_dtypes: + t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)).view(t.dtype) + else: + t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device)) + if dtype is not None and dtype != t.dtype: + if dtype.itemsize > t.dtype.itemsize: + raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})") + t3 = t2.to(dtype=dtype) + if dtype in need_workaround_dtypes: + t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)).view(dtype) + else: + t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device)) + t2.copy_(t3) + self.tensors[tensor_name].dtype = dtype + elif paddle_loaded and self.framework == "paddle": + if t.dtype in need_workaround_dtypes: + t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[t.dtype], device)) else: - t2 = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device)) - t2.copy_(t3) - self.tensors[tensor_name].dtype = dtype + t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, t.dtype, device)) + if dtype is not None and dtype != t.dtype: + if paddle_core.size_of_dtype(dtype) > paddle_core.size_of_dtype(t.dtype): + raise Exception(f"Online type conversion to larger sizes is not supported ({t.dtype} -> {dtype})") + t3 = t2.to(dtype=dtype) + if t.dtype in need_workaround_dtypes: + t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, need_workaround_dtypes[dtype], device)) + else: + t2 = paddle.utils.dlpack.from_dlpack(from_cuda_buffer(dst_dev_ptr, t.shape, t.strides, dtype, device)) + paddle.assign(t3, output=t2) + self.tensors[tensor_name].dtype = dtype ret[tensor_name] = t2 return ret @@ -180,8 +231,10 @@ def __init__(self, dtype: torch.dtype, shape: torch.Size, data_offsets: List[int @classmethod def from_buffer(self, entry: OrderedDict[str, List[int]], framework:str="pytorch"): - dtype = str_to_dtype(entry['dtype']) - shape = torch.Size(entry['shape']) + dtype = str_to_dtype(entry['dtype'], framework=framework) + shape = entry['shape'] + if framework == "pytorch": + shape = torch.Size(shape) data_offsets = list(entry['data_offsets']) strides = [] offsets = [] diff --git a/fastsafetensors/copier/gds.py b/fastsafetensors/copier/gds.py index 6473f33..65a9b36 100644 --- a/fastsafetensors/copier/gds.py +++ b/fastsafetensors/copier/gds.py @@ -4,7 +4,9 @@ import torch from .. import cpp as fstcpp from typing import Dict -from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN +from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, paddle_loaded +if paddle_loaded: + import paddle class GdsFileCopier: def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.gds_file_reader, debug_log: bool=False): @@ -16,13 +18,24 @@ def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: self.fh = 0 self.copy_reqs: Dict[int, int] = {} self.aligned_length = 0 - self.o_direct = False + try: + if self.metadata.framework == "pytorch": + cuda_vers_list = torch.version.cuda.split('.') + elif paddle_loaded and self.metadata.framework == "paddle": + cuda_vers_list = paddle.version.cuda().split('.') + cudavers = list(map(int, cuda_vers_list)) + # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors + # Compatible with CUDA 11.x + self.o_direct = not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)) + except: + self.o_direct = True def set_o_direct(self, enable: bool): self.o_direct = enable def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer: - self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, self.device.type == 'cuda') + dev_is_cuda = (self.metadata.framework == "pytorch" and self.device.type == 'cuda') or (paddle_loaded and self.metadata.framework == "paddle" and "gpu" in self.device) + self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda) offset = self.metadata.header_length length = self.metadata.size_bytes - self.metadata.header_length head_bytes = offset % ALIGN @@ -34,7 +47,7 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd aligned_length = length + head_bytes aligned_offset = offset - head_bytes - gbuf = alloc_tensor_memory(aligned_length, self.device) + gbuf = alloc_tensor_memory(aligned_length, self.device, self.metadata.framework) if use_buf_register: count = 0 while count < aligned_length: @@ -75,7 +88,7 @@ def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noali if not noalign and not self.metadata.aligned and self.aligned_length > 0: misaligned_bytes = self.metadata.header_length % CUDA_PTR_ALIGN length = 1024*1024*1024 - tmp_gbuf = alloc_tensor_memory(length, self.device) + tmp_gbuf = alloc_tensor_memory(length, self.device, self.metadata.framework) count = 0 while count + misaligned_bytes < self.aligned_length: l = self.aligned_length - misaligned_bytes - count @@ -85,6 +98,6 @@ def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noali print("wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format(gbuf.get_base_address(), misaligned_bytes, count, tmp_gbuf.get_base_address())) gbuf.memmove(count, misaligned_bytes + count, tmp_gbuf, l) count += l - free_tensor_memory(tmp_gbuf, self.device) + free_tensor_memory(tmp_gbuf, self.device, self.metadata.framework) self.aligned_offset += misaligned_bytes return self.metadata.get_tensors(gbuf, self.device, self.aligned_offset, dtype=dtype) diff --git a/fastsafetensors/copier/nogds.py b/fastsafetensors/copier/nogds.py index 1c619c2..f8359b7 100644 --- a/fastsafetensors/copier/nogds.py +++ b/fastsafetensors/copier/nogds.py @@ -20,7 +20,7 @@ def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer: total_length = self.metadata.size_bytes - self.metadata.header_length - gbuf = alloc_tensor_memory(total_length, self.device) + gbuf = alloc_tensor_memory(total_length, self.device, self.metadata.framework) count = 0 while count < total_length: l = total_length - count diff --git a/fastsafetensors/cpp/ext.cpp b/fastsafetensors/cpp/ext.cpp index c524a18..c88e5f2 100644 --- a/fastsafetensors/cpp/ext.cpp +++ b/fastsafetensors/cpp/ext.cpp @@ -135,7 +135,9 @@ static void load_nvidia_functions() { mydlsym(&cuda_fns.cudaHostAlloc, handle_cudart, "cudaHostAlloc"); mydlsym(&cuda_fns.cudaFreeHost, handle_cudart, "cudaFreeHost"); mydlsym(&cuda_fns.cudaDeviceGetPCIBusId, handle_cudart, "cudaDeviceGetPCIBusId"); - bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost && cuda_fns.cudaDeviceGetPCIBusId; + mydlsym(&cuda_fns.cudaDeviceMalloc, handle_cudart, "cudaMalloc"); + mydlsym(&cuda_fns.cudaDeviceFree, handle_cudart, "cudaFree"); + bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc && cuda_fns.cudaDeviceFree; if (!success) { cuda_found = false; if (init_log) { @@ -349,6 +351,18 @@ void cpu_free(uintptr_t addr) { free(p); } +uintptr_t gpu_malloc(uint64_t length) { + void *p; + if (cuda_fns.cudaDeviceMalloc(&p, length) != cudaSuccess) { + return 0; + } + return reinterpret_cast(p); +} + +void gpu_free(uintptr_t addr) { + cuda_fns.cudaDeviceFree(reinterpret_cast(addr)); +} + const int gds_device_buffer::cufile_register(uint64_t offset, uint64_t length) { CUfileError_t err; void * dst = reinterpret_cast(this->_devPtr_base->get_uintptr() + offset); @@ -722,6 +736,8 @@ PYBIND11_MODULE(__MOD_NAME__, m) m.def("read_buffer", &read_buffer); m.def("cpu_malloc", &cpu_malloc); m.def("cpu_free", &cpu_free); + m.def("gpu_malloc", &gpu_malloc); + m.def("gpu_free", &gpu_free); m.def("load_nvidia_functions", &load_nvidia_functions); pybind11::class_(m, "gds_device_buffer") diff --git a/fastsafetensors/cpp/ext.hpp b/fastsafetensors/cpp/ext.hpp index 66b9530..acd82d7 100644 --- a/fastsafetensors/cpp/ext.hpp +++ b/fastsafetensors/cpp/ext.hpp @@ -193,6 +193,8 @@ typedef struct ext_funcs { cudaError_t (*cudaHostAlloc)(void **, size_t, unsigned int); cudaError_t (*cudaFreeHost)(void *); cudaError_t (*cudaDeviceGetPCIBusId)(char *, int, int); + cudaError_t (*cudaDeviceMalloc)(void **, size_t); + cudaError_t (*cudaDeviceFree)(void *); int (*numa_run_on_node)(int); } ext_funcs_t; diff --git a/fastsafetensors/dlpack.py b/fastsafetensors/dlpack.py index 58b2c32..f9dc820 100644 --- a/fastsafetensors/dlpack.py +++ b/fastsafetensors/dlpack.py @@ -6,16 +6,29 @@ import ctypes import torch +from .common import paddle_loaded from typing import List +if paddle_loaded: + import paddle _c_str_dltensor = b"dltensor" class DLDevice(ctypes.Structure): def __init__(self, device: torch.device): - self.device_type = self.TYPE_MAP[device.type] - self.device_id = 0 - if device.index: - self.device_id = device.index + if isinstance(device, str): + self.device_id = 0 + if device == "cpu": + self.device_type = self.TYPE_MAP[device] + else: + device = device.split(":") + if len(device) == 2: + self.device_id = int(device[1]) + self.device_type = self.TYPE_MAP[device[0]] + else: + self.device_type = self.TYPE_MAP[device.type] + self.device_id = 0 + if device.index: + self.device_id = device.index kDLCPU = 1 kDLCUDA = 2 @@ -26,6 +39,7 @@ def __init__(self, device: torch.device): TYPE_MAP= { "cpu": kDLCPU, "cuda": kDLCUDA, + "gpu": kDLCUDA } @@ -63,7 +77,30 @@ class DLDataType(ctypes.Structure): torch.float64: (2, 64, 1), torch.bfloat16: (4, 16, 1), } - + if paddle_loaded: + TYPE_MAP = { + torch.bool: (6, 8, 1), + torch.int8: (0, 8, 1), + torch.int16: (0, 16, 1), + torch.int32: (0, 32, 1), + torch.int: (0, 32, 1), + torch.int64: (0, 64, 1), + torch.uint8: (1, 8, 1), + torch.float16: (2, 16, 1), + torch.float32: (2, 32, 1), + torch.float64: (2, 64, 1), + torch.bfloat16: (4, 16, 1), + paddle.bool: (6, 8, 1), + paddle.int8: (0, 8, 1), + paddle.int16: (0, 16, 1), + paddle.int32: (0, 32, 1), + paddle.int64: (0, 64, 1), + paddle.uint8: (1, 8, 1), + paddle.float16: (2, 16, 1), + paddle.float32: (2, 32, 1), + paddle.float64: (2, 64, 1), + paddle.bfloat16: (4, 16, 1), + } class DLTensor(ctypes.Structure): _fields_ = [ @@ -181,4 +218,4 @@ def from_cuda_buffer(dev_ptr: int, shape: List[int], strides: List[int], dtype: _c_str_dltensor, _numpy_pycapsule_deleter, ) - return pycapsule \ No newline at end of file + return pycapsule diff --git a/fastsafetensors/file_buffer.py b/fastsafetensors/file_buffer.py index 2d66252..9244585 100644 --- a/fastsafetensors/file_buffer.py +++ b/fastsafetensors/file_buffer.py @@ -8,6 +8,9 @@ from collections import OrderedDict from .tensor_factory import LazyTensorFactory +from .common import SingleGroup, paddle_loaded +if paddle_loaded: + import paddle class FilesBufferOnDevice: r""" Device buffer for .safetensors files. @@ -27,7 +30,7 @@ class FilesBufferOnDevice: Examples: See examples/run_single.py and examples/run_parallel.py. """ - def __init__(self, rank_loaders: Dict[int, List[LazyTensorFactory]], pg: dist.ProcessGroup, auto_mem_delete=True): + def __init__(self, rank_loaders: Dict[int, List[LazyTensorFactory]], pg: dist.ProcessGroup, auto_mem_delete=True, framework="pytorch"): self.rank_loaders: Dict[int, List[LazyTensorFactory]] = rank_loaders self.key_to_rank_lidx: Dict[str, Tuple[int, int]] = {} self.instantiated: Dict[int, Dict[int, Dict[str, bool]]] = {} # rank, key name @@ -39,8 +42,14 @@ def __init__(self, rank_loaders: Dict[int, List[LazyTensorFactory]], pg: dist.Pr raise Exception(f"FilesBufferOnDevice: key {key} must be unique among files") self.key_to_rank_lidx[key] = (rank, lidx) self.instantiated[rank][lidx] = {} - self.auto_mem_delete = auto_mem_delete and pg.size() > 1 - self.pg = pg + self.framework = framework + if self.framework == "pytorch" or isinstance(pg, SingleGroup): + self.pg = pg + self.group = None + elif paddle_loaded and self.framework == "paddle": + self.pg = pg.process_group + self.group = pg + self.auto_mem_delete = auto_mem_delete and self.pg.size() > 1 def close(self): for _, loaders in self.rank_loaders.items(): @@ -84,7 +93,7 @@ def get_sharded(self, tensor_name: str, dim: int, device: torch.device=None, dty A special dim is -1, which broadcast a tensor to all the ranks (== get_tensor()). """ (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].shuffle(self.pg, tensor_name, dim) + t = self.rank_loaders[rank][lidix].shuffle(self.pg, tensor_name, dim, group=self.group) return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) def get_tensor(self, tensor_name: str, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: @@ -104,12 +113,12 @@ def push_tensor(self, tensor_name: str, dst_rank: int, device: torch.device=Non Other ranks do nothing. """ (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].push(self.pg, tensor_name, dst_rank, rank) + t = self.rank_loaders[rank][lidix].push(self.pg, tensor_name, dst_rank, rank, group=self.group) return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) def get_sharded_packed_qkv(self, tensor_name: str, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: (rank, lidix) = self._get_rank_lidx(tensor_name) - t = self.rank_loaders[rank][lidix].shuffle_packed_qkv(self.pg, tensor_name) + t = self.rank_loaders[rank][lidix].shuffle_packed_qkv(self.pg, tensor_name, group=self.group) return self._get_tensor(rank, lidix, tensor_name, t, device, dtype) def get_multi_cols(self, tensor_names: List[str], dim: int, device: torch.device=None, dtype: torch.dtype=None)->torch.Tensor: @@ -122,7 +131,7 @@ def get_multi_cols(self, tensor_names: List[str], dim: int, device: torch.device rank_lidixs[ranklidx] = [tensor_name] ts: List[torch.Tensor] = [] for (rank, lidix), tns in sorted(rank_lidixs.items(), key=lambda x:x[0]): - ts.append(self.rank_loaders[rank][lidix].shuffle_multi_cols(self.pg, tns, dim)) + ts.append(self.rank_loaders[rank][lidix].shuffle_multi_cols(self.pg, tns, dim,group=self.group)) if len(ts) == 1: # fastpath: tensors at the same layer are often in the same file return self._get_tensor(rank, lidix, rank_lidixs[(rank, lidix)][0], ts[0], device, dtype) @@ -145,7 +154,7 @@ def as_dict(self, tensor_shard_dim: OrderedDict[str, int])->Dict[str, torch.Tens for tensor_name, dim in tensor_shard_dim.items(): (rank, lidx) = self._get_rank_lidx(tensor_name) loader = self.rank_loaders[rank][lidx] - tensors[tensor_name] = loader.shuffle(self.pg, tensor_name, dim) + tensors[tensor_name] = loader.shuffle(self.pg, tensor_name, dim, group=self.group) if self.auto_mem_delete: self.instantiated[rank][lidx][tensor_name] = True if len(self.instantiated[rank][lidx]) == len(loader.metadata.tensors): diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 5e98a32..4c75910 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -9,9 +9,11 @@ from typing import List, Dict, Tuple, OrderedDict, Union import warnings -from .common import SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, TensorFrame, get_device_numa_node, SingleGroup +from .common import paddle_loaded, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, TensorFrame, get_device_numa_node, SingleGroup from .tensor_factory import LazyTensorFactory from .file_buffer import FilesBufferOnDevice +if paddle_loaded: + import paddle initialized: bool = False loaded_nvidia: bool = False @@ -19,6 +21,10 @@ fstcpp.load_nvidia_functions() loaded_nvidia = True +support_framework = ["pytorch", "pt"] +if paddle_loaded: + support_framework.append("paddle") + class SafeTensorsFileLoader: r""" Load .safetensors files lazily. @@ -40,17 +46,34 @@ class SafeTensorsFileLoader: >> print(bufs.get_tensor(loader.get_keys()[0])) >> loader.close() """ - def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: int = 16 * 1024, max_pinned_memory_in_kb: int = 64 * 1024 * 1024, max_threads: int=16, nogds: bool=False, debug_log: bool=False): + def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: int = 16 * 1024, max_pinned_memory_in_kb: int = 64 * 1024 * 1024, max_threads: int=16, nogds: bool=False, debug_log: bool=False, framework="pytorch"): + if framework not in support_framework: + raise NotImplementedError(f"fastsafetensors only supports {support_framework} framework") self.device = device self.debug_log = debug_log self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {} self.need_gds_close = False self.frames: OrderedDict[str, TensorFrame] = {} - self.pg = pg + self.framework = framework + if self.framework == "pytorch" or isinstance(pg, SingleGroup): + self.pg = pg + self.group = pg + elif paddle_loaded and self.framework == "paddle": + self.pg = pg.process_group + self.group = pg + self.nogds = nogds global initialized if not initialized: fstcpp.set_debug_log(debug_log) - node = get_device_numa_node(device.index) + if self.framework == "pytorch": + d_id = device.index + elif paddle_loaded and self.framework == "paddle": + if device == "cpu": + d_id = None + else: + d_id = device.split(":") # "gpu:0" or "gpu" + d_id = int(d_id[1]) if len(d_id) == 2 else 0 + node = get_device_numa_node(d_id) if node is not None: fstcpp.set_numa_node(node) if False and fstcpp.is_cufile_found() and not nogds: # TODO: init_gds should be called but too slow for parallel initialization @@ -58,15 +81,16 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device, bbuf_size_kb: in raise Exception(f"[FAIL] init_gds max_io_block_in_kb={max_io_block_in_kb}, max_pinned_memory_in_kb={max_pinned_memory_in_kb}") self.need_gds_close = True initialized = True - if not device.type == "cpu" and not fstcpp.is_cuda_found(): + device_is_not_cpu = not (paddle_loaded and self.framework == "paddle" and device == "cpu") and not (self.framework == "pytorch" and device.type == "cpu") + if device_is_not_cpu and not fstcpp.is_cuda_found(): raise Exception("[FAIL] libcudart.so does not exist") if not fstcpp.is_cufile_found() and not nogds: warnings.warn("libcufile.so does not exist but nogds is False. use nogds=True", UserWarning) nogds = True if nogds: - self.reader = fstcpp.nogds_file_reader(False, bbuf_size_kb, max_threads, device.type != "cpu") + self.reader = fstcpp.nogds_file_reader(False, bbuf_size_kb, max_threads, device_is_not_cpu) else: - self.reader = fstcpp.gds_file_reader(max_threads, device != "cpu") + self.reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu) self.nogds = nogds def reset(self): @@ -97,7 +121,7 @@ def add_filenames(self, filenames: Dict[int, List[str]]): next_idx = rank_next_idx[rank] if next_idx < len(filenames[rank]): realpath = filenames[rank][next_idx] #os.path.realpath(filename) - metadata = SafeTensorsMetadata.from_file(realpath) + metadata = SafeTensorsMetadata.from_file(realpath, self.framework) self.meta[realpath] = (metadata, rank) self.frames.update(metadata.tensors) if self.debug_log and rank == self.pg.rank(): @@ -112,8 +136,12 @@ def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=T At this moment, we do not instantiate tensors but just creating copies at device buffers with or without GDS. Users can instantiate and/or partition tensors with FilesBufferOnDevice returned by this function. """ - if self.device.type != "cpu": - torch.cuda.set_device(self.device) + if self.framework == "pytorch": + if self.device.type != "cpu": + torch.cuda.set_device(self.device) + elif paddle_loaded and self.framework == "paddle": + if self.device != paddle.CPUPlace(): + paddle.set_device(self.device) need_wait: List[LazyTensorFactory] = [] factories: Dict[int, List[LazyTensorFactory]] = {} @@ -133,7 +161,11 @@ def copy_files_to_device(self, dtype: torch.dtype=None, use_buf_register: bool=T lidx += 1 for factory in need_wait: factory.wait_io(dtype=dtype, noalign=self.nogds) - return FilesBufferOnDevice(factories, pg=self.pg) + if self.framework == "pytorch": + return FilesBufferOnDevice(factories, pg=self.pg) + elif paddle_loaded and self.framework == "paddle": + return FilesBufferOnDevice(factories, pg=self.group, framework=self.framework) + return None class fastsafe_open: """ @@ -142,7 +174,7 @@ class fastsafe_open: Args: filenames (:obj:`str`|`list[str]`|`dict[int, str]`): The filename(s) or rank-file map to open - framework (:obj:`str`): `pt` is only supported currently + framework (:obj:`str`): `pt` and `paddle` are only supported currently device (:obj:`str`, defaults to :obj:`"cpu"`): The device on which you want the tensors. """ @@ -150,19 +182,20 @@ def __init__(self, filenames: Union[str, List[str], Dict[int, str]], framework: str="pt", pg: dist.ProcessGroup=SingleGroup(), device: Union[str, torch.device]="cpu", nogds: bool=False, - debug_log: bool=False): - if framework != "pt": + debug_log: bool=False, + max_copy_block_size: int=16*1024*1024*1024): + if framework not in support_framework: raise NotImplementedError("pytorch is only a framework that current fastsafetensors supports") - if isinstance(device, str): + if isinstance(device, str) and framework == "pt": device = torch.device(device) - self.loader = SafeTensorsFileLoader(pg, device, nogds=nogds, debug_log=debug_log) + self.loader = SafeTensorsFileLoader(pg, device, nogds=nogds, debug_log=debug_log, framework= framework if framework != "pt" else "pytorch") if isinstance(filenames, str): filenames = [filenames] if isinstance(filenames, list): self.loader.add_filenames({0: filenames}) elif isinstance(filenames, dict): self.loader.add_filenames(filenames) - self.fb = self.loader.copy_files_to_device() + self.fb = self.loader.copy_files_to_device(max_copy_block_size=max_copy_block_size) def metadata(self)->Dict[str, Dict[str, str]]: ret = {} diff --git a/fastsafetensors/tensor_factory.py b/fastsafetensors/tensor_factory.py index b751749..6556aee 100644 --- a/fastsafetensors/tensor_factory.py +++ b/fastsafetensors/tensor_factory.py @@ -7,10 +7,14 @@ from collections import OrderedDict from . import cpp as fstcpp -from .common import SafeTensorsMetadata, free_tensor_memory +from .common import SafeTensorsMetadata, free_tensor_memory, paddle_loaded from .copier.gds import GdsFileCopier from .copier.nogds import NoGdsFileCopier +if paddle_loaded: + import paddle + import paddle.distributed as pdist + class LazyTensorFactory: def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, rank: int, local_rank: bool, factory_idx_bits: int, lidx: int, nogds: bool, reader, debug_log: bool=False): self.metadata = metadata @@ -42,7 +46,7 @@ def wait_io(self, dtype: torch.dtype=None, noalign: bool=False): print(f"wait_io: tensor={name}") self.copier = None - def push(self, pg: dist.ProcessGroup, tensor_name: str, dst_rank: int, src_rank: int)->torch.Tensor: + def push(self, pg: dist.ProcessGroup, tensor_name: str, dst_rank: int, src_rank: int, group = None)->torch.Tensor: if pg.size() == 1: return self.tensors[tensor_name] tag = (self.next_tag << self.factory_idx_bits) + self.lidx @@ -62,15 +66,24 @@ def push(self, pg: dist.ProcessGroup, tensor_name: str, dst_rank: int, src_rank: t = self.tensors[tensor_name].clone().detach() if self.debug_log: print(f"push: send, tensor_name={tensor_name}, shape={frame.shape}, dst_rank={dst_rank}, pg.rank()={pg.rank()}, tag={tag}") - dist.send(t, dst_rank, group=pg, tag=tag) + if self.metadata.framework == "pytorch": + dist.send(t, dst_rank, group=pg, tag=tag) + elif paddle_loaded and self.metadata.framework == "paddle": + pdist.send(t, dst_rank, group=group) return None - t = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) + if self.debug_log: print(f"push: recv, tensor_name={tensor_name}, shape={frame.shape}, src_rank={src_rank}, pg.rank()={pg.rank()}, tag={tag}") - dist.recv(t, src_rank, group=pg, tag=tag) + + if self.metadata.framework == "pytorch": + t = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) + dist.recv(t, src_rank, group=pg, tag=tag) + elif paddle_loaded and self.metadata.framework == "paddle": + t = paddle.to_tensor(paddle.empty(size=frame.shape, dtype=frame.dtype), place=self.device) + pdist.recv(t,src_rank, group=group) return t - def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int)->torch.Tensor: + def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = None)->torch.Tensor: if pg.size() == 1: return self.tensors[tensor_name] if tensor_name in self.shuffled: @@ -83,10 +96,17 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int)->torch.Tens if tensor_name in self.tensors: dst = self.tensors[tensor_name].clone().detach() else: - dst = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) + if self.metadata.framework == "pytorch": + dst = torch.empty(size=frame.shape, dtype=frame.dtype, device=self.device) + elif paddle_loaded and self.metadata.framework == "paddle": + dst = paddle.to_tensor(paddle.empty(shape=frame.shape, dtype=frame.dtype), place=self.device) + if self.debug_log: print(f"shuffle: broadcast, tensor_name={tensor_name}, shape={frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, has_tensor={tensor_name in self.tensors}") - dist.broadcast(dst, self.rank, group=pg) + if self.metadata.framework == "pytorch": + dist.broadcast(dst, self.rank, group=pg) + elif paddle_loaded and self.metadata.framework == "paddle": + pdist.broadcast(dst, self.rank, group=group, sync_op=False) else: rank_slices: List[Tuple] = [() for i in range(0, pg.size())] size = frame.shape[dim] @@ -100,7 +120,11 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int)->torch.Tens break scatter_list: List[torch.Tensor] = [] new_frame = frame[rank_slices[pg.rank()]] - dst = torch.empty(size=new_frame.shape, dtype=new_frame.dtype, device=self.device) + + if self.metadata.framework == "pytorch": + dst = torch.empty(size=new_frame.shape, dtype=new_frame.dtype, device=self.device) + elif paddle_loaded and self.metadata.framework == "paddle": + dst = paddle.to_tensor(paddle.empty(shape=new_frame.shape, dtype=frame.dtype), place=self.device) if self.rank == pg.rank(): if tensor_name not in self.tensors: raise Exception(f"shuffle: tensor {tensor_name} was not found, released? lidx={self.lidx}") @@ -108,11 +132,14 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int)->torch.Tens scatter_list = [t[rank_slices[rank]].contiguous() for rank in range(0, pg.size())] # scatter requires contiguous tensor if self.debug_log: print(f"shuffle: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, rank_slices={rank_slices}, len(scatter_list)={len(scatter_list)}") - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + if self.metadata.framework == "pytorch": + dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + elif paddle_loaded and self.metadata.framework == "paddle": + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) self.shuffled[tensor_name] = dst return dst - def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str)->torch.Tensor: + def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str, group = None)->torch.Tensor: if tensor_name in self.shuffled: if self.debug_log: print(f"shuffle: use cache, tensor_name={tensor_name}") @@ -129,19 +156,28 @@ def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str)->torch.Ten q = t[(slice(rank * block_size, (rank + 1) * block_size, 1))] k = t[(slice(single_size + rank * block_size, single_size + (rank + 1) * block_size, 1))] v = t[(slice(single_size * 2 + rank * block_size, single_size * 2 + (rank + 1) * block_size, 1))] - scatter_list.append(torch.cat([q, k, v], dim=0)) + if self.metadata.framework == "pytorch": + cat_res = torch.cat([q, k, v], dim=0) + elif paddle_loaded and self.metadata.framework == "paddle": + cat_res = paddle.concat([q, k, v], axis=0) + scatter_list.append(cat_res) if pg.size() == 1: self.shuffled[tensor_name] = scatter_list[0] return scatter_list[0] new_shape = (block_size * 3,) + frame.shape[1:] - dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) + if self.debug_log: print(f"shuffle_packed_qkv: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, len(scatter_list)={len(scatter_list)}") - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + if self.metadata.framework == "pytorch": + dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) + dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + elif paddle_loaded and self.metadata.framework == "paddle": + dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype),place=self.device) + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) self.shuffled[tensor_name] = dst return dst - def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim: int)->torch.Tensor: + def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim: int, group = None)->torch.Tensor: rank_tensors: List[List[torch.Tensor]] = [[] for i in range(0, pg.size())] new_shape: List = [] for tensor_name in tensor_names: @@ -161,21 +197,30 @@ def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim for rank in range(0, pg.size()): rank_tensors[rank].append(t[(slice(rank * block_size, (rank + 1) * block_size, 1))]) if pg.size() == 1: - return torch.cat(rank_tensors[self.rank], dim=dim) + if self.metadata.framework == "pytorch": + return torch.cat(rank_tensors[self.rank], dim=dim) + elif paddle_loaded and self.metadata.framework == "paddle": + return paddle.concat(rank_tensors[self.rank], axis=dim) + return None scatter_list: List[torch.Tensor] = [] - dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) # dst should be eariler than scatter_list for less fragmentation + if len(rank_tensors[0]) > 0: for rank in range(0, pg.size()): scatter_list.append(torch.cat(rank_tensors[rank], dim=dim)) if self.debug_log: print(f"shuffle_multi_cols: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, len(scatter_list)={len(scatter_list)}") - dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + if self.metadata.framework == "pytorch": + dst = torch.empty(size=new_shape, dtype=frame.dtype, device=self.device) # dst should be eariler than scatter_list for less fragmentation + dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) + elif paddle_loaded and self.metadata.framework == "paddle": + dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype), place=self.device )# dst should be eariler than scatter_list for less fragmentation + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) return dst def free_dev_ptrs(self): self.tensors = {} if self.gbuf is not None: - free_tensor_memory(self.gbuf, self.device) + free_tensor_memory(self.gbuf, self.device, self.metadata.framework) self.gbuf = None def shuffle_all(self, pg: dist.ProcessGroup, tensor_shard_dim: OrderedDict[str, int])->Tuple[int, Dict[str, torch.Tensor]]: @@ -183,4 +228,4 @@ def shuffle_all(self, pg: dist.ProcessGroup, tensor_shard_dim: OrderedDict[str, for tensor_name, dim in tensor_shard_dim.items(): if tensor_name in self.metadata.tensors: ret[tensor_name] = self.shuffle(pg, tensor_name, dim) - return (0, ret) \ No newline at end of file + return (0, ret) diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py index 742f50a..ccb5574 100644 --- a/tests/test_fastsafetensors.py +++ b/tests/test_fastsafetensors.py @@ -11,32 +11,49 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup, SafeTensorsMetadata, fastsafe_open from fastsafetensors.copier.gds import GdsFileCopier from fastsafetensors.copier.nogds import NoGdsFileCopier -from fastsafetensors.common import alloc_tensor_memory, free_tensor_memory, need_workaround_dtypes +from fastsafetensors.common import alloc_tensor_memory, free_tensor_memory, need_workaround_dtypes, paddle_loaded from fastsafetensors import cpp as fstcpp +if paddle_loaded: + import paddle + from safetensors.paddle import save_file as paddle_save_file -def run_nogds_file_read(input_file: str)->Tuple[SafeTensorsMetadata, fstcpp.gds_device_buffer]: +def get_and_check_device(framework="pytorch"): + dev_is_gpu = fstcpp.is_cuda_found() + if framework == "pytorch" or framework == "pt": + device = torch.device("cuda:0" if dev_is_gpu else "cpu") + elif paddle_loaded and framework == "paddle": + device = "gpu:0" if dev_is_gpu else "cpu" + else: + raise NotImplementedError(f"Do not support framework: {framework}") + return device, dev_is_gpu + +def run_nogds_file_read(input_file: str, framework="pytorch")->Tuple[SafeTensorsMetadata, fstcpp.gds_device_buffer]: fd = os.open(input_file, os.O_RDONLY, 0o644) - meta = SafeTensorsMetadata.from_file(input_file) + meta = SafeTensorsMetadata.from_file(input_file, framework=framework) size = meta.size_bytes - meta.header_length - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - gbuf = alloc_tensor_memory(size, device) - reader = fstcpp.nogds_file_reader(False, 20 * 1024, 1, device.type=="cuda") + device, dev_is_gpu = get_and_check_device(framework) + gbuf = alloc_tensor_memory(size, device, framework=framework) + reader = fstcpp.nogds_file_reader(False, 20 * 1024, 1, dev_is_gpu) req = reader.submit_read(fd, gbuf, meta.header_length, size, 0) assert req > 0 assert reader.wait_read(req) >= 0 os.close(fd) return (meta, gbuf) -def test_load_metadata_and_dlpack(fstcpp_log, input_files): +def test_load_metadata_and_dlpack(fstcpp_log, input_files, framework="pytorch"): print("test_load_metadata_and_dlpack") assert len(input_files) > 0 - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") + device, _ = get_and_check_device(framework) for input_file in input_files: expected_tensors: Dict[str, torch.Tensor] = {} with safe_open(input_file, framework="pt") as f: for k in f.keys(): - expected_tensors[k] = f.get_tensor(k).to(device=device) - meta, gbuf = run_nogds_file_read(input_file) + expected_tensors[k] = f.get_tensor(k) + if framework == "pytorch": + expected_tensors[k] = expected_tensors[k].to(device=device) + elif framework == "paddle": + expected_tensors[k] = paddle.to_tensor(expected_tensors[k].numpy(), place=device) + meta, gbuf = run_nogds_file_read(input_file, framework=framework) assert meta.header_length > 0 assert meta.size_bytes > 0 assert len(meta.tensors) > 0 @@ -45,15 +62,30 @@ def test_load_metadata_and_dlpack(fstcpp_log, input_files): dst_dev_ptr = gbuf.get_base_address() + actual_meta.data_offsets[0] if actual_meta.dtype in need_workaround_dtypes: wdtype = need_workaround_dtypes[actual_meta.dtype] - actual = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, wdtype, device)).view(actual_meta.dtype) + cu_buf = from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, wdtype, device) + if framework == "pytorch": + actual = torch.from_dlpack(cu_buf).view(actual_meta.dtype) + elif framework == "paddle": + actual = paddle.utils.dlpack.from_dlpack(cu_buf).view(actual_meta.dtype) else: - actual = torch.from_dlpack(from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, actual_meta.dtype, device)) + cu_buf = from_cuda_buffer(dst_dev_ptr, actual_meta.shape, actual_meta.strides, actual_meta.dtype, device) + if framework == "pytorch": + actual = torch.from_dlpack(cu_buf) + elif framework == "paddle": + actual = paddle.utils.dlpack.from_dlpack(cu_buf) exp = expected_tensors[name] - assert torch.all(exp.eq(actual)) + if framework == "pytorch": + assert torch.all(exp.eq(actual)) + elif framework == "paddle": + assert paddle.all(exp.equal(actual)) if not printed: print(actual_meta.__repr__()) printed = True +def test_load_metadata_and_dlpack_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_load_metadata_and_dlpack(fstcpp_log, input_files, "paddle") + def test_set_debug_log(): fstcpp.set_debug_log(False) assert True @@ -78,37 +110,55 @@ def test_get_device_pci_bus(fstcpp_log): def test_set_numa_node(fstcpp_log): assert fstcpp.set_numa_node(0) == 0 -def test_alloc_gds_buffer(fstcpp_log): +def test_alloc_gds_buffer(fstcpp_log, framework="pytorch"): print("test_alloc_gds_buffer") - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - gbuf = alloc_tensor_memory(1024, device) + device, _ = get_and_check_device(framework) + gbuf = alloc_tensor_memory(1024, device, framework=framework) addr = gbuf.get_base_address() assert addr != 0 -def test_cufile_register_deregister(fstcpp_log): +def test_alloc_gds_buffer_for_paddle(fstcpp_log): + if paddle_loaded: + test_alloc_gds_buffer(fstcpp_log, "paddle") + +def test_cufile_register_deregister(fstcpp_log, framework="pytorch"): print("test_cufile_register_deregister") - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - gbuf = alloc_tensor_memory(1024, device) + device, _ = get_and_check_device(framework) + gbuf = alloc_tensor_memory(1024, device, framework=framework) assert gbuf.cufile_register(0, 256) == 0 assert gbuf.cufile_register(256, 1024-256) == 0 assert gbuf.cufile_deregister(0) == 0 assert gbuf.cufile_deregister(256) == 0 -def test_memmove(fstcpp_log): +def test_cufile_register_deregister_for_paddle(fstcpp_log): + if paddle_loaded: + test_alloc_gds_buffer(fstcpp_log, "paddle") + +def test_memmove(fstcpp_log , framework="pytorch"): print("test_memmove") - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - gbuf = alloc_tensor_memory(1024, device) - tmp = alloc_tensor_memory(1024, device) - assert gbuf.memmove(0, 12, tmp, 1024) == 0 + device, _ = get_and_check_device(framework) + gbuf = alloc_tensor_memory(1024, device, framework=framework) + tmp = alloc_tensor_memory(1024, device, framework=framework) + assert gbuf.memmove(0, 12, tmp, 256*3) == 0 + # Confuse about this test : gbuf.memmove(0, 12, tmp, 1024) + # I think this test should start copying a section of 1024 memory + # from the position of gbuf+12 to the position of gbuf+0. + # However, this piece of memory itself is only 1024. + # After offsetting by 12, there is no 1024 left in the remaining memory. + # This part really puzzles me. So I change the moving size to 256*3 (<1024) -def test_nogds_file_reader(fstcpp_log, input_files): +def test_memmove_for_paddle(fstcpp_log): + if paddle_loaded: + test_memmove(fstcpp_log, "paddle") + +def test_nogds_file_reader(fstcpp_log, input_files, framework="pytorch"): print("test_nogds_file_reader") fd = os.open(input_files[0], os.O_RDONLY, 0o644) s = os.fstat(fd) assert fd > 0 - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - gbuf = alloc_tensor_memory(s.st_size, device) - reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, device.type == "cuda") + device, dev_is_gpu = get_and_check_device(framework) + gbuf = alloc_tensor_memory(s.st_size, device, framework=framework) + reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, dev_is_gpu) step = s.st_size // 4 reqs = [] off = 0 @@ -126,48 +176,77 @@ def test_nogds_file_reader(fstcpp_log, input_files): assert reader.wait_read(req) > 0 os.close(fd) -def test_NoGdsFileCopier(fstcpp_log, input_files): +def test_nogds_file_reader_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_nogds_file_reader(fstcpp_log, input_files, "paddle") + +def test_NoGdsFileCopier(fstcpp_log, input_files, framework="pytorch"): print("test_NoGdsFileCopier") - meta = SafeTensorsMetadata.from_file(input_files[0]) - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, device.type == "cuda") + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + device, dev_is_gpu = get_and_check_device(framework) + reader = fstcpp.nogds_file_reader(False, 256 * 1024, 4, dev_is_gpu) copier = NoGdsFileCopier(meta, device, reader, True) gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) tensors = copier.wait_io(gbuf, None) with safe_open(input_files[0], framework="pt") as f: for key in tensors.keys(): - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) - free_tensor_memory(gbuf, device) + if framework == "pytorch": + assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) + elif framework == "paddle": + assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) + free_tensor_memory(gbuf, device, framework) -def test_GdsFileCopier(fstcpp_log, input_files): +def test_NoGdsFileCopier_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_NoGdsFileCopier(fstcpp_log, input_files,"paddle") + +def test_GdsFileCopier(fstcpp_log, input_files, framework="pytorch"): print("test_GdsFileCopier") if not fstcpp.is_cufile_found(): pytest.skip("cufile.so is not found") return - meta = SafeTensorsMetadata.from_file(input_files[0]) - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - reader = fstcpp.gds_file_reader(4, device.type == "cuda") + meta = SafeTensorsMetadata.from_file(input_files[0], framework=framework) + device, dev_is_gpu = get_and_check_device(framework) + reader = fstcpp.gds_file_reader(4, dev_is_gpu) copier = GdsFileCopier(meta, device, reader, True) gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) tensors = copier.wait_io(gbuf, None) with safe_open(input_files[0], framework="pt") as f: for key in tensors.keys(): - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) - free_tensor_memory(gbuf, device) + if framework == "torch": + assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) + elif framework == "paddle": + assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) + free_tensor_memory(gbuf, device, framework=framework) + +def test_GdsFileCopier_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_GdsFileCopier(fstcpp_log, input_files, "paddle") -def test_SafeTensorsFileLoader(fstcpp_log, input_files): - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True) +def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework="pytorch"): + device, _ = get_and_check_device(framework) + if framework == "pytorch": + data_type = torch.float16 + elif framework == "paddle": + # There are some lack of accuracy in paddle.float16 (about 1e-4) + data_type = paddle.float32 + else: + raise NotImplementedError(f"Do not support the framework: {framework}") + loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=False, debug_log=True, framework=framework) loader.add_filenames({0: input_files}) - bufs = loader.copy_files_to_device(dtype=torch.float16, use_buf_register=True, max_copy_block_size=256*1024*1024) + bufs = loader.copy_files_to_device(dtype=data_type, use_buf_register=True, max_copy_block_size=256*1024*1024) key_dims = {key: -1 for key in loader.get_keys()} tensors = bufs.as_dict(key_dims) last_key = "" last_shape: torch.Size = None with safe_open(input_files[0], framework="pt") as f: for key in tensors.keys(): - exp = f.get_tensor(key).to(device=device, dtype=torch.float16) - assert torch.all(exp.eq(bufs.get_tensor(key))) + if framework == "pytorch": + exp = f.get_tensor(key).to(device=device, dtype=data_type) + assert torch.all(exp.eq(bufs.get_tensor(key))) + elif framework == "paddle": + exp = paddle.to_tensor(f.get_tensor(key).numpy(), place=device, dtype=data_type) + assert paddle.all(exp.equal(bufs.get_tensor(key))) last_key = key last_shape = exp.shape if last_key != "": @@ -178,33 +257,53 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files): bufs.close() loader.close() +def test_SafeTensorsFileLoader_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_SafeTensorsFileLoader(fstcpp_log, input_files,"paddle") -def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files): - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") - loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True) +def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework="pytorch"): + device, _ = get_and_check_device(framework) + loader = SafeTensorsFileLoader(SingleGroup(), device, nogds=True, debug_log=True, framework=framework) loader.add_filenames({0: input_files}) bufs = loader.copy_files_to_device() key_dims = {key: -1 for key in loader.get_keys()} tensors = bufs.as_dict(key_dims) with safe_open(input_files[0], framework="pt") as f: for key in tensors.keys(): - assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) + if framework == "pytorch": + assert torch.all(f.get_tensor(key).to(device=device).eq(tensors[key])) + elif framework == "paddle": + assert paddle.all(paddle.to_tensor(f.get_tensor(key).numpy(), place=device).equal(tensors[key])) bufs.close() loader.close() -def test_fastsafe_open(fstcpp_log, input_files): - device = torch.device("cuda:0" if fstcpp.is_cuda_found() else "cpu") +def test_SafeTensorsFileLoaderNoGds_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, "paddle") + +def test_fastsafe_open(fstcpp_log, input_files, framework="pt"): + device, _ = get_and_check_device(framework) def weight_iterator(): - with fastsafe_open(input_files, pg=SingleGroup(), device=device, nogds=True, debug_log=True) as f: + with fastsafe_open(input_files, pg=SingleGroup(), device=device, nogds=True, debug_log=True, framework=framework) as f: for k in f.get_keys(): t = f.get_tensor(k) yield k, t tensors = {} with safe_open(input_files[0], framework="pt") as f: for key in f.keys(): - tensors[key] = f.get_tensor(key).to(device=device) + if framework == "pt": + tensors[key] = f.get_tensor(key).to(device=device) + elif framework == "paddle": + tensors[key] = paddle.to_tensor(f.get_tensor(key).numpy(), place=device) for k, t in weight_iterator(): - assert torch.all(tensors[k].eq(t)) + if framework == "pt": + assert torch.all(tensors[k].eq(t)) + elif framework == "paddle": + assert paddle.all(tensors[k].equal(t)) + +def test_fastsafe_open_for_paddle(fstcpp_log, input_files): + if paddle_loaded: + test_fastsafe_open(fstcpp_log, input_files, "paddle") def _test_type(tmp_dir, dtype, device): filename = os.path.join(tmp_dir, f"a.safetensors") @@ -218,11 +317,29 @@ def _test_type(tmp_dir, dtype, device): t2 = f.get_tensor(key) assert torch.all(t2.eq(t1)) +def _test_type_for_paddle(tmp_dir, dtype, device): + filename = os.path.join(tmp_dir, f"a.safetensors") + t0 = paddle.randn((8, 16), dtype=paddle.float32).to(dtype=dtype) + paddle_save_file({f"a": t0}, filename, metadata={"fst": "sample"}) + with fastsafe_open(filenames=[filename], nogds=True, device=device, debug_log=True, framework="paddle") as f: + for key in f.get_keys(): + t1 = f.get_tensor(key).clone().detach() + with safe_open(filename, framework='pt') as f: + for key in f.keys(): + t2 = paddle.to_tensor(f.get_tensor(key).numpy(), place=device) + assert paddle.all(t2.equal(t1)) + def test_int8(fstcpp_log, tmp_dir): _test_type(tmp_dir, torch.int8, "cuda:0" if fstcpp.is_cuda_found() else "cpu") + if paddle_loaded: + _test_type_for_paddle(tmp_dir, paddle.int8, "gpu:0" if fstcpp.is_cuda_found() else "cpu") def test_float8_e5m2(fstcpp_log, tmp_dir): _test_type(tmp_dir, torch.float8_e5m2, "cuda:0" if fstcpp.is_cuda_found() else "cpu") + if paddle_loaded: + _test_type_for_paddle(tmp_dir, paddle.float8_e5m2, "gpu:0" if fstcpp.is_cuda_found() else "cpu") def test_float8_e4m3fn(fstcpp_log, tmp_dir): _test_type(tmp_dir, torch.float8_e4m3fn, "cuda:0" if fstcpp.is_cuda_found() else "cpu") + if paddle_loaded: + _test_type_for_paddle(tmp_dir, paddle.float8_e4m3fn, "gpu:0" if fstcpp.is_cuda_found() else "cpu")