diff --git a/docs/configuration/env_vars.md b/docs/configuration/env_vars.md index 4f37c185..59c90ceb 100644 --- a/docs/configuration/env_vars.md +++ b/docs/configuration/env_vars.md @@ -20,15 +20,15 @@ - `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`. The default is `min_tokens`. - `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`. The default is `max_bs`. - `VLLM_EXPONENTIAL_BUCKETING`: if `true`, enables exponential bucket spacing instead of linear. The default is `true`. -- `VLLM_{phase}_{dim}_BUCKET_{param}`: collection of 12 environment variables configuring ranges of bucketing mechanism (linear bucketing only). +- `VLLM_PROMPT_BS_BUCKET_MAX`: `(VLLM_PROMPT_BS_BUCKET_MAX * query) <=max_num_batched_tokens`- prefill batch size max. The default is`1`. +- `VLLM_{phase}_{dim}_BUCKET_{param}`: collection of environment variables configuring ranges of bucketing mechanism (linear bucketing only). - `{phase}` is either `PROMPT` or `DECODE` - - `{dim}` is either `BS`, `SEQ` or `BLOCK` + - `{dim}` is either `BS`, `SEQ`, `CTX` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` - Default values: - Prompt: - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): `min(max_num_seqs, 64)` - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): `block_size` - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): `1024` diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index 27c57e77..cbe90eeb 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -17,8 +17,9 @@ def check_for_user_flags(self, phase): params = ['min', 'step', 'max', 'limit'] env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for dim in dim for p in params] user_flags = [] + overwritten_user_flags = ["VLLM_PROMPT_BS_BUCKET_MAX"] for e in env_vars: - if getattr(get_config(), e) is not None: + if getattr(get_config(), e) is not None and e not in overwritten_user_flags: user_flags.append(e) if len(user_flags) > 0: logger().warning("*******************************************************")