Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ htmlcov/
.coverage_*
.pytest_cache/
.vscode
.idea
*.log
*.pyc
examples/paddle_case/log
3 changes: 3 additions & 0 deletions fastsafetensors/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
max_threads: int = 16,
nogds: bool = False,
set_numa: bool = True,
disable_cache: bool = True,
debug_log: bool = False,
framework="pytorch",
):
Expand All @@ -55,6 +56,7 @@ def __init__(
self.debug_log = debug_log
self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {}
self.frames = OrderedDict[str, TensorFrame]()
self.disable_cache = disable_cache
global loaded_nvidia
if not loaded_nvidia:
fstcpp.load_nvidia_functions()
Expand Down Expand Up @@ -154,6 +156,7 @@ def copy_files_to_device(
self.reader,
self.framework,
self.debug_log,
disable_cache=self.disable_cache,
)
factory.submit_io(use_buf_register, max_copy_block_size)
factories[rank].append(factory)
Expand Down
8 changes: 7 additions & 1 deletion fastsafetensors/tensor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
reader: Union[fstcpp.gds_file_reader, fstcpp.nogds_file_reader],
framework: FrameworkOpBase,
debug_log: bool = False,
disable_cache=True,
):
self.framework = framework
self.metadata = metadata
Expand All @@ -46,6 +47,7 @@ def __init__(
self.factory_idx_bits = factory_idx_bits
self.lidx = lidx
self.next_tag = 1
self.disable_cache = disable_cache

def submit_io(self, use_buf_register: bool, max_copy_block_size: int):
if self.copier is not None:
Expand Down Expand Up @@ -160,7 +162,11 @@ def shuffle(self, pg: ProcessGroupBase, tensor_name: str, dim: int) -> TensorBas
f"shuffle: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, rank_slices={rank_slices}, len(scatter_list)={len(scatter_list)}"
)
pg.scatter(dst, scatter_list=scatter_list, src=self.rank)
self.shuffled[tensor_name] = dst
if not self.disable_cache:
# Cache tensor for reuse within the same batch to improve performance.
# Note: This requires additional (GPU) memory to store the cached tensors.
# Enable this only if you have sufficient (GPU) memory and required.
self.shuffled[tensor_name] = dst
return dst

def shuffle_multi_cols(
Expand Down