Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,71 @@ jobs:
run: |
sudo rm -f /workspace/test_results/${{ steps.version.outputs.version }}-online-inference-sparse.xml || true

test-e2e-sparse-offline-gpu:
runs-on: ["gpu-test-in-docker"]
needs: lint-and-unit-tests
permissions:
checks: write
pull-requests: write
steps:
- name: Clean repo
run: |
if [ -d "${{github.workspace}}" ]; then
cd ${{github.workspace}}
rm -rf ./*
rm -rf .[!.]*
fi
- uses: actions/checkout@v4
- name: Install Docker CLI
run: |
if ! command -v docker &> /dev/null; then
echo "Docker CLI not found, installing..."
sudo apt-get update
sudo apt-get install -y docker.io
else
echo "Docker CLI already installed"
fi
- name: Generate Docker Image Version
id: version
run: |
DATE=$(date +%Y%m%d)
SHORT_SHA=$(echo '${{ github.sha }}' | cut -c1-7)
REF_NAME=$(echo '${{ github.ref_name }}' | tr '/' '-')
VERSION="${REF_NAME}-${DATE}-${{ github.run_number }}-${SHORT_SHA}"
echo "version=${VERSION}" >> $GITHUB_OUTPUT
echo "Docker image version: ${VERSION}"
- name: Build
run: |
cd ${{github.workspace}}
sudo -E docker build --network=host \
--build-arg http_proxy="${http_proxy:-}" \
--build-arg https_proxy="${https_proxy:-}" \
--build-arg ENABLE_SPARSE=true \
-t ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} \
-f ./docker/Dockerfile.vllm_gpu ./
- name: Test E2E in Docker
run: |
sudo chmod -R 777 /workspace/test_results/
sudo docker run --rm \
--gpus all \
--ipc=host \
-v /home/models:/home/models \
-v /home/yanzhao/pipeline_results:/workspace/test_results \
ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} \
-c "cd /workspace/unified-cache-management/test && pip install -r requirements.txt && python3 -m pytest -x --stage=1 --feature=offline_inference_sparse --junitxml=/workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml"
- name: Upload pytest results
uses: EnricoMi/publish-unit-test-result-action/linux@v2
if: (!cancelled())
with:
files: |
/workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml
check_name: Sparse attention test results
- name: Cleanup Docker Image
if: always()
run: |
sudo docker rmi ucm-e2etest-gpu-sparse:${{ steps.version.outputs.version }} || true
- name: Cleanup Test Results
if: always()
run: |
sudo rm -f /workspace/test_results/${{ steps.version.outputs.version }}-offline-inference-sparse.xml || true

139 changes: 79 additions & 60 deletions test/common/common_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,89 +185,108 @@ def deserialize_sample_params(json_str: str) -> Any:
)


def to_dict_for_serialization(obj: Any) -> Dict[str, Any]:
"""Convert any object to dict for subprocess serialization.
def get_platform_specific_module():
"""Get platform-specific modules for inference.

Returns:
SimpleNamespace with AutoTokenizer and SamplingParams
"""
from types import SimpleNamespace

from transformers import AutoTokenizer
from vllm import SamplingParams

# Create a namespace object
modules = SimpleNamespace()
modules.AutoTokenizer = AutoTokenizer
modules.SamplingParams = SamplingParams

return modules

Supports:
- dataclass objects
- regular objects with __dict__
- vLLM SamplingParams and other custom classes

def match_any_answer(output: str, answers: List[str]) -> bool:
"""Check if output matches any of the standard answers.

Args:
obj: Object to serialize (dataclass, SamplingParams, etc.)
output: Generated output text
answers: List of acceptable answers

Returns:
Dict with _type and _data fields for reconstruction
True if output matches any answer
"""
import logging
from dataclasses import asdict, is_dataclass

try:
# Try dataclass first
if is_dataclass(obj) and not isinstance(obj, type):
data = asdict(obj)
# Try __dict__ for regular objects
elif hasattr(obj, "__dict__"):
data = obj.__dict__.copy()
else:
raise ValueError(f"Cannot serialize object of type {type(obj)}")

return {
"_type": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
"_data": data,
}
except Exception as e:
logging.warning(f"Serialization failed for {type(obj)}: {e}")
raise
for answer in answers:
if remove_punc(output) == remove_punc(answer):
return True
return False


def from_dict_for_serialization(serialized: Dict[str, Any]) -> Any:
"""Recreate object from serialized dict.
def remove_punc(text: str) -> str:
"""Remove punctuation from text for comparison.

Args:
serialized: Dict created by to_dict_for_serialization()
text: Text to remove punctuation from

Returns:
Reconstructed object instance
Text without punctuation
"""
import logging
import string

text = text.strip()
if not text:
return ""
cn_punctuation = (
"!?。。"#$%&'()*+,-/:;<=>[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—''‛"
"„‟…‧﹏."
)
all_punctuation = set(string.punctuation + cn_punctuation)
return "".join(ch for ch in text if ch not in all_punctuation)

if "_type" not in serialized:
# Not a serialized object, return as-is
return serialized

type_str = serialized["_type"]
obj_data = serialized.get("_data", {})
def match_sparse_answer(sparse_output: List[str], standard_answers: List[str]) -> bool:
"""Check if sparse output matches standard answers after removing punctuation.

try:
# Parse module and class name
import importlib
Args:
sparse_output: List of generated outputs
standard_answers: List of expected answers

module_name, class_name = type_str.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
Returns:
True if outputs match after normalization
"""
if not isinstance(sparse_output, list) or not isinstance(standard_answers, list):
return False
if not all(isinstance(item, str) for item in sparse_output) or not all(
isinstance(item, str) for item in standard_answers
):
return False

# Reconstruct object
return cls(**obj_data)
except Exception as e:
logging.warning(f"Deserialization failed for {type_str}: {e}")
raise
norm_output = [remove_punc(item) for item in sparse_output]
norm_standard = [remove_punc(item) for item in standard_answers]
return norm_output == norm_standard


def get_platform_specific_module():
"""Get platform-specific modules for inference.
def extract_answers(generated_text_list: List[str]) -> List[str]:
"""Extract answers from generated text by removing thinking tags.

Args:
generated_text_list: List of generated texts

Returns:
SimpleNamespace with AutoTokenizer and SamplingParams
List of extracted answers
"""
from types import SimpleNamespace
results = []

from transformers import AutoTokenizer
from vllm import SamplingParams
for text in generated_text_list:
if not isinstance(text, str):
results.append("")
continue

# Create a namespace object
modules = SimpleNamespace()
modules.AutoTokenizer = AutoTokenizer
modules.SamplingParams = SamplingParams
if "</think>" in text:
answer = text.rsplit("</think>", 1)[-1].strip()
else:
answer = text.strip()

return modules
answer = answer.strip("'").strip('"').strip()

results.append(answer)

return results
2 changes: 2 additions & 0 deletions test/common/offline_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def run_offline_inference(
gpu_memory_utilization,
)

logging.info(f"Running offline inference with sampling_params: {sampling_params}")

with build_llm_with_uc(
model_path=model_path,
ucm_config=ucm_config,
Expand Down
102 changes: 70 additions & 32 deletions test/common/online_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional

import requests
from common.common_inference_utils import (
match_any_answer,
)
from common.llm_connection.LLMBase import LLMRequest, LLMResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -292,6 +297,68 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.stop()


def batch_chat(
client,
requests: List[LLMRequest],
max_workers: Optional[int] = None,
) -> List[LLMResponse]:
"""Send multiple requests to the LLM server in parallel and return all responses.

This function sends multiple requests concurrently using a thread pool,
waits for all requests to complete, and returns the results in the same order as the input requests.

Args:
client: An LLM client that implements the LLMConnection protocol (e.g., OpenAIConn)
requests: List of LLMRequest objects to send
max_workers: Maximum number of worker threads (default: number of requests)

Returns:
List of LLMResponse objects in the same order as input requests

Example:
from common.llm_connection.openai_connector import OpenAIConn
from common.llm_connection.LLMRequest import LLMRequest

client = OpenAIConn(base_url="http://localhost:8000", model="qwen")
requests = [
LLMRequest(messages=[{"role": "user", "content": "Hello"}], max_tokens=100),
LLMRequest(messages=[{"role": "user", "content": "Hi"}], max_tokens=100),
]
responses = batch_chat(client, requests)
for resp in responses:
print(resp.text)
"""
if not requests:
return []

if max_workers is None:
max_workers = len(requests)

results: List[Optional[LLMResponse]] = [None] * len(requests)

def _send_request(index: int, request: LLMRequest) -> tuple[int, LLMResponse]:
"""Send a single request and return the index with response."""
response = client.chat(request)
return index, response

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all requests
future_to_index = {
executor.submit(_send_request, i, req): i for i, req in enumerate(requests)
}

# Collect results as they complete
for future in as_completed(future_to_index):
index, response = future.result()
results[index] = response

for i, req in enumerate(requests):
if req is None:
raise RuntimeError(f"Request {i} failed to complete")

return results


def hbm_ssd_mixed_test(
model_name: str,
tokenizer_path: str,
Expand Down Expand Up @@ -356,28 +423,14 @@ def hbm_ssd_mixed_test(

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_chat_template=True)

# Format prompt with chat template
system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n"
try:
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": test_prompt},
]
formatted_full_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=True,
)
except Exception:
formatted_full_prompt = test_prompt

# Split prompt for Phase
prompt_first_part, _ = split_prompt_by_tokens(
formatted_full_prompt, tokenizer, split_ratio=prompt_split_ratio
test_prompt, tokenizer, split_ratio=prompt_split_ratio
)

# Prepare messages
system_content = "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:\u201c全国美国文学研究会的第十八届年会在哪所大学举办的?\u201d\n回答应该为:\u201cxx大学\u201d。\n\n"

phase1_messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": test_prompt},
Expand Down Expand Up @@ -464,21 +517,6 @@ def hbm_ssd_mixed_test(
# ===== Accuracy Test Results =====
print(f"\n[INFO] ===== Accuracy Test Results =====")

def normalize_text(text: str) -> str:
text = text.replace("\uff0c", ",")
text = text.replace("\u3002", ".")
text = text.replace("\uff01", "!")
text = text.replace("\uff1f", "?")
text = text.replace("\uff1a", ":")
text = text.replace("\uff1b", ";")
return text.strip()

def match_any_answer(output: str, answers: List[str]) -> bool:
for answer in answers:
if normalize_text(output) == normalize_text(answer):
return True
return False

# Phase accuracy check
phase1_correct = match_any_answer(
phase1_1_output, standard_answers
Expand Down
Loading
Loading