We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 3f2bde1 + 3c81845 commit 988774dCopy full SHA for 988774d
exo/inference/tinygrad/tinygrad_helpers.py
@@ -7,6 +7,7 @@
7
from exo.helpers import DEBUG
8
from exo.download.hf.hf_helpers import get_allow_patterns
9
from fnmatch import fnmatch
10
+import re
11
12
13
# **** helper functions ****
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
42
43
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
44
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
45
elif fn.endswith(".safetensors"):
- return safe_load(fn)
46
+ weight_map = safe_load(fn)
47
+ for k in list(weight_map):
48
+ if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
49
+ del weight_map[k]
50
+ return weight_map
51
else:
52
return torch_load(fn)
0 commit comments