Skip to content

Commit

Permalink
v0.2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
cgisky1980 committed Oct 18, 2023
1 parent a36d374 commit 6d48a47
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 166 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@



[English](README.md) | [中文](README_zh.md) | [日本語](README_jp.md)
[English](README.md) | [中文](README_zh.md)

---

Expand Down
152 changes: 0 additions & 152 deletions README_jp.md

This file was deleted.

2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![All Contributors](https://img.shields.io/badge/all_contributors-4-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

[English](README.md) | [中文](README_zh.md) | [日本語](README_jp.md)
[English](README.md) | [中文](README_zh.md)


---
Expand Down
42 changes: 30 additions & 12 deletions convert_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,44 @@
from safetensors.torch import load_file, save_file

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='Path to input pth model')
parser.add_argument('--output', type=str, default='./converted.st',
help='Path to output safetensors model')
parser.add_argument("--input", type=str, help="Path to input pth model")
parser.add_argument(
"--output",
type=str,
default="./converted.st",
help="Path to output safetensors model",
)
args = parser.parse_args()


def convert_file(
pt_filename: str,
sf_filename: str,
):
def rename_key(rename, name):
for k, v in rename.items():
if k in name:
name = name.replace(k, v)
return name


def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]

loaded = {k: v.clone().half() for k, v in loaded.items()}
for k, v in loaded.items():
print(f'{k}\t{v.shape}\t{v.dtype}')
# for k, v in loaded.items():
# print(f'{k}\t{v.shape}\t{v.dtype}')

# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
for k, v in loaded.items():
for transpose_name in transpose_names:
if transpose_name in k:
loaded[k] = v.transpose(0, 1)
loaded = {rename_key(rename, k).lower(): v.contiguous()
for k, v in loaded.items()}

for k, v in loaded.items():
print(f"{k}\t{v.shape}\t{v.dtype}")

dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
Expand All @@ -40,5 +57,6 @@ def convert_file(


if __name__ == "__main__":
convert_file(args.input, args.output)
print(f"Saved to {args.output}")
convert_file(args.input, args.output, ["lora_A"], {
"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"})
print(f"Saved to {args.output}")

0 comments on commit 6d48a47

Please sign in to comment.