Skip to content

Commit 438e784

Browse files
Restored convert_weights.py
1 parent 12ea6d3 commit 438e784

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

convert_weights.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import numpy as np
3+
import tensorflow as tf # Needed to handle TF tensors
4+
import h5py
5+
6+
def convert_to_numpy(value):
7+
"""
8+
Convert TensorFlow tensors/variables to NumPy arrays (float32).
9+
Ensures we remove any TensorFlow-specific data.
10+
"""
11+
if isinstance(value, tf.Tensor) or isinstance(value, tf.Variable):
12+
return value.numpy().astype(np.float32)
13+
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.number):
14+
return value.astype(np.float32)
15+
else:
16+
return None # Ignore non-numeric data
17+
18+
def convert_weights(npy_path):
19+
"""Convert `weights.npy` to a TensorFlow-free HDF5 format."""
20+
if not os.path.exists(npy_path):
21+
print(f"❌ Error: {npy_path} not found!")
22+
return
23+
24+
h5_path = npy_path.replace(".npy", "_tf_free.h5")
25+
26+
# Load the weights
27+
print(f"🔍 Loading {npy_path}...")
28+
weights = np.load(npy_path, allow_pickle=True)
29+
30+
# Convert all elements to NumPy arrays (remove TensorFlow dtypes)
31+
converted_weights = [convert_to_numpy(w) for w in weights if convert_to_numpy(w) is not None]
32+
33+
# Save to HDF5 format
34+
with h5py.File(h5_path, "w") as hf:
35+
for i, w in enumerate(converted_weights):
36+
hf.create_dataset(f"weight_{i}", data=w)
37+
38+
print(f"✅ Converted: {npy_path} -> {h5_path}")
39+
40+
if __name__ == "__main__":
41+
# Search for all `weights.npy` files and convert them
42+
for root, _, files in os.walk("."):
43+
for file in files:
44+
if file == "weights.npy":
45+
convert_weights(os.path.join(root, file))
46+
47+
print("🚀 All weight files converted successfully!")
48+

0 commit comments

Comments
 (0)