diff --git a/examples/offline_inference_gsaondevice.py b/examples/offline_inference_gsaondevice.py index 89e8ba7b5..5c073c995 100644 --- a/examples/offline_inference_gsaondevice.py +++ b/examples/offline_inference_gsaondevice.py @@ -7,11 +7,6 @@ from transformers import AutoTokenizer -# Third Party -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig -from vllm.engine.arg_utils import EngineArgs - from ucm.logger import init_logger logger = init_logger(__name__) @@ -61,6 +56,14 @@ def setup_environment_variables(): tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) +# ENABLE_SPARSE must be set before import vllm to make sure monkey patch works +setup_environment_variables() +# Third Party +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs + + @contextlib.contextmanager def build_llm_with_uc(module_path: str, name: str, model: str): ktc = KVTransferConfig( @@ -124,7 +127,6 @@ def print_output( def main(): module_path = "ucm.integration.vllm.ucm_connector" name = "UCMConnector" - setup_environment_variables() def get_prompt(prompt): messages = [ diff --git a/ucm/integration/vllm/patch/apply_patch.py b/ucm/integration/vllm/patch/apply_patch.py index f3abf69d0..e58ff7db7 100644 --- a/ucm/integration/vllm/patch/apply_patch.py +++ b/ucm/integration/vllm/patch/apply_patch.py @@ -129,9 +129,11 @@ def apply_all_patches() -> None: # vllm patches match version: case "0.11.0": + logger.info("UCM patching vllm for pc...") import ucm.integration.vllm.patch.v0110.vllm.pc_patch if ENABLE_SPARSE: + logger.info("UCM patching vllm for sparse...") import ucm.integration.vllm.patch.v0110.vllm.sparse_patch case _: pass @@ -140,9 +142,11 @@ def apply_all_patches() -> None: ascend_version = get_vllm_ascend_version() match ascend_version: case "0.11.0": + logger.info("UCM patching vllm-ascend for pc...") import ucm.integration.vllm.patch.v0110.vllm_ascend.pc_ascend_patch if ENABLE_SPARSE: + logger.info("UCM patching vllm-ascend for sparse...") import ucm.integration.vllm.patch.v0110.vllm_ascend.sparse_ascend_patch case _: pass