Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions fastsafetensors/copier/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Dict

from .. import cpp as fstcpp
from ..frameworks import TensorBase
from ..st_types import DType


class CopierInterface(ABC):
@abstractmethod
def submit_io(
self, use_buf_register: bool, max_copy_block_size: int
) -> fstcpp.gds_device_buffer:
pass

@abstractmethod
def wait_io(
self,
gbuf: fstcpp.gds_device_buffer,
dtype: DType = DType.AUTO,
noalign: bool = False,
) -> Dict[str, TensorBase]:
pass


class DummyDeviceBuffer(fstcpp.gds_device_buffer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks new codes do not use this class. Is this required for your future changes? what kind of code do you expect to use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will implement an example copier to handle GPU memory allocation within the reader.
This is managed in C++ code, which releases python GIL, thus improving overall performance.

The DummyDeviceBuffer can help me seamlessly integrates into the framework's workflow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can you move this change to the future change?

def __init__(self):
super().__init__(0, 0, False)
56 changes: 56 additions & 0 deletions fastsafetensors/copier/example_copier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata
from ..frameworks import FrameworkOpBase, TensorBase
from ..st_types import Device, DeviceType, DType
from .base import CopierInterface, DummyDeviceBuffer


class ExampleCopier(CopierInterface):
def __init__(
self,
metadata: SafeTensorsMetadata,
device: Device,
reader,
framework: FrameworkOpBase,
debug_log: bool = False,
):
pass

def submit_io(
self, use_buf_register: bool, max_copy_block_size: int
) -> fstcpp.gds_device_buffer:
return DummyDeviceBuffer()

def wait_io(
self,
gbuf: fstcpp.gds_device_buffer,
dtype: DType = DType.AUTO,
noalign: bool = False,
) -> Dict[str, TensorBase]:
# get tensor
res: Dict[str, TensorBase] = {}
return res


def new_gds_file_copier(
device: Device,
bbuf_size_kb: int = 16 * 1024,
max_threads: int = 16,
nogds: bool = False,
):
# reader = example_reader()
reader: Any = {}

def construct_copier(
metadata: SafeTensorsMetadata,
device: Device,
framework: FrameworkOpBase,
debug_log: bool = False,
) -> CopierInterface:
return ExampleCopier(metadata, device, reader, framework, debug_log)

return construct_copier
49 changes: 48 additions & 1 deletion fastsafetensors/copier/gds.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright 2024 IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Dict, Optional

from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata
from ..frameworks import FrameworkOpBase, TensorBase
from ..st_types import Device, DeviceType, DType
from .base import CopierInterface
from .nogds import NoGdsFileCopier


class GdsFileCopier:
class GdsFileCopier(CopierInterface):
def __init__(
self,
metadata: SafeTensorsMetadata,
Expand Down Expand Up @@ -139,3 +142,47 @@ def wait_io(
return self.metadata.get_tensors(
gbuf, self.device, self.aligned_offset, dtype=dtype
)


def new_gds_file_copier(
device: Device,
bbuf_size_kb: int = 16 * 1024,
max_threads: int = 16,
nogds: bool = False,
):
device_is_not_cpu = device.type != DeviceType.CPU
if device_is_not_cpu and not fstcpp.is_cuda_found():
raise Exception("[FAIL] libcudart.so does not exist")
if not fstcpp.is_cufile_found() and not nogds:
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
)
nogds = True

if nogds:
nogds_reader = fstcpp.nogds_file_reader(
False, bbuf_size_kb, max_threads, device_is_not_cpu
)

def construct_nogds_copier(
metadata: SafeTensorsMetadata,
device: Device,
framework: FrameworkOpBase,
debug_log: bool = False,
) -> CopierInterface:
return NoGdsFileCopier(metadata, device, nogds_reader, framework, debug_log)

return construct_nogds_copier

reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)

def construct_copier(
metadata: SafeTensorsMetadata,
device: Device,
framework: FrameworkOpBase,
debug_log: bool = False,
) -> CopierInterface:
return GdsFileCopier(metadata, device, reader, framework, debug_log)

return construct_copier
5 changes: 3 additions & 2 deletions fastsafetensors/copier/nogds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata
from ..frameworks import FrameworkOpBase, TensorBase
from ..st_types import Device, DType
from ..st_types import Device, DeviceType, DType
from .base import CopierInterface


class NoGdsFileCopier:
class NoGdsFileCopier(CopierInterface):
def __init__(
self,
metadata: SafeTensorsMetadata,
Expand Down
135 changes: 79 additions & 56 deletions fastsafetensors/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,70 @@
# SPDX-License-Identifier: Apache-2.0

import math
import warnings
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union

from . import cpp as fstcpp
from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node
from .copier.gds import new_gds_file_copier
from .file_buffer import FilesBufferOnDevice
from .frameworks import TensorBase, get_framework_op
from .st_types import DeviceType, DType
from .frameworks import FrameworkOpBase, TensorBase, get_framework_op
from .st_types import Device, DeviceType, DType
from .tensor_factory import LazyTensorFactory

gl_set_numa = False

loaded_nvidia = False


class SafeTensorsFileLoader:
r"""Load .safetensors files lazily.
class BaseSafeTensorsFileLoader:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to split the class into two? Looks BaseSafeTensorsFileLoader is not reusable anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach provides a clean way to implement a new copier and loader without modifying any code in the BaseSafeTensorsFileLoader, while also ensuring the existing interfaces in fastsafetensors remain unchanged.

r"""Base class for loading .safetensors files lazily.

Args:
devcie (str): target device.
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
bbuf_size_kb (int): bounce buffer size for file copies.
max_threads (int): maximum number of threads for memory copies.
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
debug_log (bool): enable debug logs.

Examples:
>> from fastsafetensors import SafeTensorsFileLoader
>> src_files = download(target_dir, "gpt2")
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
>> loader.add_filenames({0: src_files})
>> bufs = loader.copy_files_to_device()
>> print(bufs.get_tensor(loader.get_keys()[0]))
>> loader.close()
pg (Optional[Any]): Process group-like objects for distributed loading.
Use None for single device use-cases.
device (Device): Target device where tensors will be loaded (CPU, CUDA, etc.).
copier_constructor: Constructor function for creating file copier objects.
set_numa (bool): Whether to set NUMA node affinity for optimized memory access.
disable_cache (bool): Whether to disable caching of loaded tensors.
debug_log (bool): Enable detailed debug logging.
framework (str): Deep learning framework to use ("pytorch" or "paddle").
"""

def __init__(
self,
pg: Optional[Any],
device: str = "cpu",
bbuf_size_kb: int = 16 * 1024,
max_threads: int = 16,
nogds: bool = False,
device: Device,
copier_constructor,
set_numa: bool = True,
disable_cache: bool = True,
debug_log: bool = False,
framework="pytorch",
):
self.framework = get_framework_op(framework)
self.pg = self.framework.get_process_group(pg)
self.device = self.framework.get_device(device, self.pg)
self.device = device
self.debug_log = debug_log
self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {}
self.frames = OrderedDict[str, TensorFrame]()
self.disable_cache = disable_cache
global loaded_nvidia
if not loaded_nvidia:
fstcpp.load_nvidia_functions()
if not nogds:
# no need to init gds and consume 10s+ in none-gds case
if fstcpp.init_gds() != 0:
raise Exception(f"[FAIL] init_gds()")
loaded_nvidia = True
self.init_numa(set_numa)
self.copier_constructor = copier_constructor

def init_numa(self, set_numa: bool = True):
global gl_set_numa
if not gl_set_numa and set_numa:
node = get_device_numa_node(self.device.index)
if node is not None:
fstcpp.set_numa_node(node)
gl_set_numa = True
fstcpp.set_debug_log(debug_log)
device_is_not_cpu = self.device.type != DeviceType.CPU
if device_is_not_cpu and not fstcpp.is_cuda_found():
raise Exception("[FAIL] libcudart.so does not exist")
if not fstcpp.is_cufile_found() and not nogds:
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
)
nogds = True
self.reader: Union[fstcpp.nogds_file_reader, fstcpp.gds_file_reader]
if nogds:
self.reader = fstcpp.nogds_file_reader(
False, bbuf_size_kb, max_threads, device_is_not_cpu
)
else:
self.reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)

def reset(self):
self.frames = {}
self.meta = {}

def close(self):
self.reset()
del self.reader
del self.copier_constructor

def get_keys(self) -> List[str]:
return list(self.frames.keys())
Expand Down Expand Up @@ -145,8 +115,10 @@ def copy_files_to_device(

factory_idx_bits = math.ceil(math.log2(len(self.meta) + 1))
lidx = 1

for _, (meta, rank) in sorted(self.meta.items(), key=lambda x: x[0]):
copier = self.copier_constructor(
meta, self.device, self.framework, self.debug_log
)
self_rank = self.pg.rank() == rank
factory = LazyTensorFactory(
meta,
Expand All @@ -155,7 +127,7 @@ def copy_files_to_device(
self_rank,
factory_idx_bits,
lidx,
self.reader,
copier,
self.framework,
self.debug_log,
disable_cache=self.disable_cache,
Expand All @@ -166,12 +138,63 @@ def copy_files_to_device(
need_wait.append(factory)
lidx += 1
for factory in need_wait:
factory.wait_io(
dtype=dtype, noalign=isinstance(self.reader, fstcpp.nogds_file_reader)
)
factory.wait_io(dtype=dtype, noalign=False)
return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework)


class SafeTensorsFileLoader(BaseSafeTensorsFileLoader):
r"""Load .safetensors files lazily.

Args:
devcie (str): target device.
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
bbuf_size_kb (int): bounce buffer size for file copies.
max_threads (int): maximum number of threads for memory copies.
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
debug_log (bool): enable debug logs.

Examples:
>> from fastsafetensors import SafeTensorsFileLoader
>> src_files = download(target_dir, "gpt2")
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
>> loader.add_filenames({0: src_files})
>> bufs = loader.copy_files_to_device()
>> print(bufs.get_tensor(loader.get_keys()[0]))
>> loader.close()
"""

def __init__(
self,
pg: Optional[Any],
device: str = "cpu",
bbuf_size_kb: int = 16 * 1024,
max_threads: int = 16,
nogds: bool = False,
set_numa: bool = True,
disable_cache: bool = True,
debug_log: bool = False,
framework="pytorch",
):
self.framework = get_framework_op(framework)
self.pg = self.framework.get_process_group(pg)
self.device = self.framework.get_device(device, self.pg)

fstcpp.set_debug_log(debug_log)
global loaded_nvidia
if not loaded_nvidia:
fstcpp.load_nvidia_functions()
if not nogds:
# no need to init gds and consume 10s+ in none-gds case
if fstcpp.init_gds() != 0:
raise Exception(f"[FAIL] init_gds()")
loaded_nvidia = True

copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
super().__init__(
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
)


class fastsafe_open:
"""
Opens a safetensors lazily and returns tensors as asked
Expand Down
Loading