diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 09cb7114c..c43ae2319 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -253,3 +253,71 @@ jobs: run: | sudo rm -f /workspace/test_results/${{ steps.version.outputs.version }}-online-inference-sparse.xml || true + test-e2e-sparse-offline-gpu: + runs-on: ["gpu-test-in-docker"] + needs: lint-and-unit-tests + permissions: + checks: write + pull-requests: write + steps: + - name: Clean repo + run: | + if [ -d "${{github.workspace}}" ]; then + cd ${{github.workspace}} + rm -rf ./* + rm -rf .[!.]* + fi + - uses: actions/checkout@v4 + - name: Install Docker CLI + run: | + if ! command -v docker &> /dev/null; then + echo "Docker CLI not found, installing..." + sudo apt-get update + sudo apt-get install -y docker.io + else + echo "Docker CLI already installed" + fi + - name: Generate Docker Image Version + id: version + run: | + DATE=$(date +%Y%m%d) + SHORT_SHA=$(echo '${{ github.sha }}' | cut -c1-7) + REF_NAME=$(echo '${{ github.ref_name }}' | tr '/' '-') + VERSION="${REF_NAME}-${DATE}-${{ github.run_number }}-${SHORT_SHA}" + echo "version=${VERSION}" >> $GITHUB_OUTPUT + echo "Docker image version: ${VERSION}" + - name: Build + run: | + cd ${{github.workspace}} + sudo -E docker build --network=host \ + --build-arg http_proxy="${http_proxy:-}" \ + --build-arg https_proxy="${https_proxy:-}" \ + --build-arg ENABLE_SPARSE=true \ + -t ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} \ + -f ./docker/Dockerfile.vllm_gpu ./ + - name: Test E2E in Docker + run: | + sudo chmod -R 777 /workspace/test_results/ + sudo docker run --rm \ + --gpus all \ + --ipc=host \ + -v /home/models:/home/models \ + -v /home/yanzhao/pipeline_results:/workspace/test_results \ + ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} \ + -c "cd /workspace/unified-cache-management/test && pip install -r requirements.txt && python3 -m pytest -x --stage=1 --feature=offline_inference_sparse --junitxml=/workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml" + - name: Upload pytest results + uses: EnricoMi/publish-unit-test-result-action/linux@v2 + if: (!cancelled()) + with: + files: | + /workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml + check_name: Sparse attention test results + - name: Cleanup Docker Image + if: always() + run: | + sudo docker rmi ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} || true + - name: Cleanup Test Results + if: always() + run: | + sudo rm -f /workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml || true + diff --git a/test/common/common_inference_utils.py b/test/common/common_inference_utils.py index 8dddf7e10..b3f95a532 100644 --- a/test/common/common_inference_utils.py +++ b/test/common/common_inference_utils.py @@ -185,89 +185,108 @@ def deserialize_sample_params(json_str: str) -> Any: ) -def to_dict_for_serialization(obj: Any) -> Dict[str, Any]: - """Convert any object to dict for subprocess serialization. +def get_platform_specific_module(): + """Get platform-specific modules for inference. + + Returns: + SimpleNamespace with AutoTokenizer and SamplingParams + """ + from types import SimpleNamespace + + from transformers import AutoTokenizer + from vllm import SamplingParams + + # Create a namespace object + modules = SimpleNamespace() + modules.AutoTokenizer = AutoTokenizer + modules.SamplingParams = SamplingParams + + return modules - Supports: - - dataclass objects - - regular objects with __dict__ - - vLLM SamplingParams and other custom classes + +def match_any_answer(output: str, answers: List[str]) -> bool: + """Check if output matches any of the standard answers. Args: - obj: Object to serialize (dataclass, SamplingParams, etc.) + output: Generated output text + answers: List of acceptable answers Returns: - Dict with _type and _data fields for reconstruction + True if output matches any answer """ - import logging - from dataclasses import asdict, is_dataclass - - try: - # Try dataclass first - if is_dataclass(obj) and not isinstance(obj, type): - data = asdict(obj) - # Try __dict__ for regular objects - elif hasattr(obj, "__dict__"): - data = obj.__dict__.copy() - else: - raise ValueError(f"Cannot serialize object of type {type(obj)}") - - return { - "_type": f"{obj.__class__.__module__}.{obj.__class__.__name__}", - "_data": data, - } - except Exception as e: - logging.warning(f"Serialization failed for {type(obj)}: {e}") - raise + for answer in answers: + if remove_punc(output) == remove_punc(answer): + return True + return False -def from_dict_for_serialization(serialized: Dict[str, Any]) -> Any: - """Recreate object from serialized dict. +def remove_punc(text: str) -> str: + """Remove punctuation from text for comparison. Args: - serialized: Dict created by to_dict_for_serialization() + text: Text to remove punctuation from Returns: - Reconstructed object instance + Text without punctuation """ - import logging + import string + + text = text.strip() + if not text: + return "" + cn_punctuation = ( + "!?。。"#$%&'()*+,-/:;<=>[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—''‛" + "„‟…‧﹏." + ) + all_punctuation = set(string.punctuation + cn_punctuation) + return "".join(ch for ch in text if ch not in all_punctuation) - if "_type" not in serialized: - # Not a serialized object, return as-is - return serialized - type_str = serialized["_type"] - obj_data = serialized.get("_data", {}) +def match_sparse_answer(sparse_output: List[str], standard_answers: List[str]) -> bool: + """Check if sparse output matches standard answers after removing punctuation. - try: - # Parse module and class name - import importlib + Args: + sparse_output: List of generated outputs + standard_answers: List of expected answers - module_name, class_name = type_str.rsplit(".", 1) - module = importlib.import_module(module_name) - cls = getattr(module, class_name) + Returns: + True if outputs match after normalization + """ + if not isinstance(sparse_output, list) or not isinstance(standard_answers, list): + return False + if not all(isinstance(item, str) for item in sparse_output) or not all( + isinstance(item, str) for item in standard_answers + ): + return False - # Reconstruct object - return cls(**obj_data) - except Exception as e: - logging.warning(f"Deserialization failed for {type_str}: {e}") - raise + norm_output = [remove_punc(item) for item in sparse_output] + norm_standard = [remove_punc(item) for item in standard_answers] + return norm_output == norm_standard -def get_platform_specific_module(): - """Get platform-specific modules for inference. +def extract_answers(generated_text_list: List[str]) -> List[str]: + """Extract answers from generated text by removing thinking tags. + + Args: + generated_text_list: List of generated texts Returns: - SimpleNamespace with AutoTokenizer and SamplingParams + List of extracted answers """ - from types import SimpleNamespace + results = [] - from transformers import AutoTokenizer - from vllm import SamplingParams + for text in generated_text_list: + if not isinstance(text, str): + results.append("") + continue - # Create a namespace object - modules = SimpleNamespace() - modules.AutoTokenizer = AutoTokenizer - modules.SamplingParams = SamplingParams + if "" in text: + answer = text.rsplit("", 1)[-1].strip() + else: + answer = text.strip() - return modules + answer = answer.strip("'").strip('"').strip() + + results.append(answer) + + return results diff --git a/test/common/offline_inference_utils.py b/test/common/offline_inference_utils.py index 20ef2c75e..004f34669 100644 --- a/test/common/offline_inference_utils.py +++ b/test/common/offline_inference_utils.py @@ -200,6 +200,8 @@ def run_offline_inference( gpu_memory_utilization, ) + logging.info(f"Running offline inference with sampling_params: {sampling_params}") + with build_llm_with_uc( model_path=model_path, ucm_config=ucm_config, diff --git a/test/common/online_inference_utils.py b/test/common/online_inference_utils.py index fdb4c2796..d23f6dbc3 100644 --- a/test/common/online_inference_utils.py +++ b/test/common/online_inference_utils.py @@ -41,9 +41,14 @@ import subprocess import sys import time +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional import requests +from common.common_inference_utils import ( + match_any_answer, +) +from common.llm_connection.LLMBase import LLMRequest, LLMResponse logger = logging.getLogger(__name__) @@ -292,6 +297,68 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.stop() +def batch_chat( + client, + requests: List[LLMRequest], + max_workers: Optional[int] = None, +) -> List[LLMResponse]: + """Send multiple requests to the LLM server in parallel and return all responses. + + This function sends multiple requests concurrently using a thread pool, + waits for all requests to complete, and returns the results in the same order as the input requests. + + Args: + client: An LLM client that implements the LLMConnection protocol (e.g., OpenAIConn) + requests: List of LLMRequest objects to send + max_workers: Maximum number of worker threads (default: number of requests) + + Returns: + List of LLMResponse objects in the same order as input requests + + Example: + from common.llm_connection.openai_connector import OpenAIConn + from common.llm_connection.LLMRequest import LLMRequest + + client = OpenAIConn(base_url="http://localhost:8000", model="qwen") + requests = [ + LLMRequest(messages=[{"role": "user", "content": "Hello"}], max_tokens=100), + LLMRequest(messages=[{"role": "user", "content": "Hi"}], max_tokens=100), + ] + responses = batch_chat(client, requests) + for resp in responses: + print(resp.text) + """ + if not requests: + return [] + + if max_workers is None: + max_workers = len(requests) + + results: List[Optional[LLMResponse]] = [None] * len(requests) + + def _send_request(index: int, request: LLMRequest) -> tuple[int, LLMResponse]: + """Send a single request and return the index with response.""" + response = client.chat(request) + return index, response + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all requests + future_to_index = { + executor.submit(_send_request, i, req): i for i, req in enumerate(requests) + } + + # Collect results as they complete + for future in as_completed(future_to_index): + index, response = future.result() + results[index] = response + + for i, req in enumerate(requests): + if req is None: + raise RuntimeError(f"Request {i} failed to complete") + + return results + + def hbm_ssd_mixed_test( model_name: str, tokenizer_path: str, @@ -356,28 +423,14 @@ def hbm_ssd_mixed_test( tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_chat_template=True) - # Format prompt with chat template - system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n" - try: - messages = [ - {"role": "system", "content": system_content}, - {"role": "user", "content": test_prompt}, - ] - formatted_full_prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - add_special_tokens=True, - ) - except Exception: - formatted_full_prompt = test_prompt - # Split prompt for Phase prompt_first_part, _ = split_prompt_by_tokens( - formatted_full_prompt, tokenizer, split_ratio=prompt_split_ratio + test_prompt, tokenizer, split_ratio=prompt_split_ratio ) # Prepare messages + system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n" + phase1_messages = [ {"role": "system", "content": system_content}, {"role": "user", "content": test_prompt}, @@ -464,21 +517,6 @@ def hbm_ssd_mixed_test( # ===== Accuracy Test Results ===== print(f"\n[INFO] ===== Accuracy Test Results =====") - def normalize_text(text: str) -> str: - text = text.replace("\uff0c", ",") - text = text.replace("\u3002", ".") - text = text.replace("\uff01", "!") - text = text.replace("\uff1f", "?") - text = text.replace("\uff1a", ":") - text = text.replace("\uff1b", ";") - return text.strip() - - def match_any_answer(output: str, answers: List[str]) -> bool: - for answer in answers: - if normalize_text(output) == normalize_text(answer): - return True - return False - # Phase accuracy check phase1_correct = match_any_answer( phase1_1_output, standard_answers diff --git a/test/suites/E2E/test_offline_inference.py b/test/suites/E2E/test_offline_inference.py index 4150752f6..a06f0045b 100644 --- a/test/suites/E2E/test_offline_inference.py +++ b/test/suites/E2E/test_offline_inference.py @@ -7,6 +7,7 @@ ensure_storage_dir, get_platform_specific_module, load_prompt_from_file, + match_any_answer, serialize_sample_params, split_prompt_by_tokens, ) @@ -16,12 +17,11 @@ ) from common.path_utils import get_path_relative_to_test_root, get_path_to_model -os.environ["ENABLE_UCM_PATCH"] = "1" - class TestBasicOfflineInference: """Test basic offline inference functionality.""" + @pytest.mark.skip(reason="covered by online test") @pytest.mark.stage(1) @pytest.mark.feature("offline_inference") @pytest.mark.gpu_mem(6000) @@ -190,25 +190,6 @@ def test_offline_accuracy_hbm_ssd_mixed( print(f"\n[INFO] ===== Accuracy Test Results =====") - # Note: Small numerical precision differences in KV cache loading can cause - # punctuation token selection differences (e.g., full-width vs half-width comma) - def normalize_text(text: str) -> str: - """Normalize text for comparison by replacing similar punctuation.""" - text = text.replace(",", ",") - text = text.replace("。", ".") - text = text.replace("!", "!") - text = text.replace("?", "?") - text = text.replace(":", ":") - text = text.replace(";", ";") - return text.strip() - - def match_any_answer(output: str, answers: list[str]) -> bool: - """Check if output matches any of the standard answers.""" - for answer in answers: - if normalize_text(output) == normalize_text(answer): - return True - return False - # Compare Phase 1.1 vs Phase 1.2 (SSD load accuracy) phase1_correct = match_any_answer( phase1_1_output, standard_answers diff --git a/test/suites/E2E/test_offline_inference_sparse.py b/test/suites/E2E/test_offline_inference_sparse.py index be2b7c3d5..72fb5722a 100644 --- a/test/suites/E2E/test_offline_inference_sparse.py +++ b/test/suites/E2E/test_offline_inference_sparse.py @@ -1,6 +1,5 @@ import os import re -import string from pathlib import Path from typing import List @@ -8,9 +7,12 @@ import yaml from common.common_inference_utils import ( ensure_storage_dir, + extract_answers, get_platform_specific_module, load_prompt_from_file, load_prompt_list_from_file, + match_any_answer, + match_sparse_answer, serialize_sample_params, split_prompt_by_tokens, ) @@ -20,13 +22,11 @@ ) from common.path_utils import get_path_relative_to_test_root, get_path_to_model -os.environ["ENABLE_UCM_PATCH"] = "1" - class TestBasicOfflineInferenceSparse: """Test basic offline inference functionality.""" - @pytest.mark.skip(reason="refine this code and re-enable later") + @pytest.mark.skip(reason="covered by online test") @pytest.mark.stage(1) @pytest.mark.feature("offline_inference_sparse") @pytest.mark.gpu_mem(6000) @@ -181,26 +181,6 @@ def test_offline_accuracy_hbm_ssd_mixed_nosparse( print(f"[INFO] Phase 2.2 output: {phase2_full_output}") print(f"\n[INFO] ===== Accuracy Test Results =====") - - # Note: Small numerical precision differences in KV cache loading can cause - # punctuation token selection differences (e.g., full-width vs half-width comma) - def normalize_text(text: str) -> str: - """Normalize text for comparison by replacing similar punctuation.""" - text = text.replace(",", ",") - text = text.replace("。", ".") - text = text.replace("!", "!") - text = text.replace("?", "?") - text = text.replace(":", ":") - text = text.replace(";", ";") - return text.strip() - - def match_any_answer(output: str, answers: list[str]) -> bool: - """Check if output matches any of the standard answers.""" - for answer in answers: - if normalize_text(output) == normalize_text(answer): - return True - return False - # Compare Phase 1.1 vs Phase 1.2 (SSD load accuracy) phase1_correct = match_any_answer( phase1_1_output, standard_answers @@ -226,33 +206,8 @@ def match_any_answer(output: str, answers: list[str]) -> bool: """Test GSA sparse attention.""" - def remove_punc(self, text): - text = text.strip() - if not text: - return "" - cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." - all_punctuation = set(string.punctuation + cn_punctuation) - return "".join(ch for ch in text if ch not in all_punctuation) - - def match_sparse_answer( - self, sparse_output: List[str], standard_answers: List[str] - ) -> bool: - - if not isinstance(sparse_output, list) or not isinstance( - standard_answers, list - ): - return False - if not all(isinstance(item, str) for item in sparse_output) or not all( - isinstance(item, str) for item in standard_answers - ): - return False - - norm_output = [self.remove_punc(item) for item in sparse_output] - norm_standard = [self.remove_punc(item) for item in standard_answers] - return norm_output == norm_standard - @pytest.mark.stage(1) - @pytest.mark.feature("online_inference_sparse") + @pytest.mark.feature("offline_inference_sparse") @pytest.mark.gpu_mem(70000) @pytest.mark.parametrize("model_name", ["DeepSeek-V2-Lite-Chat"]) @pytest.mark.parametrize("max_tokens", [16]) @@ -353,7 +308,7 @@ def test_offline_gsa_mla( ) print(f'GsaOnDevice output: "{phase_sparse_output}"') print(f'Standard answers: "{standard_answers}"') - phase_sparse_correct = self.match_sparse_answer( + phase_sparse_correct = match_sparse_answer( phase_sparse_output, standard_answers ) if not phase_sparse_correct: @@ -363,7 +318,7 @@ def test_offline_gsa_mla( pytest.fail("GsaOnDevice Test Failed!") @pytest.mark.stage(1) - @pytest.mark.feature("online_inference_sparse") + @pytest.mark.feature("offline_inference_sparse") @pytest.mark.gpu_mem(30000) @pytest.mark.parametrize("model_name", ["Qwen3-4B"]) @pytest.mark.parametrize("max_tokens", [2048]) @@ -459,32 +414,13 @@ def test_offline_gsa_gqa( timeout=1800, ) - def extract_answers(generated_text_list: List[str]) -> List[str]: - results = [] - - for text in generated_text_list: - if not isinstance(text, str): - results.append("") - continue - - if "" in text: - answer = text.rsplit("", 1)[-1].strip() - else: - answer = text.strip() - - answer = answer.strip("'").strip('"').strip() - - results.append(answer) - - return results - phase_sparse_output = extract_answers(phase_sparse_output) print( f" GsaOnDevice inference for a GQA-based model is completed in a subprocess." ) print(f'GsaOnDevice output: "{phase_sparse_output}"') print(f'Standard answers: "{standard_answers}"') - phase_sparse_correct = self.match_sparse_answer( + phase_sparse_correct = match_sparse_answer( phase_sparse_output, standard_answers ) if not phase_sparse_correct: @@ -492,116 +428,3 @@ def extract_answers(generated_text_list: List[str]) -> List[str]: print(f"GsaOnDevice output:\n{phase_sparse_output}") print(f"Standard answers:\n{standard_answers}") pytest.fail("GsaOnDevice Test Failed!") - - """Test ESA sparse attention.""" - - @pytest.mark.skip(reason="refine this code and re-enable later") - @pytest.mark.stage(1) - @pytest.mark.feature("offline_inference_sparse") - @pytest.mark.gpu_mem(6000) - @pytest.mark.parametrize("model_name", ["Qwen2.5-1.5B-Instruct"]) - @pytest.mark.parametrize("max_tokens", [200]) - @pytest.mark.parametrize("enforce_eager", [False]) - @pytest.mark.parametrize("max_num_batched_tokens", [2047]) - def test_offline_esa( - self, - model_name: str, - max_tokens: int, - enforce_eager: bool, - max_num_batched_tokens: int, - ): - config_file = get_path_relative_to_test_root("config.yaml") - with open(config_file, "r", encoding="utf-8") as f: - config = yaml.safe_load(f) - - model_path = get_path_to_model(model_name, config) - - assert os.path.exists(model_path), f"Model path does not exist: {model_path}" - - ucm_storage_dir = "/tmp/ucm_cache" - - # make sure UCM storage directory exists and is empty - ensure_storage_dir(ucm_storage_dir, clear_existing=True) - - try: - test_prompt, standard_answers = load_prompt_from_file( - get_path_relative_to_test_root( - "suites/E2E/prompts/test_offline_inference.json" - ) - ) - if not standard_answers: - pytest.fail(f"No standard answers found in prompt.json") - except Exception as e: - pytest.fail(f"Failed to load prompt from prompt.json: {e}") - - print(f"Standard answers: {standard_answers}") - - tokenizer = get_platform_specific_module().AutoTokenizer.from_pretrained( - model_path, use_chat_template=True - ) - - try: - messages = [ - { - "role": "system", - "content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n", - }, - {"role": "user", "content": test_prompt}, - ] - formatted_full_prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - add_special_tokens=True, - ) - except Exception: - formatted_full_prompt = test_prompt - - ucm_config = { - "ucm_connectors": [ - { - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": ucm_storage_dir, - "use_direct": False, - }, - } - ], - "ucm_sparse_config": { - "ESA": { - "init_window_sz": 1, - "local_window_sz": 2, - "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5, - } - }, - } - - sampling_params = get_platform_specific_module().SamplingParams( - temperature=0.0, - top_p=1, - max_tokens=max_tokens, - ignore_eos=False, - ) - - # Convert SamplingParams to dict for serialization, as non-picklable objects cannot be passed to subprocess - sampling_params_dict = serialize_sample_params(sampling_params) - - phase1_outputs = run_in_spawn_subprocess( - run_offline_inference, - model_path, - ucm_config, - [formatted_full_prompt, formatted_full_prompt], - sampling_params_dict, - False, # enable_prefix_caching=False - enforce_eager, - "ESA", - max_num_batched_tokens, - timeout=180, - ) - phase1_1_output = phase1_outputs[0] # Phase 1.1: SSD save - phase1_2_output = phase1_outputs[1] # Phase 1.2: SSD load - print(f"ESA inference completed in subprocess") - print(f'Phase 1.1 output: "{phase1_1_output}"') - print(f'Phase 1.2 output: "{phase1_2_output}"') diff --git a/test/suites/E2E/test_online_inference.py b/test/suites/E2E/test_online_inference.py index c1fd43b86..49312de95 100644 --- a/test/suites/E2E/test_online_inference.py +++ b/test/suites/E2E/test_online_inference.py @@ -22,8 +22,6 @@ from common.online_inference_utils import hbm_ssd_mixed_test from common.path_utils import get_path_relative_to_test_root, get_path_to_model -os.environ["ENABLE_UCM_PATCH"] = "1" - class TestBasicOnlineInference: """Test basic online inference functionality.""" diff --git a/test/suites/E2E/test_online_inference_sparse.py b/test/suites/E2E/test_online_inference_sparse.py index f8177c769..c978ee713 100644 --- a/test/suites/E2E/test_online_inference_sparse.py +++ b/test/suites/E2E/test_online_inference_sparse.py @@ -20,17 +20,20 @@ import pytest import yaml from common.common_inference_utils import ( - ensure_storage_dir, - load_prompt_from_file, + extract_answers, + load_prompt_list_from_file, + match_sparse_answer, ) from common.llm_connection.LLMBase import LLMRequest from common.llm_connection.openai_connector import OpenAIConn from common.llm_connection.token_counter import HuggingFaceTokenizer -from common.online_inference_utils import VLLMServerManager, hbm_ssd_mixed_test +from common.online_inference_utils import ( + VLLMServerManager, + batch_chat, + hbm_ssd_mixed_test, +) from common.path_utils import get_path_relative_to_test_root, get_path_to_model -os.environ["ENABLE_UCM_PATCH"] = "1" - class TestBasicOnlineInference: """Test basic online inference functionality.""" @@ -107,22 +110,21 @@ def test_online_accuracy_hbm_ssd_mixed( vllm_server_startup_args, ) - @pytest.mark.skip(reason="refine this code and re-enable later") @pytest.mark.stage(1) - @pytest.mark.gpu_mem(10000) + @pytest.mark.gpu_mem(70000) @pytest.mark.feature("online_inference_sparse") - @pytest.mark.parametrize("model_name", ["Qwen3-4B"]) - @pytest.mark.parametrize("max_tokens", [200]) - def test_online_gsa( + @pytest.mark.parametrize("model_name", ["DeepSeek-V2-Lite-Chat"]) + @pytest.mark.parametrize("max_tokens", [16]) + def test_online_gsa_mla( self, model_name: str, max_tokens: int, ): - """Test GSA sparse attention via online inference. + """Test GSA sparse attention via online inference for MLA-based model. - Mirrors test_offline_inference_sparse.py::test_offline_gsa. - Starts vLLM with GSA sparse config, sends full prompt twice, - verifies SSD save/load works. + Mirrors test_offline_inference_sparse.py::test_offline_gsa_mla. + Loads prompts from test_offline_gsaondevice_inference.json, + sends them in parallel using batch_chat, verifies using match_sparse_answer. """ os.environ["ENABLE_SPARSE"] = "1" os.environ["VLLM_HASH_ATTENTION"] = "1" @@ -131,40 +133,56 @@ def test_online_gsa( with open(config_file, "r", encoding="utf-8") as f: config = yaml.safe_load(f) - ucm_storage_dir = "/tmp/ucm_cache" - ensure_storage_dir(ucm_storage_dir, clear_existing=True) - - served_model_name = model_name - tokenizer_path = f"/home/models/{model_name}" model_path = get_path_to_model(model_name, config) + tokenizer_path = f"/home/models/{model_name}" + served_model_name = model_name - test_prompt, _ = load_prompt_from_file( - get_path_relative_to_test_root( - "suites/E2E/prompts/test_offline_inference.json" + # Load prompts and answers + try: + test_prompts, standard_answers = load_prompt_list_from_file( + get_path_relative_to_test_root( + "suites/E2E/prompts/test_offline_gsaondevice_inference.json" + ) ) - ) + if not standard_answers: + pytest.fail(f"No standard answers found in prompt.json") + except Exception as e: + pytest.fail(f"Failed to load prompt from prompt.json: {e}") + + print(f"Standard answers: {standard_answers}") + + tokenizer = HuggingFaceTokenizer(tokenizer_path) system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n" + # Create LLMRequest list + requests = [ + LLMRequest( + messages=[ + {"role": "system", "content": system_content}, + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + temperature=0.0, + ) + for prompt in test_prompts + ] + + # UCM config with UcmPipelineStore ucm_config = { "ucm_connectors": [ { - "ucm_connector_name": "UcmNfsStore", + "ucm_connector_name": "UcmPipelineStore", "ucm_connector_config": { - "storage_backends": ucm_storage_dir, - "use_direct": False, + "store_pipeline": "Empty", + "share_buffer_enable": True, }, } ], "ucm_sparse_config": {"GSAOnDevice": {}}, } - phase1_messages = [ - {"role": "system", "content": system_content}, - {"role": "user", "content": test_prompt}, - ] - - print(f"\n===== Online GSA Sparse Test =====") + print(f"\n===== Online GSA MLA Sparse Test =====") print(f"Model: {model_path}") print(f"Starting vLLM server with GSA sparse config") @@ -172,136 +190,151 @@ def test_online_gsa( model_path=model_path, port=8000, ucm_config=ucm_config, - max_model_len=12000, + max_model_len=70000, served_model_name=served_model_name, enable_prefix_caching=False, ) as server: client = OpenAIConn( base_url=server.url, - tokenizer=HuggingFaceTokenizer(tokenizer_path), + tokenizer=tokenizer, model=served_model_name, ) assert client.health_check() print(f"server models: {client.list_models()}") - # Phase 1.1: SSD save - phase1_1_output = client.chat( - LLMRequest( - messages=phase1_messages, max_tokens=max_tokens, temperature=0.0 - ) - ).text - print(f'Phase 1.1 output: "{phase1_1_output}"') + # Send requests in parallel using batch_chat + responses = batch_chat(client, requests) + outputs = [resp.text for resp in responses] + + print(f"GSA MLA online inference completed.") + print(f'GSA MLA output: "{outputs}"') + print(f'Standard answers: "{standard_answers}"') + + # Verify + phase_sparse_correct = match_sparse_answer(outputs, standard_answers) + + if not phase_sparse_correct: + print(f"Incorrect answer in GSA MLA online inference output!") + print(f"GSA MLA output:\n{outputs}") + print(f"Standard answers:\n{standard_answers}") + pytest.fail("GSA MLA Online Test Failed!") - # Phase 1.2: SSD load - phase1_2_output = client.chat( - LLMRequest( - messages=phase1_messages, max_tokens=max_tokens, temperature=0.0 - ) - ).text - print(f'Phase 1.2 output: "{phase1_2_output}"') client.close() - print("GSA inference completed.") + print("GSA MLA online inference completed.") - @pytest.mark.skip(reason="refine this code and re-enable later") @pytest.mark.stage(1) - @pytest.mark.gpu_mem(6000) + @pytest.mark.gpu_mem(30000) @pytest.mark.feature("online_inference_sparse") - @pytest.mark.parametrize("model_name", ["Qwen2.5-1.5B-Instruct"]) - @pytest.mark.parametrize("max_tokens", [200]) - def test_online_esa( + @pytest.mark.parametrize("model_name", ["Qwen3-4B"]) + @pytest.mark.parametrize("max_tokens", [2048]) + def test_online_gsa_gqa( self, model_name: str, max_tokens: int, ): - """Test ESA sparse attention via online inference. + """Test GSA sparse attention via online inference. - Mirrors test_offline_inference_sparse.py::test_offline_esa. - Starts vLLM with ESA sparse config, sends full prompt twice, - verifies SSD save/load works. + Mirrors test_offline_inference_sparse.py::test_offline_gsa_gqa. + Loads prompts from test_offline_gsaondevice_inference.json, + sends them in parallel using batch_chat, verifies using match_sparse_answer. """ + os.environ["ENABLE_SPARSE"] = "1" + os.environ["VLLM_HASH_ATTENTION"] = "1" + config_file = get_path_relative_to_test_root("config.yaml") with open(config_file, "r", encoding="utf-8") as f: config = yaml.safe_load(f) - ucm_storage_dir = "/tmp/ucm_cache" - ensure_storage_dir(ucm_storage_dir, clear_existing=True) - - served_model_name = model_name - tokenizer_path = f"/home/models/{model_name}" model_path = get_path_to_model(model_name, config) + tokenizer_path = f"/home/models/{model_name}" + served_model_name = model_name - test_prompt, _ = load_prompt_from_file( - get_path_relative_to_test_root( - "suites/E2E/prompts/test_offline_inference.json" + # Load prompts and answers + try: + test_prompts, standard_answers = load_prompt_list_from_file( + get_path_relative_to_test_root( + "suites/E2E/prompts/test_offline_gsaondevice_inference.json" + ) ) - ) + if not standard_answers: + pytest.fail(f"No standard answers found in prompt.json") + except Exception as e: + pytest.fail(f"Failed to load prompt from prompt.json: {e}") + + print(f"Standard answers: {standard_answers}") + + tokenizer = HuggingFaceTokenizer(tokenizer_path) system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n" + # Create LLMRequest list + requests = [ + LLMRequest( + messages=[ + {"role": "system", "content": system_content}, + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + temperature=0.0, + ) + for prompt in test_prompts + ] + + # UCM config with UcmPipelineStore ucm_config = { "ucm_connectors": [ { - "ucm_connector_name": "UcmNfsStore", + "ucm_connector_name": "UcmPipelineStore", "ucm_connector_config": { - "storage_backends": ucm_storage_dir, - "use_direct": False, + "store_pipeline": "Empty", + "share_buffer_enable": True, }, } ], - "ucm_sparse_config": { - "ESA": { - "init_window_sz": 1, - "local_window_sz": 2, - "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5, - } - }, + "ucm_sparse_config": {"GSAOnDevice": {}}, } - phase1_messages = [ - {"role": "system", "content": system_content}, - {"role": "user", "content": test_prompt}, - ] - - print(f"\n===== Online ESA Sparse Test =====") + print(f"\n===== Online GSA Sparse Test =====") print(f"Model: {model_path}") - print(f"Starting vLLM server with ESA sparse config") + print(f"Starting vLLM server with GSA sparse config") with VLLMServerManager( model_path=model_path, port=8000, ucm_config=ucm_config, - max_model_len=12000, + max_model_len=30000, served_model_name=served_model_name, enable_prefix_caching=False, ) as server: client = OpenAIConn( base_url=server.url, - tokenizer=HuggingFaceTokenizer(tokenizer_path), + tokenizer=tokenizer, model=served_model_name, ) assert client.health_check() print(f"server models: {client.list_models()}") - # Phase 1.1: SSD save - phase1_1_output = client.chat( - LLMRequest( - messages=phase1_messages, max_tokens=max_tokens, temperature=0.0 - ) - ).text - print(f'Phase 1.1 output: "{phase1_1_output}"') + # Send requests in parallel using batch_chat + responses = batch_chat(client, requests) + outputs = [resp.text for resp in responses] + + print(f"GSA online inference completed.") + print(f'GSA output: "{outputs}"') + print(f'Standard answers: "{standard_answers}"') + + # Extract answers and verify + outputs = extract_answers(outputs) + phase_sparse_correct = match_sparse_answer(outputs, standard_answers) + + if not phase_sparse_correct: + print(f"Incorrect answer in GSA online inference output!") + print(f"GSA output:\n{outputs}") + print(f"Standard answers:\n{standard_answers}") + pytest.fail("GSA Online Test Failed!") - # Phase 1.2: SSD load - phase1_2_output = client.chat( - LLMRequest( - messages=phase1_messages, max_tokens=max_tokens, temperature=0.0 - ) - ).text - print(f'Phase 1.2 output: "{phase1_2_output}"') client.close() - print("ESA inference completed.") + print("GSA online inference completed.")