Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions examples/parse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

def parse_safetensors(folder: str, output_csv: str):
rows = []


fp8_tensors_name = set()
for root, _, files in os.walk(folder):
for file in files:
if file.endswith(".safetensors"):
Expand All @@ -15,26 +16,46 @@ def parse_safetensors(folder: str, output_csv: str):
with safe_open(file_path, framework="pt") as f:
keys = list(f.keys())
print(f" Found {len(keys)} tensors")

for tensor_name in keys:
tensor = f.get_tensor(tensor_name)
dtype = str(tensor.dtype)
shape = list(tensor.shape)

if dtype == "torch.float8_e4m3fn":
fp8_tensors_name.add(tensor_name)
tensor_min, tensor_max = "N/A", "N/A"
try:
tensor_min, tensor_max = tensor.min().item(), tensor.max().item()
except:
pass
tensor_min, tensor_max = str(tensor_min), str(tensor_max)
Comment on lines +26 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for getting min/max values can be simplified. Also, using a bare except is discouraged as it can catch system-level exceptions like KeyboardInterrupt. It's better to handle the success case and the exception case more explicitly to improve readability and robustness. This suggestion refactors the logic to be cleaner and uses except Exception:.

Suggested change
tensor_min, tensor_max = "N/A", "N/A"
try:
tensor_min, tensor_max = tensor.min().item(), tensor.max().item()
except:
pass
tensor_min, tensor_max = str(tensor_min), str(tensor_max)
try:
tensor_min = str(tensor.min().item())
tensor_max = str(tensor.max().item())
except Exception:
tensor_min, tensor_max = "N/A", "N/A"

rows.append({
"file": os.path.relpath(file_path, folder),
"tensor": tensor_name,
"dtype": dtype,
"shape": str(shape)
"shape": str(shape),
"min": tensor_min,
"max": tensor_max
})
# breakpoint()
except Exception as e:
print(f" Error reading {file_path}: {e}")

print(f"\nSaving results to {output_csv} ...")
with open(output_csv, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape"])
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape", "min", "max"])
writer.writeheader()
writer.writerows(rows)
print("Done ✅")
# show all fp8 tensors, ignore the layer index
filtered_fp8_tensors = set()
for name in fp8_tensors_name:
filtered_name = name
for idx in range(200):
filtered_name = filtered_name.replace(f".{idx}.", ".*.")
filtered_fp8_tensors.add(filtered_name)
Comment on lines +51 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation with a nested loop to replace layer indices is inefficient as it iterates 200 times for each tensor name. A more efficient and robust approach is to use a regular expression. This can be made even more concise and Pythonic using a set comprehension. This also removes the hardcoded limit of 200 for layer indices.

Note: you will need to add import re at the top of the file.

Suggested change
filtered_fp8_tensors = set()
for name in fp8_tensors_name:
filtered_name = name
for idx in range(200):
filtered_name = filtered_name.replace(f".{idx}.", ".*.")
filtered_fp8_tensors.add(filtered_name)
filtered_fp8_tensors = {re.sub(r'\.\d+\.', '.*.', name) for name in fp8_tensors_name}

for name in sorted(filtered_fp8_tensors):
print(name)


if __name__ == "__main__":
Expand Down
58 changes: 58 additions & 0 deletions examples/update_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys
import tempfile
import shutil
from safetensors import safe_open
from safetensors.torch import save_file
import torch

def reshape_weight_scale(file_path: str):
"""Update all tensors ending with 'weight_scale' from [out_features] to [out_features, 1]."""
new_tensors = {}
updated = False

with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
t = f.get_tensor(k)
if k.endswith("weight_scale") and t.ndim == 1:
print(f" -> Reshaping {k}: {list(t.shape)} -> {[t.shape[0], 1]}")
new_tensors[k] = t.unsqueeze(1)
updated = True
else:
new_tensors[k] = t

if updated:
# Save to a temporary file first, then replace the original
tmp_file = file_path + ".tmp"
save_file(new_tensors, tmp_file) # torch version
shutil.move(tmp_file, file_path)
print(f"✅ Updated file saved: {file_path}")
else:
print(f"ℹ️ No changes needed for {file_path}")


def process_folder(folder: str):
"""Process all .safetensors files in a folder."""
files = [f for f in os.listdir(folder) if f.endswith(".safetensors")]

if not files:
print("No .safetensors files found in the folder.")
return

for fname in files:
file_path = os.path.join(folder, fname)
print(f"\nProcessing {fname} ...")
reshape_weight_scale(file_path)


if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python update_weight_scale.py <folder_path>")
sys.exit(1)

folder_path = sys.argv[1]
if not os.path.isdir(folder_path):
print(f"Error: {folder_path} is not a valid folder.")
sys.exit(1)

process_folder(folder_path)