diff --git a/README.md b/README.md index edead901..24058be8 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ cd tensorrt_llm_july-release-v1/examples/qwen/ ``` -7. 将Huggingface格式的数据转成FT(FastTransformer)需要的数据格式(非必选,不convert直接build也是可以的,两种方式都兼容,直接build更省空间,但是不支持smooth quant) +7. 将Huggingface格式的数据转成FT(FastTransformer)需要的数据格式(非必选,不convert直接build也是可以的,两种方式都兼容,直接build更省空间,但是不支持smooth quant; 运行该代码默认是需要加载cuda版huggingface模型再转换,所以低于24G显存的显卡建议跳过这步。) ```bash python3 hf_qwen_convert.py @@ -83,7 +83,7 @@ 8. 修改编译参数(可选) - - 默认编译参数,包括batch_size, max_input_len, max_new_tokens都存放在`default_config.py`中 + - 默认编译参数,包括batch_size, max_input_len, max_new_tokens, seq_length都存放在`default_config.py`中 - 对于24G显存用户,直接编译即可,默认是fp16数据类型,max_batch_size=2 - 对于低显存用户,可以降低max_batch_size=1,或者继续降低max_input_len, max_new_tokens diff --git a/requirements-dev.txt b/requirements-dev.txt index 2496650e..1d76afa1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,7 +8,7 @@ onnx==1.12.0 mpi4py tensorrt>=8.6.0 numpy -cuda-python==12.1.0 +cuda-python==12.2.0 mypy pytest-cov pytest-xdist diff --git a/tensorrt_llm_july-release-v1/examples/qwen/build.py b/tensorrt_llm_july-release-v1/examples/qwen/build.py index db610fe9..b39f562c 100644 --- a/tensorrt_llm_july-release-v1/examples/qwen/build.py +++ b/tensorrt_llm_july-release-v1/examples/qwen/build.py @@ -323,7 +323,7 @@ def build_rank_engine(builder: Builder, num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, - seq_length=2048, + seq_length=default_config.seq_length, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, @@ -360,7 +360,8 @@ def build_rank_engine(builder: Builder, QuantMode.use_weight_only(use_int4_weights=True) ) - if args.hf_model_dir is not None and args.ft_dir_path is None: + if args.hf_model_dir is not None and \ + (args.ft_dir_path is None or not os.path.exists(args.ft_dir_path)): logger.info(f'Loading HF QWen ... from {args.hf_model_dir}') tik = time.time() hf_qwen = AutoModelForCausalLM.from_pretrained( @@ -381,7 +382,6 @@ def build_rank_engine(builder: Builder, rank, args.world_size, max_position_embeddings=args.n_positions, - seq_length=args.max_input_len, kv_channels=args.kv_channels, rotary_emb_base=args.rotary_emb_base, dtype=args.dtype, diff --git a/tensorrt_llm_july-release-v1/examples/qwen/default_config.py b/tensorrt_llm_july-release-v1/examples/qwen/default_config.py index accebae6..9d9a2d74 100644 --- a/tensorrt_llm_july-release-v1/examples/qwen/default_config.py +++ b/tensorrt_llm_july-release-v1/examples/qwen/default_config.py @@ -24,6 +24,11 @@ class DefaultConfig: # Maximum number of generate new tokens. max_new_tokens = 2048 + # Maximum sequence length. + # for Qwen-7B-Chat V1.0, the seq_length is 2048 + # for Qwen-7B-Chat V1.1, the seq_length is 8192 + # for Qwen-14B-Chat, the seq_length is 8192 + seq_length = 2048 # Top p for sampling. top_p = 0.5 diff --git a/tensorrt_llm_july-release-v1/examples/qwen/weight.py b/tensorrt_llm_july-release-v1/examples/qwen/weight.py index 22dc4bda..a05fc90c 100644 --- a/tensorrt_llm_july-release-v1/examples/qwen/weight.py +++ b/tensorrt_llm_july-release-v1/examples/qwen/weight.py @@ -391,7 +391,6 @@ def load_from_hf_qwen(tensorrt_llm_qwen: QWenForCausalLM, hf_qwen, rank=0, tensor_parallel=1, - seq_length=2048, max_position_embeddings=8192, rotary_emb_base=10000, kv_channels=128, @@ -463,7 +462,7 @@ def load_from_hf_qwen(tensorrt_llm_qwen: QWenForCausalLM, if layer_idx is None: continue idx = int(layer_idx) - if idx >= tensorrt_llm_qwen.num_layers: + if idx >= tensorrt_llm_qwen._num_layers: continue if 'ln_1.weight' in k: tensorrt_llm_qwen.layers[idx].ln_1.weight.value = v