Skip to content

Commit 988774d

Browse files
authored
Merge pull request exo-explore#511 from roryclear/load_shard_only
only load layers in shard in tinygrad
2 parents 3f2bde1 + 3c81845 commit 988774d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

exo/inference/tinygrad/tinygrad_helpers.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from exo.helpers import DEBUG
88
from exo.download.hf.hf_helpers import get_allow_patterns
99
from fnmatch import fnmatch
10+
import re
1011

1112

1213
# **** helper functions ****
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
4243
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
4344
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
4445
elif fn.endswith(".safetensors"):
45-
return safe_load(fn)
46+
weight_map = safe_load(fn)
47+
for k in list(weight_map):
48+
if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
49+
del weight_map[k]
50+
return weight_map
4651
else:
4752
return torch_load(fn)

0 commit comments

Comments
 (0)