Skip to content

Commit e6a7e02

Browse files
committed
add parallel loader
Signed-off-by: yuanyuxing.yyx <[email protected]>
1 parent 40a5d58 commit e6a7e02

File tree

3 files changed

+464
-2
lines changed

3 files changed

+464
-2
lines changed

fastsafetensors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
from .common import SafeTensorsMetadata, SingleGroup, TensorFrame, get_device_numa_node
99
from .file_buffer import FilesBufferOnDevice
10-
from .loader import SafeTensorsFileLoader, fastsafe_open
10+
from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader, fastsafe_open
11+
from .parallel_loader import ParallelLoader

fastsafetensors/loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
disable_cache: bool = True,
4141
debug_log: bool = False,
4242
framework="pytorch",
43+
**kwargs,
4344
):
4445
self.framework = get_framework_op(framework)
4546
self.pg = self.framework.get_process_group(pg)
@@ -174,6 +175,7 @@ def __init__(
174175
disable_cache: bool = True,
175176
debug_log: bool = False,
176177
framework="pytorch",
178+
**kwargs,
177179
):
178180
self.framework = get_framework_op(framework)
179181
self.pg = self.framework.get_process_group(pg)
@@ -191,7 +193,14 @@ def __init__(
191193

192194
copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
193195
super().__init__(
194-
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
196+
pg,
197+
self.device,
198+
copier,
199+
set_numa,
200+
disable_cache,
201+
debug_log,
202+
framework,
203+
**kwargs,
195204
)
196205

197206

0 commit comments

Comments
 (0)