Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastsafetensors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright 2024 IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache-2.0

from importlib.metadata import version

__version__ = version(__name__)

from .common import SafeTensorsMetadata, SingleGroup, TensorFrame, get_device_numa_node
from .file_buffer import FilesBufferOnDevice
from .loader import SafeTensorsFileLoader, fastsafe_open
4 changes: 2 additions & 2 deletions fastsafetensors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import json
import os
import platform
import sys
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
Expand All @@ -15,7 +15,7 @@


def get_device_numa_node(device: Optional[int]) -> Optional[int]:
if device is None or platform.system() != "Linux":
if device is None or not sys.platform.startswith("linux"):
return None
pci_addr = fstcpp.get_device_pci_bus(device)
if pci_addr == "":
Expand Down
7 changes: 7 additions & 0 deletions fastsafetensors/cpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class gds_device_buffer:
self, dst_off: int, src_off: int, tmp: "gds_device_buffer", length: int
) -> int: ...
def get_base_address(self) -> int: ...
def get_length(self) -> int: ...

class nogds_file_reader:
def __init__(
Expand All @@ -35,6 +36,11 @@ class gds_file_reader:
) -> int: ...
def wait_read(self, id: int) -> int: ...

class cpp_metrics:
bounce_buffer_bytes: int

def __init__(self) -> None: ...

def is_cuda_found() -> bool: ...
def is_cufile_found() -> bool: ...
def cufile_version() -> int: ...
Expand All @@ -50,3 +56,4 @@ def cpu_free(addr: int) -> None: ...
def gpu_malloc(length: int) -> int: ...
def gpu_free(addr: int) -> None: ...
def load_nvidia_functions() -> None: ...
def get_cpp_metrics() -> cpp_metrics: ...
51 changes: 36 additions & 15 deletions fastsafetensors/cpp/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

static bool debug_log = false;

static cpp_metrics_t mc = {.bounce_buffer_bytes = 0};

/* cpu_mode functions: for tests and debugs */

