Skip to content

Commit

Permalink
change default seq_length to 8192
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Dec 18, 2023
1 parent e262c12 commit 29b0832
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
8 changes: 4 additions & 4 deletions examples/qwen/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def parse_arguments():
parser.add_argument("--ffn_dim_multiplier", type=int, default=1)
parser.add_argument("--inter_size", type=int, default=11008)
parser.add_argument("--hidden_act", type=str, default="silu")
parser.add_argument("--seq_length", type=int, default=8192)
parser.add_argument(
"--max_batch_size", type=int, default=default_config.trt_max_batch_size
)
Expand Down Expand Up @@ -420,9 +421,7 @@ def parse_arguments():
args.hf_model_dir,
trust_remote_code=True,
)
args.inter_size = (
hf_config.intermediate_size
) # override the inter_size for QWen
args.inter_size = hf_config.intermediate_size # override the inter_size for QWen
args.n_embd = hf_config.hidden_size
args.n_head = hf_config.num_attention_heads
if hasattr(hf_config, "num_key_value_heads"):
Expand All @@ -433,6 +432,7 @@ def parse_arguments():
args.hidden_act = "silu"
args.kv_channels = hf_config.kv_channels
args.rotary_emb_base = hf_config.rotary_emb_base
args.seq_length = hf_config.seq_length
assert (
args.use_gpt_attention_plugin is not None
), "QWen must use gpt attention plugin"
Expand Down Expand Up @@ -493,7 +493,7 @@ def build_rank_engine(
num_heads=args.n_head,
num_kv_heads=args.n_kv_head,
hidden_size=args.n_embd,
seq_length=default_config.seq_length,
seq_length=args.seq_length,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
Expand Down
10 changes: 2 additions & 8 deletions examples/qwen/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,18 @@ class DefaultConfig:
hf_max_batch_size = 1

# Maximum batch size for TRT-LLM backend.
trt_max_batch_size = 2
trt_max_batch_size = 1

# choice the model format, base or chat
# choices=["chatml", "raw"],
chat_format = "chatml"

# Maximum input length.
max_input_len = 2048
max_input_len = 6144

# Maximum number of generate new tokens.
max_new_tokens = 2048

# Maximum sequence length.
# for Qwen-7B-Chat V1.0, the seq_length is 2048
# for Qwen-7B-Chat V1.1, the seq_length is 8192
# for Qwen-14B-Chat, the seq_length is 2048
seq_length = 2048

# Top p for sampling.
top_p = 0.8

Expand Down

0 comments on commit 29b0832

Please sign in to comment.