diff --git a/examples/parse_metadata.py b/examples/parse_metadata.py index 1346449..5a178ab 100644 --- a/examples/parse_metadata.py +++ b/examples/parse_metadata.py @@ -5,7 +5,8 @@ def parse_safetensors(folder: str, output_csv: str): rows = [] - + + fp8_tensors_name = set() for root, _, files in os.walk(folder): for file in files: if file.endswith(".safetensors"): @@ -15,26 +16,46 @@ def parse_safetensors(folder: str, output_csv: str): with safe_open(file_path, framework="pt") as f: keys = list(f.keys()) print(f" Found {len(keys)} tensors") + for tensor_name in keys: tensor = f.get_tensor(tensor_name) dtype = str(tensor.dtype) shape = list(tensor.shape) - + if dtype == "torch.float8_e4m3fn": + fp8_tensors_name.add(tensor_name) + tensor_min, tensor_max = "N/A", "N/A" + try: + tensor_min, tensor_max = tensor.min().item(), tensor.max().item() + except: + pass + tensor_min, tensor_max = str(tensor_min), str(tensor_max) rows.append({ "file": os.path.relpath(file_path, folder), "tensor": tensor_name, "dtype": dtype, - "shape": str(shape) + "shape": str(shape), + "min": tensor_min, + "max": tensor_max }) + # breakpoint() except Exception as e: print(f" Error reading {file_path}: {e}") print(f"\nSaving results to {output_csv} ...") with open(output_csv, "w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape"]) + writer = csv.DictWriter(csvfile, fieldnames=["file", "tensor", "dtype", "shape", "min", "max"]) writer.writeheader() writer.writerows(rows) print("Done ✅") + # show all fp8 tensors, ignore the layer index + filtered_fp8_tensors = set() + for name in fp8_tensors_name: + filtered_name = name + for idx in range(200): + filtered_name = filtered_name.replace(f".{idx}.", ".*.") + filtered_fp8_tensors.add(filtered_name) + for name in sorted(filtered_fp8_tensors): + print(name) if __name__ == "__main__": diff --git a/examples/update_shape.py b/examples/update_shape.py new file mode 100644 index 0000000..b7f1e90 --- /dev/null +++ b/examples/update_shape.py @@ -0,0 +1,58 @@ +import os +import sys +import tempfile +import shutil +from safetensors import safe_open +from safetensors.torch import save_file +import torch + +def reshape_weight_scale(file_path: str): + """Update all tensors ending with 'weight_scale' from [out_features] to [out_features, 1].""" + new_tensors = {} + updated = False + + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + t = f.get_tensor(k) + if k.endswith("weight_scale") and t.ndim == 1: + print(f" -> Reshaping {k}: {list(t.shape)} -> {[t.shape[0], 1]}") + new_tensors[k] = t.unsqueeze(1) + updated = True + else: + new_tensors[k] = t + + if updated: + # Save to a temporary file first, then replace the original + tmp_file = file_path + ".tmp" + save_file(new_tensors, tmp_file) # torch version + shutil.move(tmp_file, file_path) + print(f"✅ Updated file saved: {file_path}") + else: + print(f"â„šī¸ No changes needed for {file_path}") + + +def process_folder(folder: str): + """Process all .safetensors files in a folder.""" + files = [f for f in os.listdir(folder) if f.endswith(".safetensors")] + + if not files: + print("No .safetensors files found in the folder.") + return + + for fname in files: + file_path = os.path.join(folder, fname) + print(f"\nProcessing {fname} ...") + reshape_weight_scale(file_path) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python update_weight_scale.py ") + sys.exit(1) + + folder_path = sys.argv[1] + if not os.path.isdir(folder_path): + print(f"Error: {folder_path} is not a valid folder.") + sys.exit(1) + + process_folder(folder_path)