static CUfileError_t cpu_cuFileDriverOpen() { return CUfileError_t{.err = CU_FILE_SUCCESS}; }
Expand Down Expand Up @@ -255,7 +257,7 @@ int init_gds()
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] init_gds: cuFileDriverOpen=%lld us\n",
std::printf("[DEBUG] init_gds: cuFileDriverOpen=%" PRId64 " us\n",
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
return 0;
Expand All @@ -275,7 +277,7 @@ int close_gds()
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] close_gds: cuFileDriverClose, elapsed=%lld us\n",
std::printf("[DEBUG] close_gds: cuFileDriverClose, elapsed=%" PRId64 " us\n",
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
return 0;
Expand Down Expand Up @@ -352,7 +354,7 @@ const int gds_device_buffer::cufile_register(uint64_t offset, uint64_t length) {
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] gds_device_buffer.cufile_register: addr=%p, offset=%" PRIu64 ", length=%" PRIu64 ", register=%lld us\n", dst, offset, length,
std::printf("[DEBUG] gds_device_buffer.cufile_register: addr=%p, offset=%" PRIu64 ", length=%" PRIu64 ", register=%" PRId64 " us\n", dst, offset, length,
std::chrono::duration_cast<std::chrono::microseconds>(end - begin_register).count());
}
return 0;
Expand All @@ -369,7 +371,7 @@ const int gds_device_buffer::cufile_deregister(uint64_t offset) {
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] gds_device_buffer.cufile_deregister: addr=%p, offset=%" PRIu64 ", elapsed=%lld us\n", dst, offset,
std::printf("[DEBUG] gds_device_buffer.cufile_deregister: addr=%p, offset=%" PRIu64 ", elapsed=%" PRId64 " us\n", dst, offset,
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
return 0;
Expand Down Expand Up @@ -410,7 +412,7 @@ const int gds_device_buffer::memmove(uint64_t _dst_off, uint64_t _src_off, const
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] gds_device_buffer.memmove: dst=%p, src=%p, tmp=%p, length=%" PRIu64 ", elapsed=%lld us\n", dst, src, tmp, length,
std::printf("[DEBUG] gds_device_buffer.memmove: dst=%p, src=%p, tmp=%p, length=%" PRIu64 ", elapsed=%" PRId64 " us\n", dst, src, tmp, length,
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
return 0;
Expand All @@ -434,7 +436,7 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] nogds_file_reader._thread: mmap, fd=%d, offset=%" PRIu64 ", length=%" PRIu64 ", elapsed=%lld us\n",
std::printf("[DEBUG] nogds_file_reader._thread: mmap, fd=%d, offset=%" PRIu64 ", length=%" PRIu64 ", elapsed=%" PRId64 " us\n",
fd, offset, length, std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
}
Expand Down Expand Up @@ -469,7 +471,7 @@ void nogds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const int
count += c;
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] nogds_file_reader._thread: read (mmap=%d), fd=%d, offset=%" PRIu64 ", count=%" PRIi64 ", c=%" PRIi64 ", copy=%lld us, cuda_copy=%lld us\n",
std::printf("[DEBUG] nogds_file_reader._thread: read (mmap=%d), fd=%d, offset=%" PRIu64 ", count=%" PRIi64 ", c=%" PRIi64 ", copy=%" PRId64 " us, cuda_copy=%" PRId64 " us\n",
s->_use_mmap, fd, offset, count, c, std::chrono::duration_cast<std::chrono::microseconds>(memcpy_begin - begin).count(), std::chrono::duration_cast<std::chrono::microseconds>(end - memcpy_begin).count());
}
}
Expand Down Expand Up @@ -500,15 +502,18 @@ const int nogds_file_reader::submit_read(const int fd, const gds_device_buffer&
if (this->_s._read_buffer == nullptr) {
cudaError_t err;
std::chrono::steady_clock::time_point alloc_begin = std::chrono::steady_clock::now();
err = _fns->cudaHostAlloc(&this->_s._read_buffer, this->_s._bbuf_size_kb * 1024 * this->_s._max_threads, 0);
auto buf_len = this->_s._bbuf_size_kb * 1024 * this->_s._max_threads;
err = _fns->cudaHostAlloc(&this->_s._read_buffer, buf_len, 0);
if (err != cudaSuccess) {
std::printf("nogds_file_reader.submit_read: cudaHostAlloc(%" PRIi64 ") failed\n", this->_s._bbuf_size_kb * 1024 * this->_s._max_threads);
std::printf("nogds_file_reader.submit_read: cudaHostAlloc(%" PRIi64 ") failed\n", buf_len);
return -1;
}
mc.bounce_buffer_bytes += buf_len;
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=%" PRIi64 ", elapsed=%lld us\n",
this->_s._bbuf_size_kb * 1024, std::chrono::duration_cast<std::chrono::microseconds>(end - alloc_begin).count());
std::printf("[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, addr=%p, size=%" PRIi64 ", elapsed=%" PRId64 " us\n",
reinterpret_cast<void*>(this->_s._read_buffer),
buf_len, std::chrono::duration_cast<std::chrono::microseconds>(end - alloc_begin).count());
}
}
std::thread *t = this->_threads[thread_id % this->_s._max_threads];
Expand Down Expand Up @@ -540,8 +545,14 @@ const uintptr_t nogds_file_reader::wait_read(const int thread_id) {
nogds_file_reader::~nogds_file_reader() {
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
if (this->_s._read_buffer != nullptr) {
auto buf_len = this->_s._bbuf_size_kb * 1024 * this->_s._max_threads;
_fns->cudaFreeHost(this->_s._read_buffer);
if (debug_log) {
std::printf("[DEBUG] cudaFreeHost, addr=%p, size=%" PRIi64 "\n",
reinterpret_cast<void *>(this->_s._read_buffer), buf_len);
}
this->_s._read_buffer = nullptr;
mc.bounce_buffer_bytes -= buf_len;
}
if (this->_threads != nullptr) {
for (uint64_t i = 0; i < this->_s._max_threads; ++i) {
Expand All @@ -556,7 +567,7 @@ nogds_file_reader::~nogds_file_reader() {
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] ~nogds_file_reader: elapsed=%lld us\n",
std::printf("[DEBUG] ~nogds_file_reader: elapsed=%" PRId64 " us\n",
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
}
Expand Down Expand Up @@ -595,7 +606,7 @@ raw_gds_file_handle::raw_gds_file_handle(std::string filename, bool o_direct, bo
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] raw_gds_file_handle: fd=%d, cf_handle=%p, elapsed=%lld us\n", fd, cf_handle,
std::printf("[DEBUG] raw_gds_file_handle: fd=%d, cf_handle=%p, elapsed=%" PRId64 " us\n", fd, cf_handle,
std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());
}
this->_cf_handle = cf_handle;
Expand Down Expand Up @@ -650,7 +661,7 @@ void gds_file_reader::_thread(const int thread_id, ext_funcs_t *fns, const gds_f
}
if (debug_log) {
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::printf("[DEBUG] gds_file_reader._thread: fh=%p, offset=%" PRIu64 ", length=%" PRIu64 ", count=%zd, read=%lld us, notify=%lld us\n",
std::printf("[DEBUG] gds_file_reader._thread: fh=%p, offset=%" PRIu64 ", length=%" PRIu64 ", count=%zd, read=%" PRId64" us, notify=%" PRId64 " us\n",
fh._get_cf_handle(), offset, length, count,
std::chrono::duration_cast<std::chrono::microseconds>(begin_notify - begin).count(),
std::chrono::duration_cast<std::chrono::microseconds>(end - begin_notify).count());
Expand Down Expand Up @@ -699,6 +710,10 @@ const ssize_t gds_file_reader::wait_read(const int id) {
return ret;
}

cpp_metrics_t get_cpp_metrics() {
return mc;
}

// Bindings

PYBIND11_MODULE(__MOD_NAME__, m)
Expand All @@ -718,13 +733,15 @@ PYBIND11_MODULE(__MOD_NAME__, m)
m.def("gpu_malloc", &gpu_malloc);
m.def("gpu_free", &gpu_free);
m.def("load_nvidia_functions", &load_nvidia_functions);
m.def("get_cpp_metrics", &get_cpp_metrics);

pybind11::class_<gds_device_buffer>(m, "gds_device_buffer")
.def(pybind11::init<const uintptr_t, const uint64_t, bool>())
.def("cufile_register", &gds_device_buffer::cufile_register)
.def("cufile_deregister", &gds_device_buffer::cufile_deregister)
.def("memmove", &gds_device_buffer::memmove)
.def("get_base_address", &gds_device_buffer::get_base_address);
.def("get_base_address", &gds_device_buffer::get_base_address)
.def("get_length", &gds_device_buffer::get_length);

pybind11::class_<nogds_file_reader>(m, "nogds_file_reader")
.def(pybind11::init<const bool, const uint64_t, const int, bool>())
Expand All @@ -738,4 +755,8 @@ PYBIND11_MODULE(__MOD_NAME__, m)
.def(pybind11::init<const int, bool>())
.def("submit_read", &gds_file_reader::submit_read)
.def("wait_read", &gds_file_reader::wait_read);

pybind11::class_<cpp_metrics_t>(m, "cpp_metrics")
.def(pybind11::init<>())
.def_readwrite("bounce_buffer_bytes", &cpp_metrics_t::bounce_buffer_bytes);
}
7 changes: 7 additions & 0 deletions fastsafetensors/cpp/ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class gds_device_buffer {
const uintptr_t get_base_address() const {
return this->_devPtr_base->get_uintptr();
}
const uint64_t get_length() const {
return _length;
}
};

class nogds_file_reader {
Expand Down Expand Up @@ -198,4 +201,8 @@ typedef struct ext_funcs {
int (*numa_run_on_node)(int);
} ext_funcs_t;

typedef struct cpp_metrics {
size_t bounce_buffer_bytes;
} cpp_metrics_t;

#endif //__EXT_HPP__
13 changes: 7 additions & 6 deletions fastsafetensors/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,16 @@ def randn(self, s: tuple, device: Device, dtype: DType) -> T:
def support_fp8(self) -> bool:
pass

@abstractmethod
def get_mem_used(self) -> int:
pass


def get_framework_op(name: str) -> FrameworkOpBase:
if name == "pt" or name == "pytorch" or name == "torch":
from ._torch import TorchOp

return TorchOp()
from ._torch import get_framework_op as op
elif name == "paddle":
from ._paddle import PaddleOp

return PaddleOp()
from ._paddle import get_framework_op as op
else:
raise Exception(f"Unknown framework name: {name}")
return op()
19 changes: 19 additions & 0 deletions fastsafetensors/frameworks/_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def recv(


class PaddleOp(FrameworkOpBase[PaddleTensor, PaddleProcessGroup]):
def __init__(self) -> None:
self.mem_used = 0

def get_name(self) -> str:
return "paddle"

Expand Down Expand Up @@ -173,13 +176,16 @@ def alloc_tensor_memory(self, length: int, dev: Device) -> gds_device_buffer:
rbuf = gpu_malloc(length)
else:
rbuf = cpu_malloc(length)
self.mem_used += length
return gds_device_buffer(rbuf, length, dev.type == DeviceType.GPU)

def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device) -> None:
length = gbuf.get_length()
if dev.type == DeviceType.GPU:
gpu_free(gbuf.get_base_address())
else:
cpu_free(gbuf.get_base_address())
self.mem_used -= length

def get_empty_tensor(
self, shape: List[int], dtype: DType, device: Device
Expand Down Expand Up @@ -248,3 +254,16 @@ def randn(self, s: tuple, device: Device, dtype: DType) -> PaddleTensor:

def support_fp8(self) -> bool:
return DType.F8_E5M2 in dtype_convert

def get_mem_used(self) -> int:
return self.mem_used


_op: Optional[PaddleOp] = None


def get_framework_op() -> FrameworkOpBase:
global _op
if _op is None:
_op = PaddleOp()
return _op
19 changes: 19 additions & 0 deletions fastsafetensors/frameworks/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def recv(


class TorchOp(FrameworkOpBase[TorchTensor, TorchProcessGroup]):
def __init__(self) -> None:
self.mem_used = 0

def get_name(self) -> str:
return "pytorch"

Expand All @@ -149,13 +152,16 @@ def alloc_tensor_memory(self, length: int, dev: Device) -> gds_device_buffer:
rbuf = torch.cuda.caching_allocator_alloc(length)
else:
rbuf = cpu_malloc(length)
self.mem_used += length
return gds_device_buffer(rbuf, length, dev.type == DeviceType.CUDA)

def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device):
length = gbuf.get_length()
if dev.type == DeviceType.CUDA:
torch.cuda.caching_allocator_delete(gbuf.get_base_address())
else:
cpu_free(gbuf.get_base_address())
self.mem_used -= length

def get_empty_tensor(
self, shape: List[int], dtype: DType, device: Device
Expand Down Expand Up @@ -218,3 +224,16 @@ def randn(self, s: tuple, device: Device, dtype: DType) -> TorchTensor:

def support_fp8(self) -> bool:
return DType.F8_E5M2 in dtype_convert

def get_mem_used(self):
return self.mem_used


_op: Optional[TorchOp] = None


def get_framework_op() -> FrameworkOpBase:
global _op
if _op is None:
_op = TorchOp()
return _op
1 change: 1 addition & 0 deletions fastsafetensors/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def reset(self):

def close(self):
self.reset()
del self.reader

def get_keys(self) -> List[str]:
return list(self.frames.keys())
Expand Down
6 changes: 6 additions & 0 deletions fastsafetensors/tensor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
def submit_io(self, use_buf_register: bool, max_copy_block_size: int):
if self.copier is not None:
self.gbuf = self.copier.submit_io(use_buf_register, max_copy_block_size)
if self.gbuf and self.debug_log:
print(f"submit_io: new buf, addr={self.gbuf.get_base_address():#x}")

def wait_io(self, dtype: DType = DType.AUTO, noalign: bool = False):
if self.copier is not None and self.gbuf is not None:
Expand Down Expand Up @@ -218,4 +220,8 @@ def free_dev_ptrs(self):
self.tensors = {}
if self.gbuf is not None:
self.framework.free_tensor_memory(self.gbuf, self.device)
if self.debug_log:
print(
f"free_dev_ptrs: delete buf, addr={self.gbuf.get_base_address():#x}"
)
self.gbuf = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fastsafetensors"
version = "0.1.14"
version = "0.1.15"
description = "High-performance safetensors model loader"
authors = [{name = "Takeshi Yoshimura", email = "[email protected]"}]
maintainers = [{name = "Takeshi Yoshimura", email = "[email protected]"}]
Expand Down
Loading