Skip to content

Commit

Permalink
update test code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 29, 2024
1 parent 7ae6095 commit 92180e8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
16 changes: 15 additions & 1 deletion export/test_onnx_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
project_dir = os.path.dirname(now_dir)

parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
help="float16 or float32",
choices=["float16", "float32"],
default="float32",
)
parser.add_argument(
'--hf_model_dir',
type=str,
Expand All @@ -23,6 +30,13 @@
)
args = parser.parse_args()

if args.dtype == "float16":
np_dtype = np.float16
elif args.dtype == "float32":
np_dtype = np.float32
else:
raise Exception("not support dtype, only support float16/float32")


def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
return np.zeros(
Expand All @@ -34,7 +48,7 @@ def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
kv_cache_length,
config.hidden_size // config.num_attention_heads
],
dtype=np.float16
dtype=np_dtype
)


Expand Down
39 changes: 31 additions & 8 deletions export/test_pytorch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,42 @@
from modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2Config

device = "cpu"
now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)

parser = argparse.ArgumentParser()
parser.add_argument(
"--device_str",
type=str,
choices=["npu", "cuda", "cpu"],
help="support npu, cuda, cpu",
default="cpu",
)
parser.add_argument(
"--dtype",
type=str,
help="float16 or float32",
choices=["float16", "float32"],
default="float32",
)
parser.add_argument(
'--hf_model_dir',
type=str,
help="model and tokenizer path, only support huggingface model",
default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct")
)


args = parser.parse_args()
device_str = args.device_str
if device_str == "cpu" and args.dtype == "float16":
raise Exception("CPU not support float16")
if args.dtype == "float16":
torch_dtype = torch.float16
elif args.dtype == "float32":
torch_dtype = torch.float32
else:
raise Exception("not support dtype, only support float16/float32")


def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
Expand All @@ -29,8 +52,8 @@ def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
kv_cache_length,
config.hidden_size // config.num_attention_heads
],
dtype=torch.float16
).to(device)
dtype=torch_dtype
).to(device_str)


def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024):
Expand Down Expand Up @@ -58,22 +81,22 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size
)
"""
cache = kv_cache[:, :, :, :, :past_kv_size]
mask = torch.ones((1, past_kv_size + seq_len), dtype=torch.long).to(device)
mask = torch.ones((1, past_kv_size + seq_len), dtype=torch.long).to(device_str)
mask[:, real_kv_size: past_kv_size] = 0
pos_id = torch.arange(
input_pos,
input_pos + seq_len,
dtype=torch.long
).reshape(1, -1).to(device)
).reshape(1, -1).to(device_str)
return cache, mask, pos_id


tokenizer = Qwen2Tokenizer.from_pretrained(args.hf_model_dir)
model_config = Qwen2Config.from_pretrained(args.hf_model_dir)
model = Qwen2ForCausalLM.from_pretrained(
args.hf_model_dir,
torch_dtype=torch.float16
).to(device)
torch_dtype=torch_dtype
).to(device_str)
prompt = "你好"
system_prompt: str = "You are a helpful assistant."
history = []
Expand All @@ -89,7 +112,7 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size
print("raw_text", text)
input_ids = tokenizer(
[text], return_tensors="pt"
)["input_ids"].to(device)
)["input_ids"].to(device_str)
print("input_ids", input_ids)
kv_cache1 = create_kv_cache(model_config)
now_kv_cache, attn_mask, position_ids = get_inputs(kv_cache1, 2, )
Expand Down

0 comments on commit 92180e8

Please sign in to comment.