-
Notifications
You must be signed in to change notification settings - Fork 0
update parse meta #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,8 @@ | |||||||||||||||
|
|
||||||||||||||||
| def parse_safetensors(folder: str, output_csv: str): | ||||||||||||||||
| rows = [] | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| fp8_tensors_name = set() | ||||||||||||||||
| for root, _, files in os.walk(folder): | ||||||||||||||||
| for file in files: | ||||||||||||||||
| if file.endswith(".safetensors"): | ||||||||||||||||
|
|
@@ -15,26 +16,46 @@ def parse_safetensors(folder: str, output_csv: str): | |||||||||||||||
| with safe_open(file_path, framework="pt") as f: | ||||||||||||||||
| keys = list(f.keys()) | ||||||||||||||||
| print(f" Found {len(keys)} tensors") | ||||||||||||||||
|
|
||||||||||||||||
| for tensor_name in keys: | ||||||||||||||||
| tensor = f.get_tensor(tensor_name) | ||||||||||||||||
| dtype = str(tensor.dtype) | ||||||||||||||||
| shape = list(tensor.shape) | ||||||||||||||||
|
|
||||||||||||||||
| if dtype == "torch.float8_e4m3fn": | ||||||||||||||||
| fp8_tensors_name.add(tensor_name) | ||||||||||||||||
| tensor_min, tensor_max = "N/A", "N/A" | ||||||||||||||||
| try: | ||||||||||||||||
| tensor_min, tensor_max = tensor.min().item(), tensor.max().item() | ||||||||||||||||
| except: | ||||||||||||||||
| pass | ||||||||||||||||
| tensor_min, tensor_max = str(tensor_min), str(tensor_max) | ||||||||||||||||
| rows.append({ | ||||||||||||||||
| "file": os.path.relpath(file_path, folder), | ||||||||||||||||
| "tensor": tensor_name, | ||||||||||||||||
| "dtype": dtype, | ||||||||||||||||
| "shape": str(shape) | ||||||||||||||||
| "shape": str(shape), | ||||||||||||||||
| "min": tensor_min, | ||||||||||||||||
| "max": tensor_max | ||||||||||||||||
| }) | ||||||||||||||||
| # breakpoint() | ||||||||||||||||
| except Exception as e: | ||||||||||||||||
| print(f" Error reading {file_path}: {e}") | ||||||||||||||||
|
|
||||||||||||||||
| print(f"\nSaving results to {output_csv} ...") | ||||||||||||||||
| with open(output_csv, "w", newline="") as csvfile: | ||||||||||||||||
| writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape"]) | ||||||||||||||||
| writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape", "min", "max"]) | ||||||||||||||||
| writer.writeheader() | ||||||||||||||||
| writer.writerows(rows) | ||||||||||||||||
| print("Done ✅") | ||||||||||||||||
| # show all fp8 tensors, ignore the layer index | ||||||||||||||||
| filtered_fp8_tensors = set() | ||||||||||||||||
| for name in fp8_tensors_name: | ||||||||||||||||
| filtered_name = name | ||||||||||||||||
| for idx in range(200): | ||||||||||||||||
| filtered_name = filtered_name.replace(f".{idx}.", ".*.") | ||||||||||||||||
| filtered_fp8_tensors.add(filtered_name) | ||||||||||||||||
|
Comment on lines
+51
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation with a nested loop to replace layer indices is inefficient as it iterates 200 times for each tensor name. A more efficient and robust approach is to use a regular expression. This can be made even more concise and Pythonic using a set comprehension. This also removes the hardcoded limit of 200 for layer indices. Note: you will need to add
Suggested change
|
||||||||||||||||
| for name in sorted(filtered_fp8_tensors): | ||||||||||||||||
| print(name) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for getting min/max values can be simplified. Also, using a bare
exceptis discouraged as it can catch system-level exceptions likeKeyboardInterrupt. It's better to handle the success case and the exception case more explicitly to improve readability and robustness. This suggestion refactors the logic to be cleaner and usesexcept Exception:.