Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions fastsafetensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

__version__ = version(__name__)

from .common import SafeTensorsMetadata, SingleGroup, TensorFrame, get_device_numa_node
from .common import (
SafeTensorsMetadata,
SingleGroup,
TensorFrame,
get_device_numa_node,
)
from .file_buffer import FilesBufferOnDevice
from .loader import SafeTensorsFileLoader, fastsafe_open
from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader, fastsafe_open
from .parallel_loader import ParallelLoader
2 changes: 2 additions & 0 deletions fastsafetensors/cpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,5 @@ def gpu_malloc(length: int) -> int: ...
def gpu_free(addr: int) -> None: ...
def load_nvidia_functions() -> None: ...
def get_cpp_metrics() -> cpp_metrics: ...
def set_gil_release(gil_release: bool) -> None: ...
def get_gil_release() -> bool: ...
75 changes: 71 additions & 4 deletions fastsafetensors/cpp/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
#include <sys/mman.h>
#include <chrono>
#include <dlfcn.h>
#include <cstdlib>
#include <algorithm>

#include "cuda_compat.h"
#include "ext.hpp"

#define ALIGN 4096

static bool debug_log = false;
static bool enable_gil_release = false;

static cpp_metrics_t mc = {.bounce_buffer_bytes = 0};

Expand Down Expand Up @@ -266,6 +269,28 @@ void set_debug_log(bool _debug_log)
debug_log = _debug_log;
}

void set_gil_release(bool enable) {
enable_gil_release = enable;
}

bool get_gil_release() {
return enable_gil_release;
}

void init_gil_release_from_env() {
const char* env_val = std::getenv("FASTSAFETENSORS_ENABLE_GIL_RELEASE");
if (env_val != nullptr) {
std::string env_str(env_val);
// Convert to lowercase for case-insensitive comparison
std::transform(env_str.begin(), env_str.end(), env_str.begin(), ::tolower);
enable_gil_release = (env_str == "1" || env_str == "true" || env_str == "yes" || env_str == "on");
if (debug_log) {
std::printf("[DEBUG] GIL release %s via environment variable FASTSAFETENSORS_ENABLE_GIL_RELEASE=%s\n",
enable_gil_release ? "enabled" : "disabled", env_val);
}
}
}

int init_gds()
{
CUfileError_t err;
Expand Down Expand Up @@ -741,6 +766,8 @@ cpp_metrics_t get_cpp_metrics() {

PYBIND11_MODULE(__MOD_NAME__, m)
{
// Initialize GIL release setting from environment variable on module load
init_gil_release_from_env();
// Export both is_cuda_found and is_hip_found on all platforms
// Use string concatenation to prevent hipify from converting the export names
#ifdef USE_ROCM
Expand Down Expand Up @@ -771,6 +798,8 @@ PYBIND11_MODULE(__MOD_NAME__, m)
m.def("gpu_free", &gpu_free);
m.def("load_nvidia_functions", &load_nvidia_functions);
m.def("get_cpp_metrics", &get_cpp_metrics);
m.def("set_gil_release", &set_gil_release);
m.def("get_gil_release", &get_gil_release);

pybind11::class_<gds_device_buffer>(m, "gds_device_buffer")
.def(pybind11::init<const uintptr_t, const uint64_t, bool>())
Expand All @@ -780,18 +809,56 @@ PYBIND11_MODULE(__MOD_NAME__, m)
.def("get_base_address", &gds_device_buffer::get_base_address)
.def("get_length", &gds_device_buffer::get_length);

// Helper lambdas to conditionally apply GIL release
auto nogds_submit_read = [](nogds_file_reader& self, const int fd, const gds_device_buffer& dst, const int64_t offset, const int64_t length, const uint64_t ptr_off) {
if (enable_gil_release) {
pybind11::gil_scoped_release release;
return self.submit_read(fd, dst, offset, length, ptr_off);
} else {
return self.submit_read(fd, dst, offset, length, ptr_off);
}
};

auto nogds_wait_read = [](nogds_file_reader& self, const int thread_id) {
if (enable_gil_release) {
pybind11::gil_scoped_release release;
return self.wait_read(thread_id);
} else {
return self.wait_read(thread_id);
}
};

pybind11::class_<nogds_file_reader>(m, "nogds_file_reader")
.def(pybind11::init<const bool, const uint64_t, const int, bool>())
.def("submit_read", &nogds_file_reader::submit_read)
.def("wait_read", &nogds_file_reader::wait_read);
.def("submit_read", nogds_submit_read)
.def("wait_read", nogds_wait_read);

pybind11::class_<gds_file_handle>(m, "gds_file_handle")
.def(pybind11::init<std::string, bool, bool>());

// Helper lambdas for gds_file_reader to conditionally apply GIL release
auto gds_submit_read = [](gds_file_reader& self, const gds_file_handle &fh, const gds_device_buffer &dst, const uint64_t offset, const uint64_t length, const uint64_t ptr_off, const uint64_t file_length) {
if (enable_gil_release) {
pybind11::gil_scoped_release release;
return self.submit_read(fh, dst, offset, length, ptr_off, file_length);
} else {
return self.submit_read(fh, dst, offset, length, ptr_off, file_length);
}
};

auto gds_wait_read = [](gds_file_reader& self, const int id) {
if (enable_gil_release) {
pybind11::gil_scoped_release release;
return self.wait_read(id);
} else {
return self.wait_read(id);
}
};

pybind11::class_<gds_file_reader>(m, "gds_file_reader")
.def(pybind11::init<const int, bool>())
.def("submit_read", &gds_file_reader::submit_read)
.def("wait_read", &gds_file_reader::wait_read);
.def("submit_read", gds_submit_read)
.def("wait_read", gds_wait_read);

pybind11::class_<cpp_metrics_t>(m, "cpp_metrics")
.def(pybind11::init<>())
Expand Down
3 changes: 3 additions & 0 deletions fastsafetensors/cpp/ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ typedef struct CUfileDrvProps {

int get_alignment_size();
void set_debug_log(bool _debug_log);
void set_gil_release(bool enable);
bool get_gil_release();
void init_gil_release_from_env();
int init_gds();
int close_gds();
std::string get_device_pci_bus(int deviceId);
Expand Down
11 changes: 10 additions & 1 deletion fastsafetensors/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
disable_cache: bool = True,
debug_log: bool = False,
framework="pytorch",
**kwargs,
):
self.framework = get_framework_op(framework)
self.pg = self.framework.get_process_group(pg)
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
disable_cache: bool = True,
debug_log: bool = False,
framework="pytorch",
**kwargs,
):
self.framework = get_framework_op(framework)
self.pg = self.framework.get_process_group(pg)
Expand All @@ -191,7 +193,14 @@ def __init__(

copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
super().__init__(
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
pg,
self.device,
copier,
set_numa,
disable_cache,
debug_log,
framework,
**kwargs,
)


Expand Down
Loading