Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions examples/parse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

def parse_safetensors(folder: str, output_csv: str):
rows = []


fp8_tensors_name = set()
for root, _, files in os.walk(folder):
for file in files:
if file.endswith(".safetensors"):
Expand All @@ -15,26 +16,46 @@ def parse_safetensors(folder: str, output_csv: str):
with safe_open(file_path, framework="pt") as f:
keys = list(f.keys())
print(f" Found {len(keys)} tensors")

for tensor_name in keys:
tensor = f.get_tensor(tensor_name)
dtype = str(tensor.dtype)
shape = list(tensor.shape)

if dtype == "torch.float8_e4m3fn":
fp8_tensors_name.add(tensor_name)
tensor_min, tensor_max = "N/A", "N/A"
try:
tensor_min, tensor_max = tensor.min().item(), tensor.max().item()
except:
pass
tensor_min, tensor_max = str(tensor_min), str(tensor_max)
Comment on lines +26 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for getting min/max values can be simplified. Also, using a bare except is discouraged as it can catch system-level exceptions like KeyboardInterrupt. It's better to handle the success case and the exception case more explicitly to improve readability and robustness. This suggestion refactors the logic to be cleaner and uses except Exception:.

Suggested change
tensor_min, tensor_max = "N/A", "N/A"
try:
tensor_min, tensor_max = tensor.min().item(), tensor.max().item()
except:
pass
tensor_min, tensor_max = str(tensor_min), str(tensor_max)
try:
tensor_min = str(tensor.min().item())
tensor_max = str(tensor.max().item())
except Exception:
tensor_min, tensor_max = "N/A", "N/A"

rows.append({
"file": os.path.relpath(file_path, folder),
"tensor": tensor_name,
"dtype": dtype,
"shape": str(shape)
"shape": str(shape),
"min": tensor_min,
"max": tensor_max
})
# breakpoint()
except Exception as e:
print(f" Error reading {file_path}: {e}")

print(f"\nSaving results to {output_csv} ...")
with open(output_csv, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape"])
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape", "min", "max"])
writer.writeheader()
writer.writerows(rows)
print("Done ✅")
# show all fp8 tensors, ignore the layer index
filtered_fp8_tensors = set()
for name in fp8_tensors_name:
filtered_name = name
for idx in range(200):
filtered_name = filtered_name.replace(f".{idx}.", ".*.")
filtered_fp8_tensors.add(filtered_name)
Comment on lines +51 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation with a nested loop to replace layer indices is inefficient as it iterates 200 times for each tensor name. A more efficient and robust approach is to use a regular expression. This can be made even more concise and Pythonic using a set comprehension. This also removes the hardcoded limit of 200 for layer indices.

Note: you will need to add import re at the top of the file.

Suggested change
filtered_fp8_tensors = set()
for name in fp8_tensors_name:
filtered_name = name
for idx in range(200):
filtered_name = filtered_name.replace(f".{idx}.", ".*.")
filtered_fp8_tensors.add(filtered_name)
filtered_fp8_tensors = {re.sub(r'\.\d+\.', '.*.', name) for name in fp8_tensors_name}

for name in sorted(filtered_fp8_tensors):
print(name)


if __name__ == "__main__":
Expand Down