From e01ec30dc85186451043117d069b24dbd29a2936 Mon Sep 17 00:00:00 2001 From: Tlntin Date: Fri, 22 Dec 2023 20:18:16 +0800 Subject: [PATCH] fixup a bug for rope_base by assigned. --- examples/qwen/build.py | 4 +- examples/qwen/test/test_dynamic_ntk.py | 227 ++++++++++++++++++ examples/qwen/test/test_logn.py | 196 +++++++++++++++ examples/qwen/{ => test}/test_rms_norm.py | 0 .../{ => test}/test_smooth_quant_rms_norm.py | 0 examples/qwen/weight.py | 7 +- 6 files changed, 429 insertions(+), 5 deletions(-) create mode 100644 examples/qwen/test/test_dynamic_ntk.py create mode 100644 examples/qwen/test/test_logn.py rename examples/qwen/{ => test}/test_rms_norm.py (100%) rename examples/qwen/{ => test}/test_smooth_quant_rms_norm.py (100%) diff --git a/examples/qwen/build.py b/examples/qwen/build.py index 321daf11..d5fc7041 100644 --- a/examples/qwen/build.py +++ b/examples/qwen/build.py @@ -431,7 +431,7 @@ def parse_arguments(): args.vocab_size = hf_config.vocab_size args.hidden_act = "silu" args.kv_channels = hf_config.kv_channels - args.rotary_emb_base = hf_config.rotary_emb_base + args.rotary_base = hf_config.rotary_emb_base args.seq_length = hf_config.seq_length assert ( args.use_gpt_attention_plugin is not None @@ -581,7 +581,7 @@ def build_rank_engine( # args.world_size, max_position_embeddings=args.n_positions, kv_channels=args.kv_channels, - rotary_emb_base=args.rotary_emb_base, + rotary_base=args.rotary_base, dtype=args.dtype, multi_query_mode=multi_query_mode, ) diff --git a/examples/qwen/test/test_dynamic_ntk.py b/examples/qwen/test/test_dynamic_ntk.py new file mode 100644 index 00000000..f5a5ee05 --- /dev/null +++ b/examples/qwen/test/test_dynamic_ntk.py @@ -0,0 +1,227 @@ +import unittest +from collections import OrderedDict +import numpy as np +import torch +from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, CreateConfig, Profile +import tensorrt_llm +from tensorrt_llm import Tensor +import math +import tensorrt as trt +import numpy as np +from tensorrt_llm.layers import Embedding +from tensorrt_llm import str_dtype_to_trt +from parameterized import parameterized +from tensorrt_llm.functional import ( + Tensor, shape, concat, constant, arange, outer, unary, + partial, expand, elementwise_binary, shape, pow, cos, sin, slice, maximum +) +log = partial(unary, op=trt.UnaryOperation.LOG) +ceil = partial(unary, op=trt.UnaryOperation.CEIL) +div = partial(elementwise_binary, op=trt.ElementWiseOperation.DIV) +gt = partial(elementwise_binary, op=trt.ElementWiseOperation.GREATER) + + + +class RotaryEmbedding(tensorrt_llm.Module): + def __init__(self, per_head_dim=128, seq_length=8192, base=10000.0) -> None: + self.per_head_dim = per_head_dim + self.seq_length = seq_length + self.base = base + super().__init__() + # self.position_embedding_cos = Embedding( + # seq_length, + # per_head_dim, + # dtype=trt.float32 + # ) + # self.position_embedding_sin = Embedding( + # seq_length, + # per_head_dim, + # dtype=trt.float32 + # ) + + def forward(self, input_ids): + # implement for old + batch_size = shape(input_ids, 0) + input_len = shape(input_ids, 1) + # pytorch impl + # context_value = math.log(true_seq_len / self.seq_length, 2) + 1 + # ntk_alpha = 2 ** math.ceil(context_value) - 1 + # ntk_alpha = max(ntk_alpha, 1) + + # trt impl + # with tensorrt_llm.precision("float32"): + context_value = log(input_len.cast(trt.float32) / float(self.seq_length)) / math.log(2) + 1.0 + ntk_alpha = pow(constant(np.array(2, dtype=np.float32)), ceil(context_value)) - 1.0 + + ntk_alpha = maximum(ntk_alpha, constant(np.array(1.0, dtype=np.float32))) + base = constant(np.array(self.base, dtype=np.float32)) + base = base * pow(ntk_alpha, (self.per_head_dim / (self.per_head_dim - 2))) + temp1 = constant(np.arange(0, self.per_head_dim, 2, dtype=np.float32) / self.per_head_dim) + temp2 = pow(base, temp1) + inv_freq = div( + constant(np.array(1, dtype=np.float32)), + temp2 + ) + # temp_length = f_max(2 * input_len, 16) + seq = arange(constant(np.array(0, dtype=np.int32)), input_len * 2, dtype="int32") + # with tensorrt_llm.precision("float32"): + freqs = outer(seq.cast(trt.float32), inv_freq) + emb = concat([freqs, freqs], dim=1) + # emb = rearrange(emb, "n d -> 1 n 1 d") + emb = emb.view(concat([1, input_len * 2, 1, self.per_head_dim])) + emb = expand(emb, concat([batch_size, input_len * 2, 1, self.per_head_dim])) + + # with tensorrt_llm.precision("float32"): + # cos, sin = emb.cos(), emb.sin() + cos_res = cos(emb) + sin_res = sin(emb) + # position_embedding_cos = cos[:, :input_len] + # position_embedding_sin = sin[:, :input_len] + position_embedding_cos = slice( + input=cos_res, + starts=concat([0, 0, 0, 0]), + sizes=concat([batch_size, input_len, 1, self.per_head_dim]), + ) + position_embedding_sin = slice( + input=sin_res, + starts=concat([0, 0, 0, 0]), + sizes=concat([batch_size, input_len, 1, self.per_head_dim]), + ) + + # self.register_network_output("my_cos", identity_op(position_embedding_cos)) + # self.register_network_output("my_sin", identity_op(position_embedding_sin)) + # expand_dims(position_embedding_cos, [batch_size, 1, 1, 1]) + rotary_pos_emb = [ + (position_embedding_cos, position_embedding_sin), + (position_embedding_cos, position_embedding_sin), + ] + return rotary_pos_emb + + + +class TestFunctional(unittest.TestCase): + + per_head_dim = 128 + seq_length = 8192 + base = 10000.0 + vocab_size = 151936 + + def setUp(self): + tensorrt_llm.logger.set_level('error') + + @parameterized.expand([('float32', 9886), ('float32', 1886), ('float16', 1886), ('float16', 9886)]) + def test_case(self, dtype, input_length): + + + def test_trt(feed_dict: dict): + # construct trt network + builder = tensorrt_llm.Builder() + net = builder.create_network() + with tensorrt_llm.net_guard(net): + input_ids = Tensor( + name='input_ids', + shape=[-1, -1], + dtype=trt.int32, + dim_range=OrderedDict([ + ("batch_size", [[1, 1, 1]]), + ("seq_length", [[1, 10 * 1024, 32 * 1024]]) + ]) + ) + # position_ids = Tensor( + # name='position_ids', + # shape=[-1, -1], + # dtype=trt.int32, + # dim_range=OrderedDict([ + # ("batch_size", [[1, 1, 1]]), + # ("seq_length", [[1, 10 * 1024, 32 * 1024]]) + # ]) + # ) + model = RotaryEmbedding(per_head_dim=self.per_head_dim, seq_length=self.seq_length) + outputs = model.forward(input_ids=input_ids) + # net._mark_output(outputs[0][0], 'cos', tensorrt_llm.str_dtype_to_trt(dtype)) + # net._mark_output(outputs[0][1], 'sin', tensorrt_llm.str_dtype_to_trt(dtype)) + net._mark_output(outputs[0][0], 'cos', trt.float32) + net._mark_output(outputs[0][1], 'sin', trt.float32) + + for k, v in model.named_network_outputs(): + # net._mark_output(v, k, tensorrt_llm.str_dtype_to_trt(dtype)) + net._mark_output(v, k, trt.float32) + # for build and run + profile = Profile().add( + "input_ids", min=(1, 1), opt=(1, 1), max=(2, 16 * 1024) + ) + build_engine = EngineFromNetwork( + (builder.trt_builder, net.trt_network), + config=CreateConfig( + fp16=(dtype == 'float16'), + precision_constraints="obey", + profiles=[profile] + ) + ) + with TrtRunner(build_engine) as runner: + outputs = runner.infer(feed_dict=feed_dict) + return outputs + + def test_pytorch(input_tensor: torch.tensor): + pt_input_len = input_tensor.shape[1] + # upper for old + # lower for pure pytorch for fp32 consistency(code in above used fp64 by python) + pt_context_value = math.log(pt_input_len / self.seq_length, 2) + 1 + # pt_context_value = torch.log(torch.Tensor([input_seq_len * 1. / self.seq_length]).cuda()) / torch.log(torch.Tensor([2.]).cuda()) + 1 + + pt_ntk_alpha = 2 ** math.ceil(pt_context_value) - 1 + # pt_ntk_alpha = torch.Tensor([2]).cuda() ** torch.ceil(pt_context_value) - 1 + + pt_ntk_alpha = max(pt_ntk_alpha, 1.0) + + pt_ntk_alpha = pt_ntk_alpha ** (self.per_head_dim / (self.per_head_dim - 2)) + + pt_base = torch.Tensor([self.base]).cuda() + pt_base = pt_base * pt_ntk_alpha + pt_temp1 = (torch.arange(0, self.per_head_dim, 2).float() / self.per_head_dim).cuda() + pt_temp2 = torch.pow(pt_base, pt_temp1) # base ** temp1 + pt_inv_freq = 1.0 / pt_temp2 + pt_seq = torch.arange(0, pt_input_len * 2).int().cuda() + pt_freqs = torch.outer(pt_seq.type_as(pt_inv_freq), pt_inv_freq) + pt_emb = torch.cat((pt_freqs, pt_freqs), dim=-1) + # emb = rearrange(emb, "n d -> 1 n 1 d") + pt_emb = pt_emb.unsqueeze(0).unsqueeze(2) + pt_cos, pt_sin = pt_emb.cos(), pt_emb.sin() + pt_cos = pt_cos[:, :pt_input_len] + pt_sin = pt_sin[:, :pt_input_len] + print("pt_cos shpae/mean/sum/dtype", pt_cos.shape, pt_cos.mean(), pt_cos.sum(), pt_cos.dtype) + print("pt_sin shpae/mean/sum/dtype", pt_sin.shape, pt_sin.mean(), pt_sin.sum(), pt_sin.dtype) + return pt_cos, pt_sin + + + + pt_batch_size = 1 + # pt_input_len = 9886 + pt_input_len = input_length + print("\ndtype", dtype, "input_length", input_length) + input_tensor = torch.randint(1, self.vocab_size, [pt_batch_size, pt_input_len], dtype=torch.int32) + # position_tensor = torch.arange(0, pt_input_len, dtype=torch.int32).unsqueeze(0).expand([pt_batch_size, pt_input_len]) + # print("position_tensor shape", position_tensor.shape) + pt_cos, pt_sin = test_pytorch(input_tensor) + outputs = test_trt( + feed_dict={ + "input_ids": input_tensor.numpy(), + } + ) + + # import pdb; pdb.set_trace() + + # np.testing.assert_allclose(ntk_alpha.cpu().numpy(), outputs['ntk_alpha'], rtol=0, atol=0) + # np.testing.assert_allclose(base.cpu().numpy(), outputs['base'], rtol=0, atol=0) + # np.testing.assert_allclose(temp1.cpu().numpy(), outputs['temp1'], rtol=0, atol=0) + # np.testing.assert_allclose(temp2.cpu().numpy(), outputs['temp2'], rtol=0, atol=0) + # np.testing.assert_allclose(seq.cpu().numpy(), outputs['seq'], rtol=1e-9, atol=1e-9) + # np.testing.assert_allclose(inv_freq.cpu().numpy(), outputs['inv_freq'], rtol=1e-9, atol=1e-9) + # np.testing.assert_allclose(pt_freqs.cpu().numpy(), outputs['freqs'], rtol=1e-9, atol=1e-9) + print("cos shpae/mean/sum/dtype", outputs["cos"].shape, outputs["cos"].mean(), outputs["cos"].sum(), outputs["cos"].dtype) + print("sin shpae/mean/sum/dtype", outputs["sin"].shape, outputs["sin"].mean(), outputs["sin"].sum(), outputs["sin"].dtype) + np.testing.assert_allclose(pt_cos.cpu().numpy(), outputs['cos'], rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(pt_sin.cpu().numpy(), outputs['sin'], rtol=1e-5, atol=1e-5) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/examples/qwen/test/test_logn.py b/examples/qwen/test/test_logn.py new file mode 100644 index 00000000..630fc440 --- /dev/null +++ b/examples/qwen/test/test_logn.py @@ -0,0 +1,196 @@ +import unittest + +import numpy as np +import torch +from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, CreateConfig +import tensorrt_llm +from tensorrt_llm import Tensor +import math +import tensorrt as trt +import numpy as np +from parameterized import parameterized +from tensorrt_llm.parameter import Parameter +from tensorrt_llm.functional import ( + Tensor, shape, concat, constant, arange, outer, unary, + partial, expand, elementwise_binary, shape, pow, cos, sin, slice, expand_dims_like, repeat_interleave, str_dtype_to_trt +) +log = partial(unary, op=trt.UnaryOperation.LOG) +ceil = partial(unary, op=trt.UnaryOperation.CEIL) +div = partial(elementwise_binary, op=trt.ElementWiseOperation.DIV) + + +class MyLogn(tensorrt_llm.Module): + def __init__(self, dtype, seq_length, head_size, per_head_dim) -> None: + super().__init__() + self.dtype = dtype + self.seq_length = seq_length + self.head_size = head_size + self.per_head_dim = per_head_dim + logn_array = np.array([ + np.log(i) / np.log(self.seq_length) if i > self.seq_length else 1 + for i in range(1, 32768) + ], + dtype=np.float32 + ).reshape(1, -1, 1, 1) + self.logn_tensor = Parameter( + value=logn_array, + dtype=trt.float32, + shape=[1, 32767, 1, 1], + ) + + def forward(self, key, query): + seq_start = slice(shape(key), [1], [1]) - slice(shape(query), [1], [1]) + seq_end = slice(shape(key), [1], [1]) + + logn_shape = self.logn_tensor.value.shape + logn_tensor = slice( + input=self.logn_tensor.value, + starts=concat([0, seq_start, 0, 0]), + sizes=concat([logn_shape[0], seq_end - seq_start, logn_shape[2], logn_shape[3]]), + ) + # logn_tensor2 = repeat_interleave(logn_tensor, self.head_size, 2) + # logn_tensor2 = repeat_interleave(logn_tensor2, self.per_head_dim, 3) + logn_tensor2 = expand( + logn_tensor, + concat([logn_shape[0], seq_end - seq_start, self.head_size, self.per_head_dim]) + ) + query2 = query.cast(trt.float32) * logn_tensor2 + query2 = query2.cast(self.dtype) + return [logn_tensor2, query2] + + + + +class TestFunctional(unittest.TestCase): + + head_size = 16 + per_head_dim = 128 + seq_length = 8192 + base = 10000.0 + dtype = 'float16' + + + def setUp(self): + tensorrt_llm.logger.set_level('error') + + @parameterized.expand([('float32', 9886), ('float32', 1886), ("float16", 9886), ("float16", 1886)]) + def test_case(self, dtype, input_length): + self.dtype = dtype + batch_size = 1 + # input_seq_len = 13727 + input_seq_len = input_length + print("\ndtype", dtype, "input_length", input_length) + if dtype == "float32": + pt_key = torch.rand( + [batch_size, input_seq_len, self.head_size, self.per_head_dim], + dtype=torch.float32 + ) + pt_query = torch.rand( + [batch_size, input_seq_len, self.head_size, self.per_head_dim], + dtype=torch.float32 + ) + else: + pt_key = torch.rand( + [batch_size, input_seq_len, self.head_size, self.per_head_dim], + dtype=torch.float16 + ) + pt_query = torch.rand( + [batch_size, input_seq_len, self.head_size, self.per_head_dim], + dtype=torch.float16 + ) + + + def test_trt(feed_dict: dict): + builder = tensorrt_llm.Builder() + net = builder.create_network() + with tensorrt_llm.net_guard(net): + key = Tensor(name='key', + shape=pt_key.shape, + dtype=tensorrt_llm.str_dtype_to_trt(self.dtype)) + + query = Tensor(name='query', + shape=pt_query.shape, + dtype=tensorrt_llm.str_dtype_to_trt(self.dtype)) + model = MyLogn( + dtype=dtype, + seq_length=self.seq_length, + head_size=self.head_size, + per_head_dim=self.per_head_dim, + ) + outputs = model.forward(query=query, key=key) + net._mark_output(outputs[0], 'logn', str_dtype_to_trt(dtype)) + net._mark_output(outputs[1], 'query_output', str_dtype_to_trt(dtype)) + # net._mark_output(outputs[0], 'logn', trt.float32) + # net._mark_output(outputs[1], 'query_output', trt.float32) + + for k, v in model.named_network_outputs(): + net._mark_output(v, k, tensorrt_llm.str_dtype_to_trt(dtype)) + # net._mark_output(v, k, trt.float32) + # for new + build_engine = EngineFromNetwork( + (builder.trt_builder, net.trt_network), + config=CreateConfig( + fp16=(dtype == 'float16'), + precision_constraints="obey", + ) + ) + with TrtRunner(build_engine) as runner: + outputs = runner.infer(feed_dict=feed_dict) + # {"key": pt_key.numpy(), "query": pt_query.numpy()} + return outputs + + def test_pytorch(pt_query, pt_key): + # torch impl + pt_logn_list = [ + math.log(i, self.seq_length) if i > self.seq_length else 1 + for i in range(1, 32768) + ] + pt_logn_tensor = torch.tensor(pt_logn_list, dtype=torch.float32)[None, :, None, None] + pt_seq_start = pt_key.size(1) - pt_query.size(1) + pt_seq_end = pt_key.size(1) + pt_logn_tensor = pt_logn_tensor[:, pt_seq_start: pt_seq_end, :, :].type_as(pt_query) + pt_logn_tensor2 = pt_logn_tensor.expand_as(pt_query) + pt_logn_tensor2 = pt_logn_tensor2.to(torch.float32) + raw_type = pt_query.dtype + pt_query2 = pt_query.to(torch.float32) * pt_logn_tensor2 + pt_logn_tensor2 = pt_logn_tensor2.to(raw_type) + pt_query2 = pt_query2.to(raw_type) + print( + "pt_logn2 shpae/mean/sum/dtype", + pt_logn_tensor2.shape, + pt_logn_tensor2.to(torch.float32).mean().item(), + pt_logn_tensor2.to(torch.float32).sum().item(), + pt_logn_tensor2.dtype + ) + print( + "pt_query2 shpae/mean/sum/dtype", + pt_query2.shape, + pt_query2.to(torch.float32).mean(), + pt_query2.to(torch.float32).sum(), + pt_query2.dtype + ) + return [pt_logn_tensor2, pt_query2] + + + (pt_logn2, pt_query2) = test_pytorch(pt_query=pt_query, pt_key=pt_key) + outputs = test_trt(feed_dict={"key": pt_key.numpy(), "query": pt_query.numpy()}) + rtol = atol = 1e-9 + print( + "logn shpae/mean/sum/dtype", + outputs['logn'].shape, + outputs['logn'].astype(np.float32).mean(), + outputs['logn'].astype(np.float32).sum(), + outputs['logn'].dtype + ) + print( + "query_output shpae/mean/sum/dtype", + outputs['query_output'].shape, + outputs['query_output'].astype(np.float32).mean(), + outputs['query_output'].astype(np.float32).sum(), + outputs['query_output'].dtype + ) + np.testing.assert_allclose(pt_logn2.cpu().numpy(), outputs['logn'], rtol=rtol, atol=atol) + np.testing.assert_allclose(pt_query2.cpu().numpy(), outputs['query_output'], rtol=rtol, atol=atol) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/examples/qwen/test_rms_norm.py b/examples/qwen/test/test_rms_norm.py similarity index 100% rename from examples/qwen/test_rms_norm.py rename to examples/qwen/test/test_rms_norm.py diff --git a/examples/qwen/test_smooth_quant_rms_norm.py b/examples/qwen/test/test_smooth_quant_rms_norm.py similarity index 100% rename from examples/qwen/test_smooth_quant_rms_norm.py rename to examples/qwen/test/test_smooth_quant_rms_norm.py diff --git a/examples/qwen/weight.py b/examples/qwen/weight.py index d636e420..35f797d2 100644 --- a/examples/qwen/weight.py +++ b/examples/qwen/weight.py @@ -501,7 +501,7 @@ def load_from_hf_qwen( # rank=0, # tensor_parallel=1, max_position_embeddings=8192, - rotary_emb_base=10000, + rotary_base=10000, kv_channels=128, dtype="float32", multi_query_mode=False, @@ -544,7 +544,7 @@ def load_from_hf_qwen( torch_dtype = str_dtype_to_torch(dtype) # set for rope embedding - # inv_freq = 1.0 / (rotary_emb_base ** ( + # inv_freq = 1.0 / (rotary_base ** ( # torch.arange(0, kv_channels, 2).float() / kv_channels) # ) # value_table = torch.matmul( @@ -808,7 +808,8 @@ def preprocess_groupwise_weight_params( qkv_part = qkv_part.reshape(model_emb, 3, q_emb) split_qkv = split(qkv_part, mapping.tp_size, mapping.rank, dim=2) split_qkv = split_qkv.reshape(model_emb, 3 * (q_emb // mapping.tp_size)) - split_qkv = torch.from_numpy(split_qkv) + if isinstance(split_qkv, np.ndarray): + split_qkv = torch.from_numpy(split_qkv) # dype: int32, int32, float16 split_qkv_suf.append(split_qkv)