-
Notifications
You must be signed in to change notification settings - Fork 0
update parse meta #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,8 @@ | |||||||||||||||
|
||||||||||||||||
def parse_safetensors(folder: str, output_csv: str): | ||||||||||||||||
rows = [] | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
fp8_tensors_name = set() | ||||||||||||||||
for root, _, files in os.walk(folder): | ||||||||||||||||
for file in files: | ||||||||||||||||
if file.endswith(".safetensors"): | ||||||||||||||||
|
@@ -15,26 +16,46 @@ def parse_safetensors(folder: str, output_csv: str): | |||||||||||||||
with safe_open(file_path, framework="pt") as f: | ||||||||||||||||
keys = list(f.keys()) | ||||||||||||||||
print(f" Found {len(keys)} tensors") | ||||||||||||||||
|
||||||||||||||||
for tensor_name in keys: | ||||||||||||||||
tensor = f.get_tensor(tensor_name) | ||||||||||||||||
dtype = str(tensor.dtype) | ||||||||||||||||
shape = list(tensor.shape) | ||||||||||||||||
|
||||||||||||||||
if dtype == "torch.float8_e4m3fn": | ||||||||||||||||
fp8_tensors_name.add(tensor_name) | ||||||||||||||||
tensor_min, tensor_max = "N/A", "N/A" | ||||||||||||||||
try: | ||||||||||||||||
tensor_min, tensor_max = tensor.min().item(), tensor.max().item() | ||||||||||||||||
except: | ||||||||||||||||
pass | ||||||||||||||||
tensor_min, tensor_max = str(tensor_min), str(tensor_max) | ||||||||||||||||
rows.append({ | ||||||||||||||||
"file": os.path.relpath(file_path, folder), | ||||||||||||||||
"tensor": tensor_name, | ||||||||||||||||
"dtype": dtype, | ||||||||||||||||
"shape": str(shape) | ||||||||||||||||
"shape": str(shape), | ||||||||||||||||
"min": tensor_min, | ||||||||||||||||
"max": tensor_max | ||||||||||||||||
}) | ||||||||||||||||
# breakpoint() | ||||||||||||||||
except Exception as e: | ||||||||||||||||
print(f" Error reading {file_path}: {e}") | ||||||||||||||||
|
||||||||||||||||
print(f"\nSaving results to {output_csv} ...") | ||||||||||||||||
with open(output_csv, "w", newline="") as csvfile: | ||||||||||||||||
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape"]) | ||||||||||||||||
writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape", "min", "max"]) | ||||||||||||||||
writer.writeheader() | ||||||||||||||||
writer.writerows(rows) | ||||||||||||||||
print("Done ✅") | ||||||||||||||||
# show all fp8 tensors, ignore the layer index | ||||||||||||||||
filtered_fp8_tensors = set() | ||||||||||||||||
for name in fp8_tensors_name: | ||||||||||||||||
filtered_name = name | ||||||||||||||||
for idx in range(200): | ||||||||||||||||
filtered_name = filtered_name.replace(f".{idx}.", ".*.") | ||||||||||||||||
filtered_fp8_tensors.add(filtered_name) | ||||||||||||||||
Comment on lines
+51
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation with a nested loop to replace layer indices is inefficient as it iterates 200 times for each tensor name. A more efficient and robust approach is to use a regular expression. This can be made even more concise and Pythonic using a set comprehension. This also removes the hardcoded limit of 200 for layer indices. Note: you will need to add
Suggested change
|
||||||||||||||||
for name in sorted(filtered_fp8_tensors): | ||||||||||||||||
print(name) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import sys | ||
import tempfile | ||
import shutil | ||
from safetensors import safe_open | ||
from safetensors.torch import save_file | ||
import torch | ||
|
||
def reshape_weight_scale(file_path: str): | ||
"""Update all tensors ending with 'weight_scale' from [out_features] to [out_features, 1].""" | ||
new_tensors = {} | ||
updated = False | ||
|
||
with safe_open(file_path, framework="pt", device="cpu") as f: | ||
for k in f.keys(): | ||
t = f.get_tensor(k) | ||
if k.endswith("weight_scale") and t.ndim == 1: | ||
print(f" -> Reshaping {k}: {list(t.shape)} -> {[t.shape[0], 1]}") | ||
new_tensors[k] = t.unsqueeze(1) | ||
updated = True | ||
else: | ||
new_tensors[k] = t | ||
|
||
if updated: | ||
# Save to a temporary file first, then replace the original | ||
tmp_file = file_path + ".tmp" | ||
save_file(new_tensors, tmp_file) # torch version | ||
shutil.move(tmp_file, file_path) | ||
print(f"✅ Updated file saved: {file_path}") | ||
else: | ||
print(f"ℹ️ No changes needed for {file_path}") | ||
|
||
|
||
def process_folder(folder: str): | ||
"""Process all .safetensors files in a folder.""" | ||
files = [f for f in os.listdir(folder) if f.endswith(".safetensors")] | ||
|
||
if not files: | ||
print("No .safetensors files found in the folder.") | ||
return | ||
|
||
for fname in files: | ||
file_path = os.path.join(folder, fname) | ||
print(f"\nProcessing {fname} ...") | ||
reshape_weight_scale(file_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) < 2: | ||
print("Usage: python update_weight_scale.py <folder_path>") | ||
sys.exit(1) | ||
|
||
folder_path = sys.argv[1] | ||
if not os.path.isdir(folder_path): | ||
print(f"Error: {folder_path} is not a valid folder.") | ||
sys.exit(1) | ||
|
||
process_folder(folder_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for getting min/max values can be simplified. Also, using a bare
except
is discouraged as it can catch system-level exceptions likeKeyboardInterrupt
. It's better to handle the success case and the exception case more explicitly to improve readability and robustness. This suggestion refactors the logic to be cleaner and usesexcept Exception:
.