diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh index 4505dc7a9373c..fa4f74fca7a11 100644 --- a/.buildkite/run-hpu-test.sh +++ b/.buildkite/run-hpu-test.sh @@ -13,4 +13,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --runtime=habana --name=hpu-test --network=host -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cd721971d01d6..3cb91fc0f8232 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,13 +3,16 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -CMakeLists.txt @tlrmchlsmth @WoosukKwon +/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +CMakeLists.txt @tlrmchlsmth + +# vLLM V1 +/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic # Test ownership /tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 71f4e520135d4..d1f6105a47166 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1,2 @@ github: [vllm-project] -open_collective: [vllm] +open_collective: vllm diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000..51a73c857ccb2 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +FILL IN THE PR DESCRIPTION HERE + +FIX #xxxx (*link existing issues this PR will resolve*) + +**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html ** diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh index 3b2da7b9f8966..3246c6f9bc4b7 100755 --- a/.github/scripts/cleanup_pr_body.sh +++ b/.github/scripts/cleanup_pr_body.sh @@ -15,19 +15,36 @@ NEW=/tmp/new_pr_body.txt gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}" cp "${OLD}" "${NEW}" -# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**" -sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE\*\*/,$d' "${NEW}" - # Remove "FIX #xxxx (*link existing issues this PR will resolve*)" sed -i '/FIX #xxxx.*$/d' "${NEW}" # Remove "FILL IN THE PR DESCRIPTION HERE" sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}" +# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**" +sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" + +# Remove HTML
section that includes text of "PR Checklist (Click to Expand)" +python3 - <.*?.*?PR Checklist \(Click to Expand\).*?.*?
', re.DOTALL) +content = re.sub(pattern, '', content) + +with open("${NEW}", "w") as file: + file.write(content) +EOF + # Run this only if ${NEW} is different than ${OLD} if ! cmp -s "${OLD}" "${NEW}"; then - echo "Updating PR body" gh pr edit --body-file "${NEW}" "${PR_NUMBER}" + echo + echo "Updated PR body:" + echo + cat "${NEW}" else echo "No changes needed" fi diff --git a/Dockerfile.ubi b/Dockerfile.ubi index faf3f498c1d2c..527c6b1557b12 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -201,7 +201,6 @@ WORKDIR /home/vllm ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] - FROM vllm-openai as vllm-grpc-adapter USER root @@ -209,7 +208,12 @@ USER root RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ - HOME=/root uv pip install "$(echo /workspace/dist/*.whl)[tensorizer]" vllm-tgis-adapter==0.5.3 + uv pip install $(echo /workspace/dist/*.whl)'[tensorizer]' --verbose && \ + uv pip install \ + "git+https://github.com/opendatahub-io/vllm-tgis-adapter@ibm-20241106-adapter" --verbose + +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ + echo "Local dir and dist:" && pwd && ls -l /workspace/dist ENV GRPC_PORT=8033 \ PORT=8000 \ diff --git a/README.md b/README.md index 6530886ed7de2..0ef073210d070 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a - Dropbox - Google Cloud - Lambda Lab +- Nebius - NVIDIA - Replicate - Roblox diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bdb8ea8e2a5dc..e9fc037a46965 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -251,6 +251,19 @@ def sample_hf_requests( "url": f"data:image/jpeg;base64,{image_base64}" }, } + elif "image" in data and isinstance(data["image"], str): + if (data["image"].startswith("http://") or \ + data["image"].startswith("file://")): + image_url = data["image"] + else: + image_url = f"file://{data['image']}" + + mm_content = { + "type": "image_url", + "image_url": { + "url": image_url + }, + } else: mm_content = None diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 665b50bf18cf0..a0342d08f1db8 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -2,8 +2,10 @@ import copy import itertools import math +import os import pickle as pkl import time +from dataclasses import dataclass from itertools import product from typing import Callable, Iterable, List, Optional, Tuple @@ -15,11 +17,12 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, + marlin_zero_points) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, pack_rows, quantize_weights) + pack_rows, quantize_weights) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -27,149 +30,349 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.bfloat16: "bf16", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: Optional[int] + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) -def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # make col major - return ops.machete_prepack_B(w_q, wtype) + return w_ref, w_q, w_s, w_zp -def make_bench_tensors( - atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, - k: int -) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, - torch.tensor]]]: - assert wtype.is_integer(), "TODO: support floating point weights" +def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> List[BenchmarkTensors]: + m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) - - a = torch.randn((m, k), device="cuda", dtype=atype) * 5 - weights = [ - torch.randn((k, n), device="cuda", dtype=atype) - for _ in range(num_weights) - ] - quanitized_weights = [ - quantize_weights(w, wtype, group_size) for w in weights - ] - - return a, quanitized_weights + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / + (k * n * types.weight_type.size_bits)) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: List[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + benchmark_tensors.append( + BenchmarkTensors(w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s)) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.group_size) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.gptq_marlin_gemm(a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + + if bt.w_ch_s is not None: + s_ch = bt.w_ch_s.to(torch.float32) + else: + s_ch = torch.ones(bt.w_ref.shape[1], + dtype=torch.float32, + device=device) + + if bt.w_tok_s is not None: + s_tok = bt.w_tok_s.to(torch.float32) + else: + s_tok = torch.ones(bt.a.shape[0], + dtype=torch.float32, + device=device) + + fn = lambda: ops.marlin_qqq_gemm(a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0]) + + return fn + + +def machete_create_bench_fn(bt: BenchmarkTensors, + out_type=torch.dtype, + schedule=None) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, + None if bt.w_g_s is None else bt.w_g_s.dtype) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=bt.w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) # impl - # bench -def bench_fn(label: str, sub_label: str, description: str, - fn: Callable) -> TMeasurement: - min_run_time = 1 - return TBenchmark.Timer( - stmt="fn()", + +def bench_fns(label: str, sub_label: str, description: str, + fns: List[Callable]): + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, globals={ - "fn": fn + "fns": fns }, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"), nvtx.annotate( + f"{label}|{sub_label}|{description}"): + fns[0]() -def loop_over_weights( - a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, - torch.tensor, torch.tensor]], - fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], - None]): - for w_ref, w_q, w_s, _ in weights: - fn(a, w_ref, w_q, w_s) + return res _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(atype: torch.dtype, - wtype: ScalarType, +def bench(types: TypeConfig, group_size: int, m: int, k: int, n: int, label: str, sub_label: str, - benchmark_marlinv1: bool = True, - sweep_schedules: bool = True) -> Iterable[TMeasurement]: - global _SWEEP_SCHEDULES_RESULTS - - a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) - sub_label += f", L={len(weights)}" - - weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) - for w_ref, w_q, w_s, w_zp in weights] + sweep_schedules: bool = True) -> List[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}"+\ + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" timers = [] # pytorch impl timers.append( - bench_fn( - label, sub_label, "torch.matmul", lambda: loop_over_weights( - a, - weights, - lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), - ))) + bench_fns( + label, sub_label, "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) + for bt in benchmark_tensors])) - if benchmark_marlinv1: - w_ref = weights[0][0] - - w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) - sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) - g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) - - def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: - w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) - return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, - wtype.size_bits) - - def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: - return marlin_permute_scales(w_s, *w_ref.shape, group_size) - - weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), - marlinv1_permute_scales(w_s), w_zp) - for w_ref, w_q, w_s, w_zp in weights] - - workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - # marlinv1 + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ + cutlass_scaled_mm_create_bench_fn(bt) + for bt in benchmark_tensors + ])) + + if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fn( - label, sub_label, "marlin_orig", lambda: loop_over_weights( - a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. - gptq_marlin_gemm(a, - w_q, - w_s, - w_zp_empty, - g_idx, - sort_indices, - workspace.scratch, - wtype, - size_m=a.shape[0], - size_n=w_ref.shape[1], - size_k=w_ref.shape[0], - is_k_full=True)))) + bench_fns(label, sub_label, f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) + for bt in benchmark_tensors])) # machete timers.append( - bench_fn( - label, sub_label, "machete_heuristic", lambda: loop_over_weights( - a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( - a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + bench_fns(label, sub_label, f"machete ({name_type_string})", [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ])) if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + print("Finding best schedule for machete") best = None best_schedule = None - schedules = ops.machete_supported_schedules(wtype) + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + for schedule in reversed(schedules): schedule_M = int(schedule.split("_")[0].split("x")[1]) @@ -177,16 +380,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - def run(a, _, w_q, w_s, schedule=schedule): - ops.machete_gemm(a, - w_q, - wtype, - w_s, - b_group_size=group_size, - schedule=schedule) - - res = bench_fn(label, sub_label, "machete_best", - lambda: loop_over_weights(a, weights_machete, run)) + res = bench_fns(label, sub_label, "machete_best", [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule) + for bt in benchmark_tensors + ]) results_row = { "M": m, @@ -213,25 +411,33 @@ def run(a, _, w_q, w_s, schedule=schedule): # runner -def print_timers(timers: Iterable[TMeasurement]): +def print_timers(timers: List[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() -def run(dtype: torch.dtype, sweep_schedules: bool, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) - results = [] + results: List[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(dtype, - scalar_types.uint4b8, - 128, + timers = bench(types, + args.group_size, m, k, n, - f"{dtype}-gemm", + f"{args.act_type}-gemm", f"MKN=({m}x{k}x{n})", - sweep_schedules=sweep_schedules) + sweep_schedules=args.sweep_schedules) print_timers(timers) results.extend(timers) @@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool, # output makers def make_output( - data: Iterable[TMeasurement], + data: List[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None, @@ -262,7 +468,6 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -306,33 +511,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, args.sweep_schedules, MKNs) + data = run(args, MKNs) model_bench_data.append(data) + type_string = f"{args.act_type}" + # Print all results for data, model_tp in zip(model_bench_data, models_tps): model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print(f"== Results {type_string} {model}-TP{tp_size} ====") print_timers(data) - timestamp = int(time.time()) + timestr = time.strftime("%Y%m%d-%H%M%S") - all_data = [] + all_results = [] for d in model_bench_data: - all_data.extend(d) + all_results.extend(d) + # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump({ + "args": args_dict, + "results": all_results, + }, f) if __name__ == "__main__": def to_torch_dtype(dt): - if dt == "bfloat16": - return torch.bfloat16 - if dt == "float16": - return torch.float16 - raise ValueError("unsupported dtype") + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" @@ -352,12 +573,42 @@ def to_torch_dtype(dt): """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - "--dtype", - type=to_torch_dtype, + "--act-type", + action=ToTorchDtype, required=True, - help="Available options are ['bfloat16', 'float16']", + choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, ) parser.add_argument( "--sweep-schedules", diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index de608fd05af70..7d0bd84150a27 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -20,10 +20,11 @@ args = parser.parse_args() with open(args.filename, 'rb') as f: - data: List[TMeasurement] = pickle.load(f) + data = pickle.load(f) + raw_results: List[TMeasurement] = data["results"] results = defaultdict(lambda: list()) - for v in data: + for v in raw_results: result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) if result is not None: KN = result.group(1) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 25ec9d6028627..51f24f3ba1774 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -40,4 +40,10 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], } diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1842fab8b2cac..f61fe3ceb978a 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { // is the layout f(x) = x template CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { return true; - else { + } else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && stride<0>(coalesced_layout) == 1) { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp similarity index 99% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index d407d66ab2aa6..7aa87feb4cce2 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -52,6 +52,7 @@ // clang-format off #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cute/tensor.hpp" namespace cutlass::epilogue::threadblock { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 0000000000000..c69e87999ae71 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,317 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c2x \ No newline at end of file diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 0000000000000..95764ecddc79f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,315 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 4fcfcd311aa91..a5beea1a35e49 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum): } } +VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + VLLMKernelScheduleTag: Dict[Union[ MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 2ad914f8e9868..90f226cf64c0a 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -3,6 +3,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" // this file extends: // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h @@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const& source) { - CUTE_INVALID_CONTROL_PATH( - "InterleavedNumericArrayConverter not implemented\n"); + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } return {}; } @@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< result_type operator()(source_type const& s) const { return convert(s); } }; -// TODO (LucasWilkinson): Implement -// for Array <= Array - -// .... - template struct ArrayConverterPacked32Bit { using result_type = Array; @@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { using ScalarConverter = NumericConverter; template - CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { if constexpr (sizeof(PackedSrc) == 1) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; } else if constexpr (sizeof(PackedSrc) == 2) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; } else { - static_assert(sizeof(PackedSrc) == 4); - return reinterpret_cast(source); + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); } } @@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { static_assert(std::is_same_v); static_assert(std::is_same_v); - return RegConvert32bit::template convert(to_reg(source)); + return RegConvert32bit::template convert(to_regs(source)); } friend class detail::VectorizedConverter; @@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { } }; +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + // for Array <= Array template struct NumericArrayConverter { @@ -148,7 +282,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -417,7 +554,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; PackedResultType r; // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of @@ -513,7 +652,8 @@ struct NumericArrayConverter { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; // Hold output BF16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -788,6 +930,61 @@ struct NumericArrayConverter { #endif +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 0000000000000..500ed508c8303 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index ee801e16573d4..dbb72e8bbd3f5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,6 +8,10 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" + +using namespace vllm; + /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_fp8_dispatch< - cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_fp8_dispatch( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 6329ff63623e2..d03242f44ab1d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,6 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - template - using ColOrScalarLoad = - cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = - cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using RowOrZeroLoad = - cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - // it would technically work but no use case as data_ptr is never nullptr - static_assert(!std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - static_assert(std::is_same_v>); - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : protected ScaledEpilogueBase { - protected: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 292c9e4b34e1c..33581a63d4c3d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -23,11 +23,12 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; +using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_sm90_epilogue( + c, a, b, a_scales, b_scales); } } @@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d126af1849024..ac63afe79a255 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -3,8 +3,10 @@ import os import shutil from collections.abc import Iterable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from copy import deepcopy +from dataclasses import dataclass, fields +from functools import reduce +from typing import Dict, List, Optional, Tuple, Union import jinja2 # yapf conflicts with isort for this block @@ -14,7 +16,10 @@ MixedInputKernelScheduleType, TileSchedulerTag, TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, VLLMDataTypeTag, + VLLMDataTypeNames, + VLLMDataTypeSize, VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, VLLMKernelScheduleTag) # yapf: enable @@ -27,49 +32,125 @@ #include "../machete_mm_launcher.cuh" namespace machete { -using GemmDispatcher_ = GemmDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -{% for s in schedules %}extern torch::Tensor -impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); -{% endfor %} -template <> -torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + +{% for impl_config in impl_configs %} +{% set type_sig = gen_type_sig(impl_config.types) -%} +{% for s in impl_config.schedules %} +extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs); +{%- endfor %} + +torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { [[maybe_unused]] auto M = args.A.size(0); [[maybe_unused]] auto N = args.B.size(1); [[maybe_unused]] auto K = args.A.size(1); - if (!args.schedule) { - {%- for cond, s in heuristic %} + if (!args.maybe_schedule) { + {%- for cond, s in impl_config.heuristic %} {%if cond is not none%}if ({{cond}}) {%- else %}else {%- endif %} - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %} } - {% for s in schedules %} - if (*args.schedule == "{{ gen_sch_name(s) }}") { - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); - } - {% endfor %} + {%- for s in impl_config.schedules %} + if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}") + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args); + {%- endfor %} TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " - "schedule = ", *args.schedule); + "schedule = ", *args.maybe_schedule); } +{%- endfor %} + -template <> -std::vector GemmDispatcher_::supported_schedules() { - return { - {% for s in schedules -%} - "{{ gen_sch_name(s) }}"{{ ", - " if not loop.last }}{%- endfor %} - }; +static inline std::optional maybe_scalartype( + c10::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set type_sig = gen_type_sig(t) -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!maybe_g_scales_type{%endif%} + && {%if t.b_group_zeropoint != void -%} + maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!maybe_g_zeros_type{%endif%} + && {%if t.b_channel_scale != void -%} + maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} + {%- else %}!maybe_ch_scales_type{%endif%} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} + {%- else %}!maybe_tok_scales_type{%endif%} + ) { + return mm_dispatch_{{type_sig}}(args); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \\n", + {%- for impl_config in impl_configs %} + {% set t = impl_config.types -%} + "\\t{{gen_type_option_name(t)}}\\n", + {%- endfor %} + ""); } +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set schs = impl_config.schedules -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && args.a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!args.maybe_group_scales_type{%endif%} + && {%if t.b_group_zeropoint != void-%} + args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!args.maybe_group_zeros_type{%endif%} + ) { + return { + {%- for s in impl_config.schedules %} + "{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %} + {%- endfor %} + }; + } + {%- endfor %} + + return {}; +}; + }; // namespace machete """ @@ -77,20 +158,10 @@ #include "../machete_mm_launcher.cuh" namespace machete { -template -using Kernel = MacheteKernelTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, - Config, with_C, with_scales, with_zeropoints>; - -{% for sch in schedules %} -{% set schedule_name = gen_sch_name(sch) -%} -struct sch_{{schedule_name}} { + +{% for sch in unique_schedules(impl_configs) %} +{% set sch_sig = gen_sch_sig(sch) -%} +struct sch_{{sch_sig}} { using TileShapeNM = Shape<{{ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ @@ -101,27 +172,34 @@ using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; }; - +{% endfor %} + +{% for impl_config in impl_configs %} +{% set t = impl_config.types -%} +{% set schs = impl_config.schedules -%} +{% set type_sig = gen_type_sig(t) -%} + +template +using Kernel_{{type_sig}} = MacheteKernelTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[t.b]}}, // ElementB + {{DataTypeTag[t.out]}}, // ElementD + {{DataTypeTag[t.accumulator]}}, // Accumulator + {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT + {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT + {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Sch>; + +{% for sch in schs %} +{% set sch_sig = gen_sch_sig(sch) -%} torch::Tensor -impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { - bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), - with_zeropoints = args.zeros.has_value(); - - {% for s in specializations %} - if (with_C == {{s.with_C|lower}} - && with_zeropoints == {{s.with_zeropoints|lower}} - && with_scales == {{s.with_scales|lower}}) { - return run_impl>(args); - }{% endfor %} - - TORCH_CHECK_NOT_IMPLEMENTED( - false, "for the sake of compile times and binary size machete_mm(..) is " - " not implemented for with_C=", with_C, ", with_scales=", with_scales, - ", with_zeropoints=", with_zeropoints, - " (for {{type_name}}_sch_{{schedule_name}})"); +impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) { + return run_impl>(args); } -{% endfor %} +{%- endfor %} +{%- endfor %} }; // namespace machete """ @@ -130,26 +208,34 @@ #include "../machete_prepack_launcher.cuh" namespace machete { -using PrepackBDispatcher_ = PrepackBDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -using PrepackedLayoutB = PrepackedLayoutBTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; - -template <> -torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { - return prepack_impl(B); + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + {%- for t in types %} + {% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %} + if (args.a_type == {{TorchTypeTag[t.a]}} + && args.b_type.size_bits() == {{t.b_num_bits}} + && convert_type == {{TorchTypeTag[t.convert]}}) { + return prepack_impl< + PrepackedLayoutBTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[b_type]}}, // ElementB + {{DataTypeTag[t.convert]}}, // ElementConvert + {{DataTypeTag[t.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + >(args.B); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); } + }; // namespace machete """ @@ -166,32 +252,34 @@ class ScheduleConfig: tile_scheduler: TileSchedulerType -@dataclass +@dataclass(frozen=True) class TypeConfig: - element_a: DataType - element_b: Union[DataType, VLLMDataType] - element_b_scale: DataType - element_b_zeropoint: DataType - element_d: DataType + a: DataType + b: Union[DataType, VLLMDataType] + b_group_scale: DataType + b_group_zeropoint: DataType + b_channel_scale: DataType + a_token_scale: DataType + out: DataType accumulator: DataType -@dataclass -class Specialization: - with_C: bool - with_zeropoints: bool - with_scales: bool +@dataclass(frozen=True) +class PrepackTypeConfig: + a: DataType + b_num_bits: int + convert: DataType + accumulator: DataType @dataclass class ImplConfig: - type_config: TypeConfig - schedule_configs: List[ScheduleConfig] - specializations: List[Specialization] + types: TypeConfig + schedules: List[ScheduleConfig] heuristic: List[Tuple[Optional[str], ScheduleConfig]] -def generate_schedule_name(schedule_config: ScheduleConfig) -> str: +def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) @@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str: f"_{epilogue_schedule}_{tile_scheduler}") -# mostly unique shorter schedule_name -def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: +# mostly unique shorter sch_sig +def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } - schedule_name = generate_schedule_name(schedule_config) + sch_sig = generate_sch_sig(schedule_config) for orig, terse in kernel_terse_names_replace.items(): - schedule_name = schedule_name.replace(orig, terse) - return schedule_name + sch_sig = sch_sig.replace(orig, terse) + return sch_sig # unique type_name -def generate_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - element_d = VLLMDataTypeNames[kernel_type_config.element_d] - accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] - element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] - element_zeropoint = VLLMDataTypeNames[ - kernel_type_config.element_b_zeropoint] - - return (f"{element_a}{element_b}{element_d}" - f"{accumulator}{element_scale}{element_zeropoint}") - +def generate_type_signature(kernel_types: TypeConfig): + return str("".join([ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ])) -# non-unique shorter type_name -def generate_terse_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - return f"{element_a}{element_b}" +def generate_type_option_name(kernel_types: TypeConfig): + return ", ".join([ + f"{field.name.replace('b_', 'with_')+'_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ]) def is_power_of_two(n): @@ -263,13 +345,36 @@ def _to_cute_constant(value: int): return _to_cute_constant(value) +def unique_schedules(impl_configs: List[ImplConfig]): + return list( + set(sch for impl_config in impl_configs + for sch in impl_config.schedules)) + + +def unsigned_type_with_bitwidth(num_bits): + return { + 4: DataType.u4, + 8: DataType.u8, + 16: DataType.u16, + 32: DataType.u32, + 64: DataType.u64, + }[num_bits] + + template_globals = { + "void": DataType.void, "DataTypeTag": VLLMDataTypeTag, + "VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag, + "TorchTypeTag": VLLMDataTypeTorchDataTypeTag, "KernelScheduleTag": VLLMKernelScheduleTag, "EpilogueScheduleTag": EpilogueScheduleTag, "TileSchedulerTag": TileSchedulerTag, "to_cute_constant": to_cute_constant, - "gen_sch_name": generate_terse_schedule_name, + "gen_sch_sig": generate_terse_sch_sig, + "gen_type_sig": generate_type_signature, + "unique_schedules": unique_schedules, + "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, + "gen_type_option_name": generate_type_option_name } @@ -284,42 +389,82 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_config: ImplConfig, num_impl_files=1): +def create_sources(impl_configs: List[ImplConfig], num_impl_files=8): sources = [] - type_name = generate_type_signature(impl_config.type_config) - terse_type_name = generate_terse_type_signature(impl_config.type_config) - sources.append(( - f"machete_mm_{terse_type_name}", - mm_dispatch_template.render(type_name=type_name, - type_config=impl_config.type_config, - schedules=impl_config.schedule_configs, - heuristic=impl_config.heuristic), + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), )) + prepack_types = [] + for impl_config in impl_configs: + convert_type = impl_config.types.a \ + if impl_config.types.b_group_scale == DataType.void \ + else impl_config.types.b_group_scale + prepack_types.append( + PrepackTypeConfig( + a=impl_config.types.a, + b_num_bits=VLLMDataTypeSize[impl_config.types.b], + convert=convert_type, + accumulator=impl_config.types.accumulator, + )) + + def prepacked_type_key(prepack_type: PrepackTypeConfig): + # For now we we can just use the first accumulator type seen since + # the tensor core shapes/layouts don't vary based on accumulator + # type so we can generate less code this way + return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) + + unique_prepack_types = [] + prepack_types_seen = set() + for prepack_type in prepack_types: + key = prepacked_type_key(prepack_type) + if key not in prepack_types_seen: + unique_prepack_types.append(prepack_type) + prepack_types_seen.add(key) + sources.append(( - f"machete_prepack_{terse_type_name}", - prepack_dispatch_template.render( - type_name=type_name, - type_config=impl_config.type_config, - ), + "machete_prepack", + prepack_dispatch_template.render(types=unique_prepack_types, ), )) - num_schedules = len(impl_config.schedule_configs) - schedules_per_file = math.ceil(num_schedules / num_impl_files) - for part, i in enumerate(range(0, num_schedules, schedules_per_file)): - file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + # Split up impls across files + num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) + num_impls_per_file = math.ceil(num_impls / num_impl_files) + + files_impls: List[List[ImplConfig]] = [[]] + + curr_num_impls_assigned = 0 + curr_impl_in_file = 0 + curr_impl_configs = deepcopy(list(reversed(impl_configs))) + + while curr_num_impls_assigned < num_impls: + room_left_in_file = num_impls_per_file - curr_impl_in_file + if room_left_in_file == 0: + files_impls.append([]) + room_left_in_file = num_impls_per_file + curr_impl_in_file = 0 + + curr_ic = curr_impl_configs[-1] + if len(curr_ic.schedules) >= room_left_in_file: + # Break apart the current impl config + tmp_ic = deepcopy(curr_ic) + tmp_ic.schedules = curr_ic.schedules[:room_left_in_file] + curr_ic.schedules = curr_ic.schedules[room_left_in_file:] + files_impls[-1].append(tmp_ic) + else: + files_impls[-1].append(curr_ic) + curr_impl_configs.pop() + curr_num_impls_assigned += len(files_impls[-1][-1].schedules) + curr_impl_in_file += len(files_impls[-1][-1].schedules) + for part, file_impls in enumerate(files_impls): sources.append(( - f"machete_mm_{terse_type_name}_impl_part{part}", - mm_impl_template.render( - type_name=type_name, - type_config=impl_config.type_config, - schedules=file_schedules, - specializations=impl_config.specializations, - ), + f"machete_mm_impl_part{part+1}", + mm_impl_template.render(impl_configs=file_impls), )) + return sources @@ -328,187 +473,169 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedule_common_params = dict( + sch_common_params = dict( kernel_schedule=TmaMI, epilogue_schedule=TmaCoop, tile_scheduler=TileSchedulerType.StreamK, ) - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - default_heuristic = [ + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + default_tile_heuristic_config = { #### M = 257+ - ( - "M > 256 && K <= 16384 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 256", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + "M > 256": ((128, 256), (2, 1, 1)), #### M = 129-256 - ( - "M > 128 && K <= 4096 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128 && K <= 8192 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + "M > 128": ((128, 256), (2, 1, 1)), #### M = 65-128 - ( - "M > 64 && K <= 4069 && N <= 4069", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K <= 4069 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K >= 8192 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), #### M = 33-64 - ( - "M > 32 && K <= 6144 && N <= 6144", - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32 && K >= 16384 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), #### M = 17-32 - ( - "M > 16 && K <= 12288 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 16", - ScheduleConfig( - tile_shape_mn=(256, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - ( - "N >= 26624", - ScheduleConfig( - tile_shape_mn=(256, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - None, - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in default_tile_heuristic_config.items() ] - # Do not use schedules = list(set(...)) because we need to make sure - # the output list is deterministic; otherwise the generated kernel file - # will be non-deterministic and causes ccache miss. - schedules = [] - for _, schedule_config in default_heuristic: - if schedule_config not in schedules: - schedules.append(schedule_config) + def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) + return schedules impl_configs = [] GPTQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for element_a in (DataType.f16, DataType.bf16)) - - GPTQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=False, with_scales=True) - ] + ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(GPTQ_kernel_specializations), + ImplConfig(x[0], x[1], x[2]) + for x in zip(GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ] AWQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (DataType.u4, DataType.u8) - for element_a in (DataType.f16, DataType.bf16)) + ) for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16)) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip(AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic)) + ] - AWQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=True, with_scales=True) + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # TODO (LucasWilkinson): Further tuning required + qqq_tile_heuristic_config = { + #### M = 257+ + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # "M > 256": ((128, 256), (2, 1, 1)), + "M > 256": ((128, 128), (2, 1, 1)), + #### M = 129-256 + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 128": ((128, 256), (2, 1, 1)), + "M > 128": ((128, 128), (2, 1, 1)), + #### M = 65-128 + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), + #### M = 33-64 + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), + #### M = 17-32 + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), + #### M = 1-16 + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + qqq_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in qqq_tile_heuristic_config.items() + ] + + QQQ_kernel_types = [ + *(TypeConfig( + a=DataType.s8, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.s32, + ) for b_group_scale in (DataType.f16, DataType.void)), + *(TypeConfig( + a=DataType.e4m3, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.f32, + ) for b_group_scale in (DataType.f16, DataType.void)), ] impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(AWQ_kernel_specializations), - itertools.repeat(default_heuristic)) + ImplConfig(x[0], x[1], x[2]) + for x in zip(QQQ_kernel_types, + itertools.repeat(get_unique_schedules(qqq_heuristic)), + itertools.repeat(qqq_heuristic)) ] output_dir = os.path.join(SCRIPT_DIR, "generated") @@ -521,12 +648,11 @@ def generate(): os.makedirs(output_dir) # Render each group of configurations into separate files - for impl_config in impl_configs: - for filename, code in create_sources(impl_config): - filepath = os.path.join(output_dir, f"{filename}.cu") - with open(filepath, "w") as output_file: - output_file.write(code) - print(f"Rendered template to {filepath}") + for filename, code in create_sources(impl_configs): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") if __name__ == "__main__": diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index e8e7b14de0da1..816f33a1078e5 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -171,6 +171,10 @@ struct MacheteCollectiveMma { make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + using SmemLayoutAtomARowMajor = decltype(rs_smem_selector(TileShape_MNK{})), @@ -288,14 +292,7 @@ struct MacheteCollectiveMma { static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutACopy = decltype(tile_to_shape( - SmemLayoutAtomARowMajor{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), - Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - + // Tile along modes in a way that maximizes the TMA box size using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), @@ -428,12 +425,12 @@ struct MacheteCollectiveMma { // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0))))); using ATensor = decltype(make_tensor( get_logical_ptr(static_cast(nullptr)), - shape(GmemLayoutA::TVbNbKL_to_offset( + shape(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0)))), PrepackedStrideA{})); @@ -450,8 +447,8 @@ struct MacheteCollectiveMma { static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { return make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), - shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any } @@ -584,7 +581,7 @@ struct MacheteCollectiveMma { typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Zero tma_load_zero; - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); tma_load_a = make_tma_copy_A( make_logical_tensor(ptr_A, shape(layout), stride(layout))); @@ -722,7 +719,7 @@ struct MacheteCollectiveMma { // (TILE_V,TILE_B,m,k,l) auto make_gA_mkl = [&]() { // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); return local_tile(mA_mkl, make_shape(size<0>(layout), PPBlocksPerTile_MK{}), diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 4d41b8d291484..d4d19ae5deec7 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -21,6 +21,8 @@ #include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" #include "machete_collective_builder.cuh" #include "machete_prepacked_layout.cuh" #include "machete_interleaving_utils.cuh" @@ -37,27 +39,42 @@ using namespace cute; // W is quantized, in this situation or right-hand operand is quantized so // we compute the transpose to move it to the left-hand side. template + typename AccumulatorT, typename GroupScaleT, typename GroupZeroT, + typename ChannelScaleT, typename TokenScaleT, class KernelSchedule, + typename ScheduleConfig> struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; using ElementC = cute::conditional_t; - using ElementZ = ZeroT; - using ElementS = ScaleT; - - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; using BTypeTuple = cute::conditional_t< - with_scales, - cute::conditional_t, - cute::tuple>, + with_group_scales, + cute::conditional_t, + cute::tuple>, ElementB>; using LayoutA = cutlass::layout::RowMajor; @@ -71,8 +88,8 @@ struct MacheteKernelTemplate { using StrideA = cutlass::detail::TagToStrideA_t; using StrideC = cutlass::detail::TagToStrideA_t; using StrideD = cutlass::detail::TagToStrideA_t; - using StrideS = cutlass::detail::TagToStrideA_t; - using StrideZ = StrideS; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; @@ -85,8 +102,8 @@ struct MacheteKernelTemplate { using OperatorClass = cutlass::arch::OpClassTensorOp; using PrepackedLayoutB = - PrepackedLayoutBTemplate; + PrepackedLayoutBTemplate; static int constexpr TileShapeK = 128 * 8 / cutlass::sizeof_bits::value; @@ -103,12 +120,42 @@ struct MacheteKernelTemplate { using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; using TileScheduler = typename ScheduleConfig::TileScheduler; + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, - AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, - EpilogueSchedule>::CollectiveOp; + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::VLLMCollectiveBuilder< @@ -131,26 +178,44 @@ struct MacheteKernelTemplate { using MainloopArguments = typename GemmKernel::MainloopArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments; - template static Arguments create_arguments( cudaStream_t stream, - ElementA const* A_ptr, // A is an MxK matrix - Layout const& layout_A, - ElementB const* B_ptr, // B is an KxN prepacked matrix - ElementD* D_ptr, // D is an MxN matrix - Layout const& layout_D, - ElementC const* C_ptr, // C is an MxN matrix - std::optional> const& layout_C, - ElementS const* S_ptr, // S is an scale_KxN matrix - std::optional> const& layout_S, - ElementZ const* Z_ptr, // Z is an scale_KxN matrix - std::optional> const& layout_Z, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { - static_assert(!with_zeropoints || with_scales); - - int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + c10::optional const& maybe_g_scales, // scale_KxN matrix + c10::optional const& maybe_g_zeros, // scale_KxN matrix + c10::optional maybe_group_size, + c10::optional const& maybe_ch_scales, // len N vector + c10::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); int const group_size = maybe_group_size == -1 ? K : maybe_group_size.value_or(K); @@ -159,26 +224,28 @@ struct MacheteKernelTemplate { TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); - if constexpr (with_C) { - TORCH_CHECK(C_ptr && layout_C); + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); } else { - TORCH_CHECK(!C_ptr, "C not supported"); + TORCH_CHECK(!S_group_ptr, "Scales not supported"); } - if constexpr (with_scales) { - TORCH_CHECK(S_ptr && layout_S); - TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); } else { - TORCH_CHECK(!S_ptr, "Scales not supported"); + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); } - if constexpr (with_zeropoints) { - TORCH_CHECK(Z_ptr && layout_Z); - TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); - TORCH_CHECK(layout_S && *layout_Z == *layout_S, - "Scales and zeros must have the same layout"); - } else { - TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); } // Transpose A and D @@ -186,24 +253,33 @@ struct MacheteKernelTemplate { // for B (which is At) auto stride_At = layout_A.stride(); auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); - auto stride_Ct = stride_Dt; - if (layout_C) { - stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); - } MainloopArguments mainloop_arguments{}; - EpilogueArguments epilogue_arguments{ - {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } - if constexpr (with_scales && with_zeropoints) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); - mainloop_arguments = - MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, - S_ptr, stride_S, group_size, Z_ptr}; - } else if constexpr (with_scales) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); mainloop_arguments = MainloopArguments{ - B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; } else { mainloop_arguments = MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 60a4ed60535b7..4b0da5b303e0c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -5,73 +5,61 @@ #include "machete_mm_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { -struct PyTorchArguments { +struct MMArgs { torch::Tensor const& A; torch::Tensor const& B; - c10::optional const& scales; - c10::optional const& zeros; - c10::optional group_size; - c10::optional const& C; - c10::optional alpha; - c10::optional beta; - c10::optional schedule; + vllm::ScalarType const& b_type; + c10::optional const& maybe_out_type; + c10::optional const& maybe_group_scales; + c10::optional const& maybe_group_zeros; + c10::optional maybe_group_size; + c10::optional const& maybe_channel_scales; + c10::optional const& maybe_token_scales; + c10::optional maybe_schedule; }; +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; + c10::optional maybe_group_zeros_type; + c10::optional maybe_channel_scales_type; + c10::optional maybe_token_scales_type; + c10::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + template -torch::Tensor run_impl(PyTorchArguments args) { +torch::Tensor run_impl(MMArgs args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using EleA = typename MacheteKernel::ElementA; - using EleB = typename MacheteKernel::ElementB; - using EleC = typename MacheteKernel::ElementC; - using EleD = typename MacheteKernel::ElementD; - using EleScale = typename MacheteKernel::ElementS; - using EleZero = typename MacheteKernel::ElementZ; - - using StrideA = typename MacheteKernel::StrideA; - using StrideC = typename MacheteKernel::StrideC; - using StrideD = typename MacheteKernel::StrideD; - using StrideS = typename MacheteKernel::StrideS; - using StrideZ = typename MacheteKernel::StrideZ; - int M = args.A.size(0); int N = args.B.size(1); int K = args.A.size(1); // Allocate output - torch::Tensor D = - torch::empty({M, N}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) - .device(device)); - - auto const &A = args.A, &B = args.B; - auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; - - auto layout_A = make_cute_layout(A, "A"); - auto layout_D = make_cute_layout(D, "D"); - auto layout_C = maybe_make_cute_layout(C, "C"); - auto layout_S = maybe_make_cute_layout(scales, "scales"); - auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); - - auto A_ptr = static_cast(A.const_data_ptr()); - auto B_ptr = static_cast(B.const_data_ptr()); - auto D_ptr = static_cast(D.mutable_data_ptr()); - auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); - auto S_ptr = - static_cast(scales ? scales->const_data_ptr() : nullptr); - auto Z_ptr = - static_cast(zeros ? zeros->const_data_ptr() : nullptr); + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); auto arguments = MacheteKernel::create_arguments( - stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, - layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size); + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); @@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) { return D; }; -template -struct GemmDispatcher { - static torch::Tensor dispatch(PyTorchArguments args); - static std::vector supported_schedules(); -}; - }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh index f23483f928b47..d002355ca49d6 100644 --- a/csrc/quantization/machete/machete_prepack_kernel.cuh +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -6,31 +6,49 @@ namespace machete { -template -static __global__ void prepack_B_kernel(BInTensor B_in, - BTiledOutTensor B_tiled_out) { - auto tB_in = local_tile(B_in, TileShapeNKL{}, - make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); - auto tB_out = B_tiled_out(make_coord(_, _), - make_coord(blockIdx.x, blockIdx.y), blockIdx.z); +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, - Layout, Stride<_32, _1>>{}, - Layout>{}); + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); - Tensor thr_tile_S = thr_copy.partition_S(tB_in); - Tensor thr_tile_D = thr_copy.partition_D(tB_out); + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); // Construct a register-backed Tensor with the same shape as each thread's // partition - auto fragment = make_tensor(shape(thr_tile_D)); + auto fragment = make_tensor(shape(thr_tB_in_linear)); - // Copy from GMEM to RMEM and from RMEM to GMEM - copy(tiled_copy, thr_tile_S, fragment); - copy(Copy_Atom{}, fragment, thr_tile_D); + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); } template @@ -44,18 +62,15 @@ static void prepack_B_template( TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); - TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); - auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); - auto B_tiled_out = - make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); - prepack_B_kernel - <<>>(B_in, B_tiled_out); + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); } }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index a33d8f9484cfe..3486d28be2126 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -2,9 +2,17 @@ #include "machete_prepack_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; +}; + template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); @@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) { return D; }; -template -struct PrepackBDispatcher { - static torch::Tensor dispatch(torch::Tensor B); -}; +torch::Tensor prepack_B_dispatch(PrepackBArgs args); }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 78e2cc5eec7d8..680a858a893c1 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {}; // The contract here is that the `TiledMma` determined below matches the one // ultimately used in the kernel. (this is also why the other element types are // required along with the kernel schedule) -template // clang-format on @@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; - using ElementD = ElementD_; - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementMma = MmaType; - // Only use interleaved layouts for subbyte weights, prmt instructions makes - // non-interleaved layouts for 8bit+ weights efficient enough we don't need - // iterleaved layouts + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, using IlvdBlkLayout = std::conditional_t< std::is_same_v, - std::conditional_t <= 4, - decltype(get_interleaved_blk_layout< - ElementB, sizeof_bits_v, 32>()), - void>, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, IlvBlkLayout_>; // TODO (LucasWilkinson): compare the performance for other sizes @@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate { // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} auto frgV = get<1, 0>(layout_no_interleave); auto ilvdBlk = IlvdBlkLayout{}; - static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); auto ilvd_FrgV = make_layout( make_shape(shape(ilvdBlk), Int{}), make_stride(stride(ilvdBlk), size(ilvdBlk))); @@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + return make_layout(coalesce(get<0>(layout)), get<1>(layout), + get<2>(layout)); + } + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( @@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) template CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 9f9073ded6191..da2c2fb0d3e77 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -8,89 +8,61 @@ namespace machete { using namespace vllm; -// -// Utils (type dispatching) -// - -template -static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { - if (type == vllm::kU4) { - return fn(cutlass::uint4b_t{}); - } else if (type == vllm::kU8) { - return fn(cutlass::uint8_t{}); - } else if (type == vllm::kU4B8) { - return fn(cutlass::vllm_uint4b8_t{}); - } else if (type == vllm::kU8B128) { - return fn(cutlass::vllm_uint8b128_t{}); - } else { - TORCH_CHECK(false, "Unsupported type ", type.str()); - } -} - -#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ - AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) - -#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ - AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) - -// -// Interface -// - -std::vector supported_schedules(ScalarTypeId const btype_id) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - vllm::ScalarType b_type = ScalarType::from_id(btype_id); - return scalar_type_dispatch(b_type, [&](auto BType) { - return GemmDispatcher::supported_schedules(); +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + c10::optional maybe_group_scales_type, + c10::optional maybe_group_zeros_type, + c10::optional maybe_channel_scales_type, + c10::optional maybe_token_scales_type, + c10::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif } -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeId const btype_id, - c10::optional const& scales, - c10::optional const& zeros, - c10::optional group_size, - c10::optional const& C, - c10::optional alpha, c10::optional beta, - c10::optional schedule) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - ScalarType const btype = ScalarType::from_id(btype_id); - auto args = PyTorchArguments{.A = A, - .B = B, - .scales = scales, - .zeros = zeros, - .group_size = group_size, - .C = C, - .alpha = alpha, - .beta = beta, - .schedule = schedule}; - - return scalar_type_dispatch(btype, [&](auto BType) { - return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( - A.scalar_type(), "machete_gemm", [&] { - using ComputeType = equivalent_cutlass_type_t; - return GemmDispatcher::dispatch(args); - }); - }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + c10::optional const& maybe_out_type, + c10::optional const& maybe_group_scales, + c10::optional const& maybe_group_zeros, + c10::optional maybe_group_size, + c10::optional const& maybe_channel_scales, + c10::optional const& maybe_token_scales, + c10::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); } -torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { - ScalarType const btype = ScalarType::from_id(btype_id); - return scalar_type_dispatch(btype, [&](auto BType) { - return PrepackBDispatcher::dispatch(B); - }); +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + c10::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); - m.impl("machete_gemm", &gemm); + m.impl("machete_mm", &mm); } // use CatchAll since supported_schedules has no tensor arguments diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 229fd554d3eee..e4cc7ec951848 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. - ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_gemm(Tensor A, Tensor B, int btype, " - " Tensor? scales, Tensor? zeros, int? group_size, " - " Tensor? C, float? alpha, float? beta, str? schedule)" - "-> Tensor"); - ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); + "machete_supported_schedules(" + " ScalarType a_type," + " int b_type," + " ScalarType? maybe_group_scales_type," + " ScalarType? maybe_group_zeros_type," + " ScalarType? maybe_channel_scales_type," + " ScalarType? maybe_token_scales_type," + " ScalarType? maybe_out_type" + ") -> str[]"); + ops.def( + "machete_mm(" + " Tensor A," + " Tensor B," + " int b_type," + " ScalarType? out_type," + " Tensor? group_scales," + " Tensor? group_zeros," + " int? group_size," + " Tensor? channel_scales," + " Tensor? token_scales," + " str? schedule" + ") -> Tensor"); + ops.def( + "machete_prepack_B(" + " Tensor B," + " ScalarType a_type," + " int b_type," + " ScalarType? group_scales_type" + ") -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md index 52fbf9a577c7e..c6f83b3a92ca0 100644 --- a/docs/source/community/sponsors.md +++ b/docs/source/community/sponsors.md @@ -15,6 +15,7 @@ vLLM is a community project. Our compute resources for development and testing a - Dropbox - Google Cloud - Lambda Lab +- Nebius - NVIDIA - Replicate - Roblox diff --git a/docs/source/contributing/overview.rst b/docs/source/contributing/overview.rst index ac2d2b2fe4103..4cea0afdaea74 100644 --- a/docs/source/contributing/overview.rst +++ b/docs/source/contributing/overview.rst @@ -41,15 +41,6 @@ Testing Contribution Guidelines ======================= -DCO and Signed-off-by ----------------------- - -When contributing changes to this project, you must agree to the `DCO `_. -Commits must include a ``Signed-off-by:`` header which certifies agreement with -the terms of the `DCO `_. - -Using ``-s`` with ``git commit`` will automatically add this header. - Issues ------ @@ -61,7 +52,110 @@ If you encounter a bug or have a feature request, please `search existing issues Pull Requests & Code Reviews ---------------------------- -Please check the PR checklist in the `PR template `_ for a detailed guide for contribution. +Thank you for your contribution to vLLM! Before submitting the pull request, +please ensure the PR meets the following criteria. This helps vLLM maintain the +code quality and improve the efficiency of the review process. + +DCO and Signed-off-by +^^^^^^^^^^^^^^^^^^^^^ + +When contributing changes to this project, you must agree to the `DCO `_. +Commits must include a ``Signed-off-by:`` header which certifies agreement with +the terms of the `DCO `_. + +Using ``-s`` with ``git commit`` will automatically add this header. + +PR Title and Classification +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Only specific types of PRs will be reviewed. The PR title is prefixed +appropriately to indicate the type of change. Please use one of the following: + +- ``[Bugfix]`` for bug fixes. +- ``[CI/Build]`` for build or continuous integration improvements. +- ``[Doc]`` for documentation fixes and improvements. +- ``[Model]`` for adding a new model or improving an existing model. Model name + should appear in the title. +- ``[Frontend]`` For changes on the vLLM frontend (e.g., OpenAI API server, + ``LLM`` class, etc.) +- ``[Kernel]`` for changes affecting CUDA kernels or other compute kernels. +- ``[Core]`` for changes in the core vLLM logic (e.g., ``LLMEngine``, + ``AsyncLLMEngine``, ``Scheduler``, etc.) +- ``[Hardware][Vendor]`` for hardware-specific changes. Vendor name should + appear in the prefix (e.g., ``[Hardware][AMD]``). +- ``[Misc]`` for PRs that do not fit the above categories. Please use this + sparingly. + +.. note:: + If the PR spans more than one category, please include all relevant prefixes. + +Code Quality +^^^^^^^^^^^^ + +The PR needs to meet the following code quality standards: + +- We adhere to `Google Python style guide + `_ and `Google C++ style guide + `_. +- Pass all linter checks. Please use `format.sh + `_ to format your + code. +- The code needs to be well-documented to ensure future contributors can easily + understand the code. +- Include sufficient tests to ensure the project stays correct and robust. This + includes both unit tests and integration tests. +- Please add documentation to ``docs/source/`` if the PR modifies the + user-facing behaviors of vLLM. It helps vLLM users understand and utilize the + new features or changes. + +Adding or Changing Kernels +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Each custom kernel needs a schema and one or more implementations to be registered with PyTorch. + +- Make sure custom ops are registered following PyTorch guidelines: + `Custom C++ and CUDA Operators `_ + and `The Custom Operators Manual `_. +- Custom operations that return ``Tensors`` require meta-functions. + Meta-functions should be implemented and registered in Python so that dynamic + dims can be handled automatically. See above documents for a description of + meta-functions. +- Use `torch.library.opcheck() `_ + to test the function registration and meta-function for any registered ops. + See ``tests/kernels`` for examples. +- When changing the C++ signature of an existing op, the schema must be updated + to reflect the changes. +- If a new custom type is needed, see the following document: + `Custom Class Support in PT2 `_. + +Notes for Large Changes +^^^^^^^^^^^^^^^^^^^^^^^ + +Please keep the changes as concise as possible. For major architectural changes +(>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue +(RFC) discussing the technical design and justification. Otherwise, we will tag +it with ``rfc-required`` and might not go through the PR. + +What to Expect for the Reviews +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The goal of the vLLM team is to be a *transparent reviewing machine*. We would +like to make the review process transparent and efficient and make sure no +contributor feels confused or frustrated. However, the vLLM team is small, so we +need to prioritize some PRs over others. Here is what you can expect from the +review process: + +- After the PR is submitted, the PR will be assigned to a reviewer. Every + reviewer will pick up the PRs based on their expertise and availability. +- After the PR is assigned, the reviewer will provide status updates every 2-3 + days. If the PR is not reviewed within 7 days, please feel free to ping the + reviewer or the vLLM team. +- After the review, the reviewer will put an ``action-required`` label on the PR + if there are changes required. The contributor should address the comments and + ping the reviewer to re-review the PR. +- Please respond to all comments within a reasonable time frame. If a comment + isn't clear or you disagree with a suggestion, feel free to ask for + clarification or discuss the suggestion. Thank You --------- diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst index 15f0c8ccf77ee..58a888b17ba53 100644 --- a/docs/source/design/class_hierarchy.rst +++ b/docs/source/design/class_hierarchy.rst @@ -1,3 +1,5 @@ +.. _class_hierarchy: + vLLM's Class Hierarchy ======================= diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst new file mode 100644 index 0000000000000..bfca702b9267a --- /dev/null +++ b/docs/source/design/plugin_system.rst @@ -0,0 +1,62 @@ +.. _plugin_system: + +vLLM's Plugin System +==================== + +The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM. + +How Plugins Work in vLLM +------------------------ + +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`class_hierarchy`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. + +How vLLM Discovers Plugins +-------------------------- + +vLLM's plugin system uses the standard Python ``entry_points`` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin: + +.. code-block:: python + + # inside `setup.py` file + from setuptools import setup + + setup(name='vllm_add_dummy_model', + version='0.1', + packages=['vllm_add_dummy_model'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_model:register"] + }) + + # inside `vllm_add_dummy_model.py` file + def register(): + from vllm import ModelRegistry + + if "MyLlava" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model("MyLlava", + "vllm_add_dummy_model.my_llava:MyLlava") + +For more information on adding entry points to your package, please check the `official documentation `__. + +Every plugin has three parts: + +1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group ``vllm.general_plugins`` to register general plugins. This is the key of ``entry_points`` in the ``setup.py`` file. Always use ``vllm.general_plugins`` for vLLM's general plugins. + +2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the ``entry_points`` dictionary. In the example above, the plugin name is ``register_dummy_model``. Plugins can be filtered by their names using the ``VLLM_PLUGINS`` environment variable. To load only a specific plugin, set ``VLLM_PLUGINS`` to the plugin name. + +3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is ``vllm_add_dummy_model:register``, which refers to a function named ``register`` in the ``vllm_add_dummy_model`` module. + +What Can Plugins Do? +-------------------- + +Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. + +Guidelines for Writing Plugins +------------------------------ + +- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. + +Compatibility Guarantee +----------------------- + +vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. \ No newline at end of file diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 75ab2b6ba02dc..22cc684a1c778 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -44,15 +44,18 @@ Requirements Provision Cloud TPUs ==================== -You can provision Cloud TPUs using the `Cloud TPU API `_` -or the `queued resources `_` -API. This section shows how to create TPUs using the queued resource API. -For more information about using the Cloud TPU API, see `Create a Cloud TPU using the Create Node API `_. -`Queued resources `_ -enable you to request Cloud TPU resources in a queued manner. When you request -queued resources, the request is added to a queue maintained by the Cloud TPU -service. When the requested resource becomes available, it's assigned to your -Google Cloud project for your immediate exclusive use. +You can provision Cloud TPUs using the `Cloud TPU API `_ +or the `queued resources `_ +API. This section shows how to create TPUs using the queued resource API. For +more information about using the Cloud TPU API, see `Create a Cloud TPU using the Create Node API `_. +Queued resources enable you to request Cloud TPU resources in a queued manner. +When you request queued resources, the request is added to a queue maintained by +the Cloud TPU service. When the requested resource becomes available, it's +assigned to your Google Cloud project for your immediate exclusive use. + +.. note:: + In all of the following commands, replace the ALL CAPS parameter names with + appropriate values. See the parameter descriptions table for more information. Provision a Cloud TPU with the queued resource API -------------------------------------------------- @@ -68,6 +71,7 @@ Create a TPU v5e with 4 TPU chips: --runtime-version RUNTIME_VERSION \ --service-account SERVICE_ACCOUNT + .. list-table:: Parameter descriptions :header-rows: 1 @@ -81,12 +85,13 @@ Create a TPU v5e with 4 TPU chips: * - PROJECT_ID - Your Google Cloud project * - ZONE - - The `zone `_ where you - want to create your Cloud TPU. + - The GCP zone where you want to create your Cloud TPU. The value you use + depends on the version of TPUs you are using. For more information, see + `TPU regions and zones `_ * - ACCELERATOR_TYPE - - The TPU version you want to use. Specify the TPU version, followed by a - '-' and the number of TPU cores. For example `v5e-4` specifies a v5e TPU - with 4 cores. For more information, see `TPU versions `_. + - The TPU version you want to use. Specify the TPU version, for example + `v5litepod-4` specifies a v5e TPU with 4 cores. For more information, + see `TPU versions `_. * - RUNTIME_VERSION - The TPU VM runtime version to use. For more information see `TPU VM images `_. * - SERVICE_ACCOUNT @@ -98,7 +103,15 @@ Connect to your TPU using SSH: .. code-block:: bash - gcloud compute tpus tpu-vm ssh TPU_NAME + gcloud compute tpus tpu-vm ssh TPU_NAME --zone ZONE + +Install Miniconda + +.. code-block:: bash + + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + bash Miniconda3-latest-Linux-x86_64.sh + source ~/.bashrc Create and activate a Conda environment for vLLM: @@ -162,9 +175,11 @@ Run the Docker image with the following command: .. note:: - Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. - The compilation time may take 20~30 minutes in the first run. - However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). + Since TPU relies on XLA which requires static shapes, vLLM bucketizes the + possible input shapes and compiles an XLA graph for each shape. The + compilation time may take 20~30 minutes in the first run. However, the + compilation time reduces to ~5 minutes afterwards because the XLA graphs are + cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). .. tip:: @@ -173,7 +188,8 @@ Run the Docker image with the following command: .. code-block:: console from torch._C import * # noqa: F403 - ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory + ImportError: libopenblas.so.0: cannot open shared object file: No such + file or directory Install OpenBLAS with the following command: diff --git a/docs/source/index.rst b/docs/source/index.rst index a2abd2995b1cc..b04acbbce4169 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -101,6 +101,7 @@ Documentation models/engine_args models/lora models/vlm + models/structured_outputs models/spec_decode models/performance @@ -158,6 +159,7 @@ Documentation design/class_hierarchy design/huggingface_integration + design/plugin_system design/input_processing/model_inputs_index design/kernel/paged_attention design/multimodal/multimodal_index diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index c6d88cc38e99b..a70ebf99c746f 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -102,11 +102,11 @@ This method should load the weights from the HuggingFace's checkpoint file and a Finally, register your :code:`*ForCausalLM` class to the :code:`_VLLM_MODELS` in `vllm/model_executor/models/registry.py `_. 6. Out-of-Tree Model Integration --------------------------------------------- +-------------------------------- -We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. +You can integrate a model without modifying the vLLM codebase. Steps 2, 3, and 4 are still required, but you can skip steps 1 and 5. Instead, write a plugin to register your model. For general introduction of the plugin system, see :ref:`plugin_system`. -Just add the following lines in your code: +To register the model, use the following code: .. code-block:: python @@ -114,7 +114,7 @@ Just add the following lines in your code: from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) -If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`: +If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`: .. code-block:: python @@ -123,19 +123,8 @@ If your model imports modules that initialize CUDA, consider instead lazy-import ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM") .. important:: - If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. + If your model is a multimodal model, ensure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. Read more about that :ref:`here `. -If you are running api server with :code:`vllm serve `, you can wrap the entrypoint with the following code: - -.. code-block:: python - - from vllm import ModelRegistry - from your_code import YourModelForCausalLM - ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) - - if __name__ == '__main__': - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') - -Save the above code in a file and run it with :code:`python your_file.py `. +.. note:: + Although you can directly put these code snippets in your script using ``vllm.LLM``, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server. diff --git a/docs/source/models/structured_outputs.rst b/docs/source/models/structured_outputs.rst new file mode 100644 index 0000000000000..ff4ff7169fc5f --- /dev/null +++ b/docs/source/models/structured_outputs.rst @@ -0,0 +1,173 @@ +.. _structured_outputs: + +Structured Outputs +================== + +vLLM supports the generation of structured outputs using `outlines `_ or `lm-format-enforcer `_ as backends for the guided decoding. +This document shows you some examples of the different options that are available to generate structured outputs. + + +Online Inference (OpenAI API) +----------------------------- + +You can generate structured outputs using the OpenAI’s `Completions `_ and `Chat `_ API. + +The following parameters are supported, which must be added as extra parameters: + +- ``guided_choice``: the output will be exactly one of the choices. +- ``guided_regex``: the output will follow the regex pattern. +- ``guided_json``: the output will follow the JSON schema. +- ``guided_grammar``: the output will follow the context free grammar. +- ``guided_whitespace_pattern``: used to override the default whitespace pattern for guided json decoding. +- ``guided_decoding_backend``: used to select the guided decoding backend to use. + +You can see the complete list of supported parameters on the `OpenAI Compatible Server `_ page. + +Now let´s see an example for each of the cases, starting with the ``guided_choice``, as it´s the easiest one: + +.. code-block:: python + + from openai import OpenAI + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], + extra_body={"guided_choice": ["positive", "negative"]}, + ) + print(completion.choices[0].message.content) + + +The next example shows how to use the ``guided_regex``. The idea is to generate an email address, given a simple regex template: + +.. code-block:: python + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", + } + ], + extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]}, + ) + print(completion.choices[0].message.content) + +One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats. +For this we can use the ``guided_json`` parameter in two different ways: + +- Using directly a `JSON Schema `_ +- Defining a `Pydantic model `_ and then extracting the JSON Schema from it (which is normally an easier option). + +The next example shows how to use the ``guided_json`` parameter with a Pydantic model: + +.. code-block:: python + + from pydantic import BaseModel + from enum import Enum + + class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + + class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + + json_schema = CarDescription.model_json_schema() + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's", + } + ], + extra_body={"guided_json": json_schema}, + ) + print(completion.choices[0].message.content) + +.. tip:: + While not strictly necessary, normally it´s better to indicate in the prompt that a JSON needs to be generated and which fields and how should the LLM fill them. + This can improve the results notably in most cases. + + +Finally we have the ``guided_grammar``, which probably is the most difficult one to use but it´s really powerful, as it allows us to define complete languages like SQL queries. +It works by using a context free EBNF grammar, which for example we can use to define a specific format of simplified SQL queries, like in the example below: + +.. code-block:: python + + simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", + } + ], + extra_body={"guided_grammar": simplified_sql_grammar}, + ) + print(completion.choices[0].message.content) + +The complete code of the examples can be found on `examples/openai_chat_completion_structured_outputs.py `_. + + +Offline Inference +----------------- + +Offline inference allows for the same types of guided decoding. +To use it, we´ll need to configure the guided decoding using the class ``GuidedDecodingParams`` inside ``SamplingParams``. +The main available options inside ``GuidedDecodingParams`` are: + +- ``json`` +- ``regex`` +- ``choice`` +- ``grammar`` +- ``backend`` +- ``whitespace_pattern`` + +These parameters can be used in the same way as the parameters from the Online Inference examples above. +One example for the usage of the ``choices`` parameter is shown below: + +.. code-block:: python + + from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams + + llm = LLM(model="HuggingFaceTB/SmolLM2-1.7B-Instruct") + + guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) + sampling_params = SamplingParams(guided_decoding=guided_decoding_params) + outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, + ) + print(outputs[0].outputs[0].text) + +A complete example with all options can be found in `examples/offline_inference_structured_outputs.py `_. \ No newline at end of file diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 96a513d42753b..e902d393f2f70 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -446,7 +446,7 @@ Text Generation - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - - + - ✅︎ - ✅︎ * - :code:`H2OVLChatModel` - H2OVL diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index 9bf0cdb80376d..09f8e7112cf0c 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ + - ✅︎ - ✅︎ - ✗ - ✗ @@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ - - ✗ + - ✅︎ + - ✅︎ - ✗ - ✗ * - Marlin (GPTQ/AWQ/FP8) @@ -129,4 +129,4 @@ Notes: Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. -For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. \ No newline at end of file +For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 78965813b1213..79d032bf8b211 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -172,12 +172,20 @@ completion = client.chat.completions.create( ] ) ``` -Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like -`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which -format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify -between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match -this, unless explicitly specified. +Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the +request. vLLM provides best-effort support to detect this automatically, which is logged as a string like +*"Detected the chat template content format to be..."*, and internally converts incoming requests to match +the detected format, which can be one of: + +- `"string"`: A string. + - Example: `"Hello world"` +- `"openai"`: A list of dictionaries, similar to OpenAI schema. + - Example: `[{"type": "text", "text": "Hello world!"}]` + +If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument +to override which format to use. ## Command line arguments for the server diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..391ac6b9b6b03 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,22 +1,80 @@ +from dataclasses import asdict + from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def get_prompts(num_prompts: int): + # The default sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if num_prompts != len(prompts): + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] + + return prompts + + +def main(args): + # Create prompts + prompts = get_prompts(args.num_prompts) + + # Create a sampling params object. + sampling_params = SamplingParams(n=args.n, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_tokens=args.max_tokens) + + # Create an LLM. + # The default model is 'facebook/opt-125m' + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**asdict(engine_args)) + + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + group = parser.add_argument_group("SamplingParams options") + group.add_argument("--num-prompts", + type=int, + default=4, + help="Number of prompts used for inference") + group.add_argument("--max-tokens", + type=int, + default=16, + help="Generated output length for sampling") + group.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt') + group.add_argument('--temperature', + type=float, + default=0.8, + help='Temperature for text generation') + group.add_argument('--top-p', + type=float, + default=0.95, + help='top_p for text generation') + group.add_argument('--top-k', + type=int, + default=-1, + help='top_k for text generation') -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_structured_outputs.py b/examples/offline_inference_structured_outputs.py new file mode 100644 index 0000000000000..00d864606eeff --- /dev/null +++ b/examples/offline_inference_structured_outputs.py @@ -0,0 +1,78 @@ +from enum import Enum + +from pydantic import BaseModel + +from vllm import LLM, SamplingParams +from vllm.sampling_params import GuidedDecodingParams + +llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) + +# Guided decoding by Choice (list of possible options) +guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Regex +guided_decoding_params = GuidedDecodingParams(regex="\w+@\w+\.com\n") +sampling_params = SamplingParams(guided_decoding=guided_decoding_params, + stop=["\n"]) +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") +outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) +print(outputs[0].outputs[0].text) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +guided_decoding_params = GuidedDecodingParams(json=json_schema) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" +guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) diff --git a/examples/openai_chat_completion_structured_outputs.py b/examples/openai_chat_completion_structured_outputs.py new file mode 100644 index 0000000000000..8c059c7ca07ce --- /dev/null +++ b/examples/openai_chat_completion_structured_outputs.py @@ -0,0 +1,94 @@ +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", +) + +# Guided decoding by Choice (list of possible options) +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": "Classify this sentiment: vLLM is wonderful!" + }], + extra_body={"guided_choice": ["positive", "negative"]}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Regex +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": "\w+@\w+\.com\n", + "stop": ["\n"] + }, +) +print(completion.choices[0].message.content) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" + +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_grammar": simplified_sql_grammar}, +) +print(completion.choices[0].message.content) diff --git a/requirements-test.txt b/requirements-test.txt index a59b85023948b..9fe5effa1a241 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -217,7 +217,7 @@ mbstrdecoder==1.1.3 # dataproperty # pytablewriter # typepy -mistral-common[opencv]==1.5.1 +mistral_common[opencv]==1.5.0 # via # -r requirements-test.in # mistral-common diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index cc5bc2aca27c9..5f2b0d4fdf7c2 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -7,7 +7,6 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ import os -from contextlib import nullcontext import pytest @@ -229,7 +228,6 @@ def test_with_prefix_caching( max_num_batched_tokens = max_num_seqs = chunk_size outputs = {} # type: ignore - check_result = True for enable in (True, False): with vllm_runner( model, @@ -241,22 +239,15 @@ def test_with_prefix_caching( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, ) as vllm_model: - # It should fail when prefix caching is enable and chunk - # size is not a multiple of block size (16). - should_fail = chunk_size % 16 != 0 and enable - check_result &= not should_fail outputs[enable] = [] # Send the request one-by-one to ensure the cache is populated. - with pytest.raises(ValueError) if should_fail else nullcontext(): - for prompt in full_prompts: - outputs[enable] += vllm_model.generate_greedy([prompt], - max_tokens) + for prompt in full_prompts: + outputs[enable] += vllm_model.generate_greedy([prompt], + max_tokens) - # Check results only if we did not expect a failure. - if check_result: - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="w/o prefix caching", - name_1="with prefix caching", - ) + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="w/o prefix caching", + name_1="with prefix caching", + ) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index c631850ecdedb..45f56cbbd4b16 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op global_counter = 0 @@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - model = SillyModel(vllm_config=VllmConfig(), prefix='') + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c363a587a818e..8032304e95806 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -15,12 +15,10 @@ from torch.library import Library from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig -from vllm.plugins import set_compilation_config +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -272,9 +270,11 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -395,9 +395,11 @@ def benchmark(): else: set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda().to(torch.bfloat16) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 833589ba5dc9f..08747ebc58b75 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -3,7 +3,7 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f00334934cb46..4dfdfe21a67df 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,6 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb951..4db79b070fd8d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,10 +3,10 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3668c1fab6b89..74f66baaa5ea1 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -3,6 +3,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel class MyMod(torch.nn.Module): @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) + super().__init__(compiled_callable, + compilation_level=CompilationLevel.DYNAMO_ONCE) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 222c63a342a4b..729f10676888b 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index acd82065ae457..34ced55828bd7 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -5,6 +5,9 @@ from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, SequenceGroup from .utils import create_dummy_prompt @@ -14,7 +17,7 @@ def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] -def append_new_token(seq_group, token_id: int): +def append_new_token(seq_group: SequenceGroup, token_id: int): for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) @@ -121,6 +124,214 @@ def test_chunk(): assert out.num_batched_tokens == 57 +def test_concurrent_chunking(): + """Verify prefills are chunked properly when + --max-num-partial-prefills is > 1""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify both requests are chunked with half of max_num_batched_tokens each + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 32 + assert seq_group_meta[1].token_chunk_size == 32 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # After one iteration, both should have 60 - 32 = 28 tokens left to prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + +def test_concurrent_chunking_large_requests(): + """Verify large prefill requests are run one at a time""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + + # Verify only a single request is chunked, and it gets all 64 tokens + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 64 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + +def test_short_prompts_jump_long_prompts_in_queue(): + """Verify large prefill requests are punted behind smaller ones if + another large prefill request is already running""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add 2 large seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Add 2 small seq groups behind them + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i + 2), + prompt_length=40, # Very small prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Verify one large req and 1 small req chunked + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens + assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens + + # all 4 are prefilling + assert running[0].is_prefill() + assert running[1].is_prefill() + assert running[2].is_prefill() + assert running[3].is_prefill() + + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # in the second iteration, + # the first small request had only 8 tokens left + # so it went to decode + # The other small req is scheduled + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # the new small req got 64 - (32+8) tokens + assert (seq_group_meta[0].token_chunk_size == 24) + assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 + # the other small request had only 8 tokens left + assert seq_group_meta[2].token_chunk_size == 8 # 40-32 + + # notice the small request got to decode now + # this is because of max_num_partial_prefills logic + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert running[3].is_prefill() + + assert out.num_prefill_groups == 3 + assert out.num_batched_tokens == 64 + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the third iteration, + # the first small request has entered decode + # and other small req had 16 tokens left + # so it went to decode + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 + # small req prefilled 40-24=16 tokens + assert (seq_group_meta[1].token_chunk_size == 16) + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 49 # (32+16+1 decode) + + # both small requests have now reached decode + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() + + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the fourth iteration, both small requests are decoding + # so large request gets all the budget + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # large req gets 63 tokens (minus 1 for decode) + assert seq_group_meta[0].token_chunk_size == 63 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() + + # both the small seq groups have a new token appended + append_new_token(running[2], 1) + append_new_token(running[3], 1) + + # in the fifth iteration, large request gets all the budget + # while both small requests are decoding + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 62 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + def test_complex(): block_size = 4 max_seqs = 60 @@ -467,7 +678,7 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +def test_prefix_caching(): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 @@ -507,3 +718,86 @@ def test_perfix_caching(): assert seq_group_meta[1].token_chunk_size == 12 assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 62 + + +def test_prefix_caching_with_concurrent_partial_prefills(): + """Verify allocating full blocks when prefix caching is enabled with + --max-num-partial-prefills > 1.""" + block_size = 4 + max_seqs = 10 + max_model_len = 8000 + max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # To partially prefill both sequences, both can chunk up to 30 tokens + # But the next lowest multiple of the block size (4) is 28 + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + # On the next iteration, both sequences should finish prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # Both sequences have 50 - 28 = 22 tokens left to prefill. + # This is not a multiple of the block size, but we don't care since we don't + # cache the final partial block of prefix sequences + assert seq_group_meta[0].token_chunk_size == 22 + assert seq_group_meta[1].token_chunk_size == 22 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 44 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) +def test_chunked_prefill_with_actual_engine(model: str, + max_num_partial_prefills: int): + """Make sure the model can actually sample with concurrent + partial prefills + """ + + prompt = "hello" * 40 + + engine_args = EngineArgs( + model=model, + max_num_partial_prefills=max_num_partial_prefills, + max_num_batched_tokens=40, + max_num_seqs=8, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8, + ) + + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(temperature=0) + + for req_num in range(max_num_partial_prefills): + engine.add_request(f"{req_num}", prompt, sampling_params) + # first step + request_outputs = engine.step() + # means all are prefilling + assert len(request_outputs) == 0 + assert len(engine.scheduler[0].running) == max_num_partial_prefills diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index e969d33775d86..93660e6118ca8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -26,7 +26,6 @@ class MockModelConfig: tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" - chat_template_text_format = "string" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() @@ -49,6 +48,7 @@ async def _async_serving_chat_init(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) @@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5fa466f8f041f..72477e048eafa 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,15 +6,24 @@ from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (parse_chat_messages, - parse_chat_messages_futures) +from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..utils import VLLM_PATH + +EXAMPLES_DIR = VLLM_PATH / "examples" + PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" +LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" @pytest.fixture(scope="function") @@ -26,7 +35,6 @@ def phi3v_model_config(): trust_remote_code=True, dtype="bfloat16", seed=0, - chat_template_text_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -94,19 +102,24 @@ def test_parse_chat_messages_single_image( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -121,19 +134,24 @@ async def test_parse_chat_messages_single_image_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -147,24 +165,29 @@ def test_parse_chat_messages_multiple_images( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -181,24 +204,29 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -214,27 +242,31 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], phi3v_model_config, phi3v_tokenizer) - + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -249,26 +281,35 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -285,34 +326,39 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "text", - "text": "What about this one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [ { @@ -335,7 +381,6 @@ def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - phi3v_model_config.chat_template_text_format = "openai" conversation, mm_data = parse_chat_messages( [{ "role": "user", @@ -349,7 +394,11 @@ def test_parse_chat_messages_context_text_format( }, { "role": "user", "content": "What about this one?" - }], phi3v_model_config, phi3v_tokenizer) + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="openai", + ) assert conversation == [ { @@ -389,29 +438,34 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_rejects_too_many_images_across_messages( @@ -427,39 +481,44 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_multiple_images_uncommon_input( @@ -467,17 +526,22 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - "What's in these images?", { - "image_url": image_url - }, { - "image_url": image_url - } - ] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + "What's in these images?", { + "image_url": image_url + }, { + "image_url": image_url + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -495,16 +559,21 @@ def test_mllama_single_image( image_url, ): """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + "image_url": image_url + }] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 1) assert conversation == [{ 'role': @@ -524,26 +593,31 @@ def test_mllama_interleaved_images( image_url, ): """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + { + "image_url": image_url + }, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + { + "image_url": image_url + }, + ] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 2) assert conversation == [{ 'role': @@ -626,6 +700,7 @@ def get_conversation(is_hf: bool): vllm_conversation, model_config, tokenizer_group, + content_format="openai", ) vllm_result = apply_hf_chat_template( @@ -636,3 +711,89 @@ def get_conversation(is_hf: bool): ) assert hf_result == vllm_result + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [(PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (MLLAMA_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai")], +) +# yapf: enable +def test_resolve_content_format_hf_defined(model, expected_format): + tokenizer_group = TokenizerGroup( + model, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + chat_template = tokenizer.chat_template + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + None, # Test detecting the tokenizer's chat_template + "auto", + tokenizer, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("template_path", "expected_format"), + [("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_blip2.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_llava.jinja", "string"), + ("template_vlm2vec.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "string"), + ("tool_chat_template_llama3.2_json.jinja", "string"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string")], +) +# yapf: enable +def test_resolve_content_format_examples(template_path, expected_format): + tokenizer_group = TokenizerGroup( + PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + dummy_tokenizer = tokenizer_group.tokenizer + dummy_tokenizer.chat_template = None + + chat_template = load_chat_template(EXAMPLES_DIR / template_path) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + chat_template, + "auto", + dummy_tokenizer, + ) + + assert resolved_format == expected_format diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py deleted file mode 100644 index 59c0a24753c3b..0000000000000 --- a/tests/kernels/test_machete_gemm.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Tests for the machete kernel. - -Run `pytest tests/kernels/test_machete_gemm.py`. -""" - -import math -from typing import Optional, Tuple - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) -from vllm.platforms import current_platform -from vllm.scalar_type import ScalarType, scalar_types - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -MNK_SHAPES = [ - (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), - (1, 8192, 28672), - (13, 8192, 4096), - (26, 4096, 8192), - (64, 4096, 4096), - (64, 8192, 28672), - (257, 128, 4096), - (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), - (1024, 8192, 4096), -] - -ACT_TYPES = [torch.float16, torch.bfloat16] -WTYPE_ZEROPOINTS = [ - # GPTQ style - (scalar_types.uint4b8, False), - (scalar_types.uint8b128, False), - # AWQ style - (scalar_types.uint4, True), - (scalar_types.uint8, True), -] - -# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel -# unit tests to a common utility function. Currently the use of -# `is_quant_method_supported` conflates kernels with quantization methods -# an assumption which is breaking down as quantizations methods can have -# have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) - - -def rand_data(shape, dtype=torch.float16): - return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) - - -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): - return zps if zps is None else -1 * s * (zps.to(s.dtype)) - - -def machete_quantize_and_pack(w: torch.Tensor, - wtype: ScalarType, - group_size: int, - zero_points: bool = False): - assert wtype.is_integer(), "TODO: support floating point weights" - - w_ref, w_q, w_s, w_zp = quantize_weights( - w, - wtype, - group_size, - zero_points=zero_points, - # to match how the kernel applies zps - ref_zero_points_after_scales=True) - - w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # convert to col major - w_q_machete = ops.machete_prepack_B(w_q, wtype) - - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) - - return w_ref, w_q_machete, w_s, w_zp - - -def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, - wtype: ScalarType, group_size: int, - zero_points: bool): - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - output = ops.machete_gemm( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_all_schedules(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - print(f"MNK = {m} {n} {k}") - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - w = rand_data((k, n), atype) - - w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( - w, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - for schedule in ops.machete_supported_schedules(wtype): - print(f"Testing schedule {schedule}") - output = ops.machete_gemm( - a, - b_q=w_q_machete, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - schedule=schedule, - ) - - opcheck( - torch.ops._C.machete_gemm, - (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ - f"Schedule failed {schedule}" - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_heuristic(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - b = rand_data((k, n), atype) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_machete_devices(device: str): - m, n, k = 512, 4096, 4096 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - print(f"MNK = {m} {n} {k}, device = {device}") - - a = rand_data((m, k), torch.float16).to(device) - b = rand_data((k, n), torch.float16).to(device) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_subset(): - big_m, big_n, big_k = 1024, 1024, 1024 - m, n, k = 512, 512, 512 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - whole_a = rand_data((big_m, big_k), torch.float16) - whole_b = rand_data((big_k, big_n), torch.float16) - - a = whole_a[0:m, 0:k] - b = whole_b[0:k, 0:n] - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test to make sure cuda graphs work -class MacheteLayer(torch.nn.Module): - - def __init__(self, **kwargs): - super().__init__() - self.kwargs = kwargs - - def forward(self, a): - return ops.machete_gemm(**self.kwargs) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_cuda_graph(): - m, n, k = 512, 4096, 4096 - - a = rand_data((m, k), torch.float16) - b = rand_data((k, n), torch.float16) - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - # Construct a trivial model with a single layer that calls a machete kernel - model = MacheteLayer( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - output_ref = torch.matmul(a, w_ref) - - # Run the model with a cuda graph - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = model(a) - output.zero_() - g.replay() - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/test_machete_mm.py new file mode 100644 index 0000000000000..1c6eb2dd9a228 --- /dev/null +++ b/tests/kernels/test_machete_mm.py @@ -0,0 +1,406 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_mm.py`. +""" + +import math +from dataclasses import dataclass, fields +from typing import List, Optional, Tuple + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: Optional[torch.Tensor] + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +# NOTE: None "Scale Type" means the act type is floating point +# None "Output Type" means the output type is the same as the act type +TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + # GPTQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16]), + # AWQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16]), + # QQQ style + *(TypeConfig(act_type=torch.int8, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), + *(TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +def rand_data(shape, dtype=torch.float16, scale=1, offset=0): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - offset).to(dtype) + else: + return torch.randint(-8, 7, shape, dtype=dtype, device="cuda") + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def group_size_valid(shape: Tuple[int, int, int], + group_size: Optional[int]) -> bool: + return group_size is None or group_size == -1 or group_size % shape[2] == 0 + + +def machete_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype) + opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype)) + + return w_ref, w_q_machete, w_s, w_zp + + +def create_test_tensors(shape: Tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None) -> Tensors: + m, n, k = shape + factor = subset_stride_factor or 1 + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) + w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) + + if factor > 1: + a = a[0:m, 0:k] + w = w[0:k, 0:n] + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +# None stype means scales use the same dtype as a +def machete_mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) + output_ref_type = output_ref.dtype + + if tensors.w_ch_s is not None: + output_ref = (output_ref.to(tensors.w_ch_s.dtype) * + tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + if tensors.w_tok_s is not None: + output_ref = (output_ref.to(tensors.w_tok_s.dtype) * + tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + + output = ops.machete_mm( + a=tensors.a, + b_q=tensors.w_q, + b_type=types.weight_type, + b_group_scales=tensors.w_g_s, + b_group_zeros=tensors.w_g_zp, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + out_type=types.output_type, + schedule=schedule, + ) + + print(output) + print(output_ref) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if tensors.w_g_zp is not None\ + else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=rtol, + atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_all_schedules(shape, types: TypeConfig): + + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + print(f"MNK = {shape}") + for schedule in ops.machete_supported_schedules( + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type): + print(f"Testing schedule {schedule}") + machete_mm_test_helper(types, tensors, group_size, schedule) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_heuristic(shape, types: TypeConfig): + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + machete_mm_test_helper(types, tensors, group_size) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) + + for field in fields(Tensors): + tensor = getattr(tensors, field.name) + if isinstance(tensor, torch.Tensor): + setattr(tensors, field.name, tensor.to(device)) + + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), + type_config, + group_size, + subset_stride_factor=2) + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + stype = torch.float16 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, b, wtype, stype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + b_q=w_q_packed, + b_type=wtype, + b_group_scales=w_s, + b_group_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index af267f804ffa7..c3219bc50646b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -3,11 +3,13 @@ import pytest +from vllm.config import CompilationConfig, VllmConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + assert CustomOp.default_on() == default_on - # Reset default_on (computed once): - CustomOp.default_on.cache_clear() + ops_enabled = [bool(x) for x in ops_enabled] - assert CustomOp.default_on() == default_on + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] - ops_enabled = [bool(x) for x in ops_enabled] + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - - # Unregistered subclass - class SiluAndMul2(SiluAndMul): - pass - - # Subclasses should not require registration - assert SiluAndMul2().enabled() == SiluAndMul().enabled() + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_CUSTOM_OPS"] = env - CustomOp.default_on.cache_clear() - - with pytest.raises(AssertionError): - RMSNorm(1024).enabled() + with pytest.raises(Exception): # noqa + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + RMSNorm(1024).enabled() diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index d8a98a0f84d3b..6233860747b9c 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -8,13 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import pytest +from mistral_common.multimodal import download_image from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk +from transformers import AutoProcessor -from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt +from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, + TextPrompt, TokensPrompt) from vllm.multimodal import MultiModalDataBuiltins +from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test @@ -49,6 +53,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: }] +def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "content": PROMPT, + }, *({ + "type": "image", + "image": download_image(url) + } for url in urls)], + }] + + def _create_engine_inputs(urls: List[str]) -> TokensPrompt: msg = _create_msg_format(urls) @@ -70,6 +88,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt: return engine_inputs +def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: + msg = _create_msg_format_hf(urls) + + tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b") + prompt = tokenizer.apply_chat_template(msg) + + images = [] + for chunk in msg[0]["content"]: + if chunk["type"] == "image": + images.append(chunk["image"]) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data) + + return engine_inputs + + MSGS = [ _create_msg_format(IMG_URLS[:1]), _create_msg_format(IMG_URLS[:2]), @@ -191,3 +226,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: outputs_1_lst=logprobs, name_0="h100_ref", name_1="output") + + +@large_gpu_test(min_gb=24) +@pytest.mark.parametrize( + "prompt,expected_ranges", + [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ + "offset": 10, + "length": 494 + }]), + (_create_engine_inputs_hf(IMG_URLS[1:4]), [{ + "offset": 10, + "length": 266 + }, { + "offset": 276, + "length": 1056 + }, { + "offset": 1332, + "length": 418 + }])]) +def test_multi_modal_placeholders( + vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None: + with vllm_runner( + "mistral-community/pixtral-12b", + max_model_len=8192, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = vllm_model.model.generate(prompt) + + assert len(outputs) == 1, f"{len(outputs)=}" + output: RequestOutput = outputs[0] + assert hasattr(output, + "multi_modal_placeholders"), f"{output.__dict__=}" + assert "image" in output.multi_modal_placeholders, \ + f"{output.multi_modal_placeholders.keys()=}" + image_placeholder_ranges: list[ + PlaceholderRange] = output.multi_modal_placeholders["image"] + assert len(image_placeholder_ranges) == len( + expected_ranges), f"{image_placeholder_ranges=}" + for real_range, expected_range in zip(image_placeholder_ranges, + expected_ranges): + assert real_range == expected_range, \ + f"{real_range=} {expected_range=}" diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 718c675b86fb4..71b6ba4dca435 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -18,6 +18,7 @@ IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" +MODEL_HIDDEN_SIZE = 1536 def qwen2_vl_chat_template(*query): @@ -230,7 +231,7 @@ def batch_make_video_embeddings( return result -def run_test( +def run_embedding_input_test( vllm_runner: Type[VllmRunner], inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], model: str, @@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, [], ) for image, prompt in zip(images, IMAGE_PROMPTS)] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, [], )] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, [rescale_video_size(video, factor) for factor in size_factors], ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] - run_test( + run_embedding_input_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +def run_chunked_prefill_test( + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Compare inference result between + chunked prefill disabled and chunked prefill enabled + """ + + # NOTE: + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + + outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) + for prompts, images, videos in inputs + ] + + with vllm_runner( + model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_chunked_prefill=True, + # should be small enough to ensure prefilling is chunked + max_num_batched_tokens=32, + mm_processor_kwargs={ + "max_pixels": 16 * 28 * 28, + }) as vllm_model_chunked: + outputs_per_case_chunked = [ + vllm_model_chunked.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) for prompts, images, videos in inputs + ] + + for outputs, \ + outputs_chunked \ + in zip(outputs_per_case, + outputs_per_case_chunked): + check_logprobs_close( + outputs_0_lst=outputs, + outputs_1_lst=outputs_chunked, + name_0="non_chunked", + name_1="chunked", + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [1]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """ + Test Qwen2-VL's chunked prefill with M-RoPE + """ + prompts = [ + qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt) + for prompt in example_prompts[:1] + ] + + # 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs, + # so an image is included in the inputs + # 2. however, Qwen2-VL currently won't work properly + # when chunked prefill is enabled and there are some multi-modal inputs, + # here use a hacky way: provide a **zero-length** image to make it happy + # + # and finally we achieved: + # (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests + zero_len_image = { + "image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)), + "image_grid_thw": torch.tensor([[0, 0, 0]]) + } + images = [zero_len_image] * len(prompts) + + inputs_per_case: List[Tuple[List[str], PromptImageInput, + PromptVideoInput]] = [ + (prompts, images, []), + ] + + run_chunked_prefill_test( vllm_runner, inputs_per_case, model, diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py index d541efcefcac3..68a73f0f8ab48 100644 --- a/tests/quantization/test_ipex_quant.py +++ b/tests/quantization/test_ipex_quant.py @@ -1,5 +1,5 @@ """Test model set-up and inference for quantized HF models supported - on the CPU backend using IPEX (including AWQ). + on the CPU/GPU backend using IPEX (including AWQ/GPTQ). Validating the configuration and printing results for manual checking. @@ -11,13 +11,15 @@ from vllm.platforms import current_platform MODELS = [ - "casperhansen/llama-3-8b-instruct-awq", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx ] DTYPE = ["bfloat16"] -@pytest.mark.skipif(not current_platform.is_cpu(), - reason="only supports the CPU backend.") +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu(), + reason="only supports Intel CPU/XPU backend.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", DTYPE) def test_ipex_quant(vllm_runner, model, dtype): diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 86d9af88e49ea..941abe17a3378 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,7 +5,7 @@ import depyf -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel # disable custom dispatcher, let Dynamo takes over # all the control diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 923d0f1680802..53b10c06135a1 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,6 +1,6 @@ import os -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import compare_two_settings diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b276b8fc25473..aa89010ca8ecd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -444,18 +444,18 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_gemm") - def machete_gemm_fake( + @register_fake("_C::machete_mm") + def machete_mm_fake( a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B + # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, schedule: Optional[str] = None, ) -> torch.Tensor: m = a.size(0) @@ -463,8 +463,9 @@ def machete_gemm_fake( return torch.empty((m, n), device=a.device, dtype=a.dtype) @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) @@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete -def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type.id) - - -def machete_gemm( - a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, -) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, - b_group_size, c, alpha, beta, schedule) +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) -def machete_prepack_B(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) if hasattr(torch.ops._C, "permute_cols"): diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py index ec1c37c5bcb0e..727a470ba6d0e 100644 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -157,19 +157,22 @@ def _fwd_kernel_inner( k = tl.load( k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD), + other=0.0, ) else: if EVEN_D: k = tl.load(k_ptrs + start_n * stride_kt) else: k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD) + mask=offs_d[:, None] < D_HEAD, + other=0.0) qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -200,19 +203,22 @@ def _fwd_kernel_inner( v = tl.load( v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, ) else: if EVEN_D: v = tl.load(v_ptrs + start_n * stride_vt) else: v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD) + mask=offs_d[None, :] < D_HEAD, + other=0.0) acc += tl.dot(p, v) @@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference( q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=offs_m[:, None] < q_seqlen, + other=0.0, ) else: q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0, + other=0.0, ) sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5682faa158069..0cf1e3a95fcba 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -2,21 +2,19 @@ import dataclasses import operator from contextlib import ExitStack -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, - Union) +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch import torch import torch.fx as fx import vllm.envs as envs +from vllm.config import CompilationConfig from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors -from .config import CompilationConfig from .counter import compilation_counter from .fusion import FusionPass -from .levels import CompilationLevel from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -392,7 +390,10 @@ class VllmBackend: sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - def __init__(self, post_grad_passes: Sequence[Callable] = ()): + def __init__( + self, + compilation_configs: CompilationConfig, + ): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -401,11 +402,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = post_grad_passes + self.post_grad_passes = [] self.sym_tensor_indices = [] self.input_buffers = [] + self.compilation_configs = compilation_configs + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -437,10 +440,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is read now, because only here can + # config is updated now, because only here can # we get the sizes to capture for cudagraph # from compilation context - self.compilation_configs = CompilationConfig.select_and_init_config() + self.compilation_configs.init_during_runtime() self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( @@ -680,12 +683,3 @@ def __call__(self, *args) -> Any: entry.cudagraph.replay() return entry.output - - -def select_default_backend(level: int) -> Union[str, Callable]: - if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - backend_str = "eager" - return backend_str - assert level == CompilationLevel.PIECEWISE - - return VllmBackend() diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py deleted file mode 100644 index 3e663505c627d..0000000000000 --- a/vllm/compilation/config.py +++ /dev/null @@ -1,159 +0,0 @@ -import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, PrivateAttr - -import vllm.envs as envs -from vllm.logger import init_logger - -from .compile_context import get_compile_context - -logger = init_logger(__name__) - - -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has two parts: - - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations - in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - Custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) - - use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) - cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = False - - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - - # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr - - def model_post_init(self, __context: Any) -> None: - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func - - def init_during_runtime(self): - """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context - if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize - else: - self.capture_sizes = self.cudagraph_capture_sizes - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") - self.compile_sizes = self.inductor_compile_sizes - - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - config.init_during_runtime() - return config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca1e96a33c014..4b78491bc5a48 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,10 +3,8 @@ import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + self.do_not_compile = \ + vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() if self.do_not_compile: return - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ # type: ignore diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index eb43604b1399b..e6a3afef85e1b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,8 +6,8 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass +from vllm.config import CompilationConfig from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b23351fa19759..8082a08b40019 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py deleted file mode 100644 index 19a3a2b526870..0000000000000 --- a/vllm/compilation/levels.py +++ /dev/null @@ -1,8 +0,0 @@ -# constants for the levels of the compilation process - - -class CompilationLevel: - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7366ed4d16b0b..0143d0301ca1a 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,8 +8,7 @@ import torch import vllm.envs as envs - -from .levels import CompilationLevel +from vllm.config import CompilationLevel class TorchCompileWrapperWithCustomDispatcher: @@ -25,20 +24,17 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Optional[Callable] = None): + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): if compiled_callable is None: # default compilation settings # compiling the forward method - # choose the compile backend - - # if the user has set the backend, use it - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() - if backend is None: - from vllm.compilation.backends import select_default_backend - backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + from vllm.plugins import get_current_vllm_config + backend = get_current_vllm_config( + ).compilation_config.init_backend() compiled_callable = torch.compile( self.forward, @@ -54,7 +50,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + compilation_level >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/config.py b/vllm/config.py index 1c190da1d327e..3093ca535cdda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,10 +3,12 @@ import json import warnings from dataclasses import dataclass, field, replace +from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch +from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs @@ -20,7 +22,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, print_warning_once) + identity, print_warning_once, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -155,7 +157,6 @@ def __init__( limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_text_format: str = "string", hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None, @@ -216,7 +217,6 @@ def __init__( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.chat_template_text_format = chat_template_text_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. @@ -696,6 +696,11 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_cross_encoder(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_cross_encoder_model(architectures) + class CacheConfig: """Configuration for the KV cache. @@ -1091,6 +1096,9 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_threshold: float = 0.04, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1131,6 +1139,10 @@ def __init__(self, ) self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_token_threshold = int(max_model_len * + long_prefill_threshold) if enable_chunked_prefill: logger.info( @@ -2054,6 +2066,214 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - backend: the backend for compilation. It needs to be a string. + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the backend function. + We use string to avoid serialization issues when using compilation in a distributed setting. + When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). + When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + backend: str = "" + custom_ops: List[str] = Field(default_factory=list) + + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False + + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + self.level = envs.VLLM_TORCH_COMPILE_LEVEL + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_compile_config[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + def init_backend(self) -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + from vllm.compilation.backends import VllmBackend + return VllmBackend(self) + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + from vllm.compilation.compile_context import get_compile_context + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + return config + + @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2075,6 +2295,8 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2135,6 +2357,12 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.compilation_config is None: + self.compilation_config = CompilationConfig.select_and_init_config( + ) + + current_platform.check_and_update_config(self) + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..0c62de78c9f0c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -15,7 +15,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStatus) + SequenceStage, SequenceStatus) from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -295,6 +295,97 @@ def scheduled_seq_group_builder(): # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. + + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ + + # A minimum bound on the total number of prefills running during this + # scheduling step + partial_prefills: int + + # The number of long prefill requests currently running + long_partial_prefills: int + + scheduler_config: SchedulerConfig + + def cannot_schedule(self, seq_group: SequenceGroup) -> bool: + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return seq_group.first_seq.get_num_new_tokens() > \ + self.scheduler_config.long_prefill_token_threshold \ + and self.long_partial_prefills >= \ + self.scheduler_config.max_long_partial_prefills \ + and self.scheduler_config.max_num_partial_prefills > 1 + + def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: + # When a new prefill is scheduled, we need to know if it is a + # long request + if (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold): + self.long_partial_prefills += 1 + + @classmethod + def from_queues( + cls, running: Deque[SequenceGroup], waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" + partial_prefills = 0 + long_partial_prefills = 0 + + waiting_partial_prefills = 0 + waiting_long_prefills = 0 + + for sg in running: + # TODO: Check if this stage is correctly updated before scheduling + if sg.first_seq.data.stage == SequenceStage.PREFILL: + partial_prefills += 1 + if sg.first_seq.get_num_new_tokens( + ) > scheduler_config.long_prefill_token_threshold: + long_partial_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know + # there are already at + # least max_partial_prefills requests to fill + if partial_prefills + waiting_partial_prefills \ + >= scheduler_config.max_num_partial_prefills: + break + + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway + if sg.first_seq.get_num_new_tokens( + ) > scheduler_config.long_prefill_token_threshold: + if long_partial_prefills + waiting_long_prefills \ + >= scheduler_config.max_long_partial_prefills: + continue + waiting_long_prefills += 1 + waiting_partial_prefills += 1 + + return PartialPrefillMetadata(partial_prefills + + waiting_partial_prefills, + long_partial_prefills, + scheduler_config=scheduler_config) + + class Scheduler: def __init__( @@ -395,6 +486,18 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[ + 0] = scheduler_config.max_num_batched_tokens + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = \ + scheduler_config.max_num_batched_tokens // i + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -494,6 +597,7 @@ def _schedule_running( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -508,7 +612,9 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + partial_prefill_metadata: information about the partial prefills + that are currently running + Returns: SchedulerRunningOutputs. """ @@ -542,7 +648,12 @@ def _schedule_running( while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) if num_running_tokens == 0: # No budget => Stop @@ -807,17 +918,17 @@ def _schedule_priority_preemption( SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs)): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens = self._get_num_new_tokens( vseq_group, SequenceStatus.RUNNING, False, budget) @@ -827,11 +938,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -845,6 +956,7 @@ def _schedule_prefills( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. @@ -865,10 +977,20 @@ def _schedule_prefills( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: SchedulerPrefillOutputs. """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) + ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] @@ -878,13 +1000,31 @@ def _schedule_prefills( while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + waiting_seqs = \ + seq_group.get_seqs(status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + + if partial_prefill_metadata is not None and \ + partial_prefill_metadata.cannot_schedule(seq_group): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + num_new_tokens = self._get_num_new_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata) + + num_new_seqs = seq_group.get_max_num_running_seqs() + # quick budget check + if num_new_tokens == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs): + break + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -935,18 +1075,15 @@ def _schedule_prefills( waiting_queue.popleft() continue - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): - break - # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) + if partial_prefill_metadata is not None: + partial_prefill_metadata.increment_partial_prefills(seq_group) + if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -1108,10 +1245,19 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config) + # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. @@ -1120,9 +1266,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) @@ -1303,6 +1452,7 @@ def schedule( encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, + token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1584,9 +1734,14 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: + def _get_num_new_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. @@ -1602,46 +1757,54 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. + + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > budget.remaining_token_budget() \ + else num_new_tokens + # If number of seq > 1, it means it is doing beam search # in a decode phase. Do not chunk. - if enable_chunking and len(seqs) == 1: + elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = remaining_token_budget \ + if partial_prefill_metadata is None \ + else self.partial_prefill_budget_lookup_list[ + partial_prefill_metadata.partial_prefills + ] + + if self.cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = \ + (min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) + return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d73f95f59c71f..e37e48abb47c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -90,7 +90,6 @@ class EngineArgs: task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' - chat_template_text_format: str = 'string' trust_remote_code: bool = False allowed_local_media_path: str = "" download_dir: Optional[str] = None @@ -120,6 +119,9 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None + max_num_partial_prefills: Optional[int] = 1 + max_long_partial_prefills: Optional[int] = 1 + long_prefill_threshold: Optional[float] = 0.04 max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -258,14 +260,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') - parser.add_argument( - '--chat-template-text-format', - type=str, - default=EngineArgs.chat_template_text_format, - choices=['string', 'openai'], - help='The format to render text content within a chat template. ' - '"string" will keep the content field as a string whereas ' - '"openai" will parse content in the current OpenAI format.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -474,6 +468,31 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_num_batched_tokens, help='Maximum number of batched tokens per ' 'iteration.') + parser.add_argument( + "--max-num-partial-prefills", + type=int, + default=EngineArgs.max_num_partial_prefills, + help="For chunked prefill, the max number of concurrent \ + partial prefills." + "Defaults to 1", + ) + parser.add_argument( + "--max-long-partial-prefills", + type=int, + default=EngineArgs.max_long_partial_prefills, + help="For chunked prefill, the max number of long concurrent " + "partial prefills. The length is determined by the " + "long-prefill-threshold argument. " + "Defaults to 1", + ) + parser.add_argument( + "--long-prefill-threshold", + type=float, + default=EngineArgs.long_prefill_threshold, + help="For chunked prefill, a request is considered long " + "if the prompt is longer than the " + "max_model_length * long_prefill_threshold. Defaults to 0.04%", + ) parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, @@ -894,7 +913,6 @@ def create_model_config(self) -> ModelConfig: # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, - chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, dtype=self.dtype, @@ -1095,7 +1113,10 @@ def create_engine_config(self) -> VllmConfig: multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), - policy=self.scheduling_policy) + policy=self.scheduling_policy, + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_threshold=self.long_prefill_threshold) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index aa9c7893c4cfe..e72dc81f35b67 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -262,8 +262,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_text_format=%s, mm_processor_kwargs=%s, " - "pooler_config=%r)", + "mm_processor_kwargs=%s, pooler_config=%r)", VLLM_VERSION, model_config.model, speculative_config, @@ -296,7 +295,6 @@ def __init__( cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, - model_config.chat_template_text_format, model_config.mm_processor_kwargs, model_config.pooler_config, ) @@ -1718,7 +1716,7 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - num_generation_tokens_from_prefill_groups = 0. + num_generation_tokens_from_prefill_groups = 0 # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != # scheduler_outputs.num_prefill_groups, this means that diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index e896bcdded2d1..47472c274ccb6 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -512,6 +512,11 @@ def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return counter.labels(**self.labels).inc(data) def _log_counter_labels(self, counter, data: CollectionsCounter, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3ca460c47c3bd..abee5ac46391c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,12 +2,14 @@ import codecs import json from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, @@ -153,6 +155,199 @@ class ConversationMessage(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + tokenizer_chat_template = tokenizer.chat_template + else: + tokenizer_chat_template = None + + jinja_text: Optional[str] + if isinstance(tokenizer_chat_template, str) and chat_template is None: + jinja_text = tokenizer_chat_template + elif (isinstance(tokenizer_chat_template, dict) + and chat_template in tokenizer_chat_template): + jinja_text = tokenizer_chat_template[chat_template] + else: + jinja_text = load_chat_template(chat_template, is_literal=True) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format if given_format == "auto" else given_format + + +@lru_cache +def resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + given_format, + tokenizer, + ) + + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + return detected_format + + ModalityStr = Literal["image", "audio", "video"] _T = TypeVar("_T") @@ -407,12 +602,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): def load_chat_template( - chat_template: Optional[Union[Path, str]]) -> Optional[str]: + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: if chat_template is None: return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return codecs.decode(chat_template, "unicode_escape") + try: with open(chat_template) as f: - resolved_chat_template = f.read() + return f.read() except OSError as e: if isinstance(chat_template, Path): raise @@ -426,10 +632,7 @@ def load_chat_template( # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - resolved_chat_template = codecs.decode(chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", resolved_chat_template) - return resolved_chat_template + return load_chat_template(chat_template, is_literal=True) # TODO: Let user specify how to insert multimodal tokens into prompt @@ -464,7 +667,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) -MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { @@ -542,18 +744,12 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, + *, + wrap_dicts: bool, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - model_config = mm_tracker.model_config - - wrap_dicts = (chat_template_text_format == "openai" - or (model_config.task == "embedding" - and model_config.is_multimodal_model) - or (model_config.hf_config.model_type - in MODEL_KEEP_MULTI_MODAL_CONTENT)) for part in parts: parse_res = _parse_chat_message_content_part( @@ -578,9 +774,11 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content_part( - part: ChatCompletionContentPartParam, - mm_parser: BaseMultiModalContentParser, - wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[Union[str, Dict[str, str]]]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -629,7 +827,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, + content_format: _ChatTemplateContentFormat, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -645,7 +843,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, - chat_template_text_format, + wrap_dicts=(content_format == "openai"), ) for result_msg in result: @@ -684,6 +882,7 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -692,7 +891,7 @@ def parse_chat_messages( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, + content_format, ) conversation.extend(sub_messages) @@ -706,6 +905,7 @@ def parse_chat_messages_futures( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -714,7 +914,7 @@ def parse_chat_messages_futures( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, + content_format, ) conversation.extend(sub_messages) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4b33fc1458ee3..9c7c8e3b2aa13 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,10 +13,12 @@ TaskOption) from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, apply_hf_chat_template, apply_mistral_chat_template, - parse_chat_messages) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt + parse_chat_messages, + resolve_chat_template_content_format) +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -523,6 +525,7 @@ def chat( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, @@ -539,9 +542,11 @@ def chat( to the OpenAI API. Args: - messages: A list of conversations or a single conversation. - - Each conversation is represented as a list of messages. - - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -551,11 +556,19 @@ def chat( lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + + - "string" will render the content as a string. + Example: ``"Who are you?"`` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: ``[{"type": "text", "text": "Who are you?"}]`` + add_generation_prompt: If True, adds a generation template to each message. continue_final_message: If True, continues the final message in - the conversation instead of starting a new one. Cannot be `True` - if `add_generation_prompt` is also `True`. + the conversation instead of starting a new one. Cannot be + ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. @@ -576,17 +589,26 @@ def chat( cast(List[ChatCompletionMessageParam], messages) ] + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) + prompts: List[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - # NOTE: _parse_chat_message_content_parts() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. conversation, mm_data = parse_chat_messages( - msgs, model_config, tokenizer) + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) prompt_data: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): @@ -737,7 +759,7 @@ def encode( generation, if any. Returns: - A list of `EmbeddingRequestOutput` objects containing the + A list of ``EmbeddingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: @@ -782,6 +804,111 @@ def encode( return self.engine_class.validate_outputs(outputs, EmbeddingRequestOutput) + def score( + self, + query: SingletonPrompt, + texts: Union[SingletonPrompt, Sequence[SingletonPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates similarity scores for all pairs . + + This method pairs the input query with each of the texts to generate + a list of prompts for the cross encoder model. This class automatically + batches the prompts, considering the memory constraint. For the best + performance, put all of your texts into a single list and pass it to + this method. + + Args: + query: The query to compare against all other text input + texts: The texts to pair with the query to form the input + to the LLM. You may pass a sequence of texts for batch + inference. See :class:`~vllm.inputs.PromptType` for more + details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.score() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) + + if not self.llm_engine.model_config.is_cross_encoder: + raise ValueError("Your model does not support the cross encoding") + + tokenizer = self.llm_engine.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for cross encoding") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + query = ensure_str(query) + if isinstance(texts, (str, dict)): + # Convert a single prompt to a list. + texts = [texts] + + input_pairs = [(query, ensure_str(t)) for t in texts] + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) + def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b13f6a228b4c6..f07d0d4732974 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -29,10 +29,12 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.entrypoints.openai.fim import get_supported_fim_encoders # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -529,6 +531,9 @@ def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -537,7 +542,8 @@ def init_app_state( lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, @@ -551,13 +557,15 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + fim_encoder=args.fim, ) if model_config.task == "generate" else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, base_model_paths, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -565,7 +573,8 @@ def init_app_state( base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) @@ -588,11 +597,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - valide_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") + if args.enable_auto_tool_choice: + valid_tool_parsers = ToolParserManager.tool_parsers + if args.tool_call_parser not in valid_tool_parsers: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") + + if args.fim is not None: + valid_fim_encoders = get_supported_fim_encoders() + if args.fim not in valid_fim_encoders: + raise KeyError( + f"invalid FIM encoder: {args.fim} " + f"(chose from {{ {','.join(valid_fim_encoders)} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index eb08a89293370..5f035972fd407 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -7,10 +7,12 @@ import argparse import json import ssl -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.fim import get_supported_fim_encoders +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -132,6 +134,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: "Hello World"\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: [{"type": "text", "text": "Hello world!"}]') parser.add_argument("--response-role", type=nullable_str, default="assistant", @@ -213,6 +227,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") + valid_fim_encoders = get_supported_fim_encoders() + parser.add_argument( + "--fim", + type=str, + metavar="{" + ",".join(valid_fim_encoders) + "}", + default=None, + help="Select the fill-in-the-middle (FIM) encoder depending on the" + " model that you're using. Required to use the suffix parameter of the" + " OpenAI Completions API.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py new file mode 100644 index 0000000000000..d5f3431130165 --- /dev/null +++ b/vllm/entrypoints/openai/fim/__init__.py @@ -0,0 +1,69 @@ +from functools import partial +from inspect import isclass +from typing import Callable, Dict, Iterable, Optional, Tuple, Type, Union + +from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + StringTemplateFIMEncoder) +from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + +__all__ = [ + "FIMEncoder", "get_supported_fim_encoders", "get_fim_encoder_lookup" +] + +# Entries are either an FIMEncoder implementation class or +# tuple of (template, special_tokens_list). +_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = { + "mistral": + MistralFIMEncoder, + "codellama": + CodeLlamaFIMEncoder, + "deepseek": ( + "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>", + ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"), + ), + "starcoder": ( + "{prefix}{suffix}", + ("", "", ""), + ) +} + + +def get_supported_fim_encoders() -> Iterable[str]: + """Return set of supported FIM encoder types.""" + return _FIM_ENCODERS.keys() + + +def get_fim_encoder_lookup( + name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]: + """ + Get a function that returns a FIMEncoder instance for a given tokenizer. + Raise a KeyError exception if the name is not recognized. + """ + if name is None: + return None + + if (encoder := _FIM_ENCODERS.get(name)) is None: + raise ValueError(f"fim encoder '{name}' not recognized") + + factory: Callable[[AnyTokenizer], FIMEncoder] + if isclass(encoder): + assert issubclass(encoder, FIMEncoder) + factory = encoder + else: + assert isinstance(encoder, tuple) + template, special_tokens = encoder + factory = partial(StringTemplateFIMEncoder, + name=name, + template=template, + special_tokens=special_tokens) + + def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder: + fim_encoder = getattr(tokenizer, "fim_encoder", None) + if fim_encoder is None: + fim_encoder = factory(tokenizer) + tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr] + return fim_encoder + + return for_tokenizer diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py new file mode 100644 index 0000000000000..224d34d2288d5 --- /dev/null +++ b/vllm/entrypoints/openai/fim/codellama_fim.py @@ -0,0 +1,41 @@ +from typing import List + +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class CodeLlamaFIMEncoder(FIMEncoder): + """ + FIM Encoder for Meta CodeLlama models + + Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474 + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not hasattr(tokenizer, "convert_tokens_to_ids"): + raise ValueError( + "tokenizer incompatible with 'codellama' FIM encoder") + + self.bos_id = tokenizer.convert_tokens_to_ids("") + self.prefix_id = tokenizer.convert_tokens_to_ids("▁
")
+        self.suffix_id = tokenizer.convert_tokens_to_ids("▁")
+        self.middle_id = tokenizer.convert_tokens_to_ids("▁")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        if any(tid in
+               {self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
+               for tid in (None, unk_token_id)):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prefix_tokens = self.tokenizer(prefix,
+                                       add_special_tokens=False).input_ids
+        # Encode a string without an implicit leading space.
+        suffix_tokens = self.tokenizer("☺" + suffix,
+                                       add_special_tokens=False).input_ids[2:]
+
+        return ([self.bos_id, self.prefix_id] + prefix_tokens[self.suffix_id] +
+                suffix_tokens + [self.middle_id])
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
new file mode 100644
index 0000000000000..9b6f27a7a1e36
--- /dev/null
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -0,0 +1,52 @@
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Optional
+
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+
+class FIMEncoder(ABC):
+    """
+    An encoder of fill-in-the-middle (FIM) prompts comprising prefix
+    and suffix strings.
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        self.tokenizer = tokenizer
+
+    @abstractmethod
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        """
+        Encode the provided prompt prefix and suffix
+        to a list of token ids
+        """
+        pass
+
+
+class StringTemplateFIMEncoder(FIMEncoder):
+    """FIMEncoder implementation using a simple string template
+    with prefix and suffix variables."""
+
+    def __init__(
+        self,
+        tokenizer: AnyTokenizer,
+        name: str,
+        template: str,
+        special_tokens: Optional[Iterable[str]] = None,
+    ):
+        super().__init__(tokenizer)
+
+        if not hasattr(tokenizer, "convert_tokens_to_ids"):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        for special_token in special_tokens or ():
+            token_id = tokenizer.convert_tokens_to_ids(special_token)
+            if token_id is None or token_id == unk_token_id:
+                raise ValueError(
+                    f"tokenizer incompatible with '{name}' FIM encoder")
+        self.template = template
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prompt = self.template.format(prefix=prefix, suffix=suffix)
+        return self.tokenizer(prompt, add_special_tokens=False).input_ids
diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
new file mode 100644
index 0000000000000..21fd1cca9e217
--- /dev/null
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -0,0 +1,22 @@
+from typing import List
+
+from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV2
+
+from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers import MistralTokenizer
+
+
+class MistralFIMEncoder(FIMEncoder):
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        # InstructTokenizerV3 is a subclass of InstructTokenizerV2
+        if not isinstance(tokenizer, MistralTokenizer) \
+            or not isinstance(tokenizer.instruct, InstructTokenizerV2):
+            raise ValueError(
+                "tokenizer incompatible with 'mistral' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 820aefd8800d9..b7b064ae01f05 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -5,9 +5,8 @@
 from typing import Any, Dict, List, Literal, Optional, Union
 
 import torch
-from openai.types.chat import ChatCompletionContentPartParam
 from pydantic import BaseModel, ConfigDict, Field, model_validator
-from typing_extensions import Annotated, Required, TypedDict
+from typing_extensions import Annotated
 
 from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
 from vllm.pooling_params import PoolingParams
@@ -35,26 +34,6 @@
 assert _LONG_INFO.max == _MOCK_LONG_INFO.max
 
 
-class CustomChatCompletionMessageParam(TypedDict, total=False):
-    """Enables custom roles in the Chat Completion API."""
-    role: Required[str]
-    """The role of the message's author."""
-
-    content: Union[str, List[ChatCompletionContentPartParam]]
-    """The contents of the message."""
-
-    name: str
-    """An optional name for the participant.
-
-    Provides the model information to differentiate between participants of the
-    same role.
-    """
-
-    tool_call_id: Optional[str]
-
-    tool_calls: Optional[List[dict]]
-
-
 class OpenAIBaseModel(BaseModel):
     # OpenAI API does not allow extra fields
     model_config = ConfigDict(extra="forbid")
@@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
     model: str
     prompt: str
 
-    add_special_tokens: bool = Field(default=True)
+    add_special_tokens: bool = Field(
+        default=True,
+        description=(
+            "If true (the default), special tokens (e.g. BOS) will be added to "
+            "the prompt."),
+    )
 
 
 class TokenizeChatRequest(OpenAIBaseModel):
     model: str
     messages: List[ChatCompletionMessageParam]
 
-    add_generation_prompt: bool = Field(default=True)
-    continue_final_message: bool = Field(default=False)
-    add_special_tokens: bool = Field(default=False)
+    add_generation_prompt: bool = Field(
+        default=True,
+        description=
+        ("If true, the generation prompt will be added to the chat template. "
+         "This is a parameter used by chat template in tokenizer config of the "
+         "model."),
+    )
+    continue_final_message: bool = Field(
+        default=False,
+        description=
+        ("If this is set, the chat will be formatted so that the final "
+         "message in the chat is open-ended, without any EOS tokens. The "
+         "model will continue this message rather than starting a new one. "
+         "This allows you to \"prefill\" part of the model's response for it. "
+         "Cannot be used at the same time as `add_generation_prompt`."),
+    )
+    add_special_tokens: bool = Field(
+        default=False,
+        description=(
+            "If true, special tokens (e.g. BOS) will be added to the prompt "
+            "on top of what is added by the chat template. "
+            "For most models, the chat template takes care of adding the "
+            "special tokens so this should be set to false (as is the "
+            "default)."),
+    )
+    chat_template: Optional[str] = Field(
+        default=None,
+        description=(
+            "A Jinja template to use for this conversion. "
+            "As of transformers v4.44, default chat template is no longer "
+            "allowed, so you must provide a chat template if the tokenizer "
+            "does not define one."),
+    )
+    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+        default=None,
+        description=("Additional kwargs to pass to the template renderer. "
+                     "Will be accessible by the chat template."),
+    )
 
     @model_validator(mode="before")
     @classmethod
diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py
index 1b422a93263b2..00cdb3b6839f5 100644
--- a/vllm/entrypoints/openai/run_batch.py
+++ b/vllm/entrypoints/openai/run_batch.py
@@ -222,6 +222,7 @@ async def main(args):
         prompt_adapters=None,
         request_logger=request_logger,
         chat_template=None,
+        chat_template_content_format="auto",
         enable_prompt_tokens_details=args.enable_prompt_tokens_details,
     ) if model_config.task == "generate" else None
     openai_serving_embedding = OpenAIServingEmbedding(
@@ -230,6 +231,7 @@ async def main(args):
         base_model_paths,
         request_logger=request_logger,
         chat_template=None,
+        chat_template_content_format="auto",
     ) if model_config.task == "embedding" else None
 
     tracker = BatchProgressTracker()
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 77cae00ae827f..2eef909eb9319 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -10,7 +10,8 @@
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
+from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
+                                         ConversationMessage)
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.protocol import (
     ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -38,20 +39,23 @@
 
 class OpenAIServingChat(OpenAIServing):
 
-    def __init__(self,
-                 engine_client: EngineClient,
-                 model_config: ModelConfig,
-                 base_model_paths: List[BaseModelPath],
-                 response_role: str,
-                 *,
-                 lora_modules: Optional[List[LoRAModulePath]],
-                 prompt_adapters: Optional[List[PromptAdapterPath]],
-                 request_logger: Optional[RequestLogger],
-                 chat_template: Optional[str],
-                 return_tokens_as_token_ids: bool = False,
-                 enable_auto_tools: bool = False,
-                 tool_parser: Optional[str] = None,
-                 enable_prompt_tokens_details: bool = False):
+    def __init__(
+        self,
+        engine_client: EngineClient,
+        model_config: ModelConfig,
+        base_model_paths: List[BaseModelPath],
+        response_role: str,
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+        chat_template_content_format: ChatTemplateContentFormatOption,
+        return_tokens_as_token_ids: bool = False,
+        enable_auto_tools: bool = False,
+        tool_parser: Optional[str] = None,
+        enable_prompt_tokens_details: bool = False,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -61,8 +65,8 @@ def __init__(self,
                          return_tokens_as_token_ids=return_tokens_as_token_ids)
 
         self.response_role = response_role
-        self.use_tool_use_model_template = False
-        self.chat_template = load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
         # set up tool use
         self.enable_auto_tools: bool = enable_auto_tools
@@ -120,6 +124,7 @@ async def create_chat_completion(
             ) = self._maybe_get_adapters(request)
 
             tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
             tool_parser = self.tool_parser
 
             # validation for OpenAI tools
@@ -157,6 +162,7 @@ async def create_chat_completion(
                 tokenizer,
                 request.messages,
                 chat_template=request.chat_template or self.chat_template,
+                chat_template_content_format=self.chat_template_content_format,
                 add_generation_prompt=request.add_generation_prompt,
                 continue_final_message=request.continue_final_message,
                 tool_dicts=tool_dicts,
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 936aae8f1c267..7f70afa5296bc 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -47,6 +47,7 @@ def __init__(
         prompt_adapters: Optional[List[PromptAdapterPath]],
         request_logger: Optional[RequestLogger],
         return_tokens_as_token_ids: bool = False,
+        fim_encoder: Optional[str] = None,
     ):
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
@@ -54,7 +55,8 @@ def __init__(
                          lora_modules=lora_modules,
                          prompt_adapters=prompt_adapters,
                          request_logger=request_logger,
-                         return_tokens_as_token_ids=return_tokens_as_token_ids)
+                         return_tokens_as_token_ids=return_tokens_as_token_ids,
+                         fim_encoder=fim_encoder)
 
     async def create_completion(
         self,
@@ -66,9 +68,6 @@ async def create_completion(
         See https://platform.openai.com/docs/api-reference/completions/create
         for the API specification. This API mimics the OpenAI Completion API.
 
-        NOTE: Currently we do not support the following feature:
-            - suffix (the language models we currently support do not support
-            suffix)
         """
         error_check_ret = await self._check_model(request)
         if error_check_ret is not None:
@@ -80,11 +79,6 @@ async def create_completion(
         if self.engine_client.errored:
             raise self.engine_client.dead_error
 
-        # Return error for unsupported features.
-        if request.suffix is not None:
-            return self.create_error_response(
-                "suffix is not currently supported")
-
         model_name = self.base_model_paths[0].name
         request_id = f"cmpl-{random_uuid()}"
         created_time = int(time.time())
@@ -105,6 +99,7 @@ async def create_completion(
                 request,
                 tokenizer,
                 request.prompt,
+                request.suffix,
                 truncate_prompt_tokens=request.truncate_prompt_tokens,
                 add_special_tokens=request.add_special_tokens,
             )
@@ -303,6 +298,14 @@ async def completion_stream_generator(
                             # Chunked prefill case, don't return empty chunks
                             continue
 
+                    previous_text_lens[i] += len(output.text)
+                    previous_num_tokens[i] += len(output.token_ids)
+                    finish_reason = output.finish_reason
+                    stop_reason = output.stop_reason
+
+                    if finish_reason and request.echo and request.suffix:
+                        delta_text += request.suffix
+
                     if request.logprobs is not None:
                         assert out_logprobs is not None, (
                             "Did not output logprobs")
@@ -316,11 +319,6 @@ async def completion_stream_generator(
                     else:
                         logprobs = None
 
-                    previous_text_lens[i] += len(output.text)
-                    previous_num_tokens[i] += len(output.token_ids)
-                    finish_reason = output.finish_reason
-                    stop_reason = output.stop_reason
-
                     chunk = CompletionStreamResponse(
                         id=request_id,
                         created=created_time,
@@ -388,6 +386,8 @@ def request_output_to_completion_response(
         num_prompt_tokens = 0
         num_generated_tokens = 0
 
+        suffix = "" if request.suffix is None else request.suffix
+
         for final_res in final_res_batch:
             prompt_token_ids = final_res.prompt_token_ids
             assert prompt_token_ids is not None
@@ -397,7 +397,6 @@ def request_output_to_completion_response(
             token_ids: GenericSequence[int]
             out_logprobs: Optional[GenericSequence[Optional[Dict[int,
                                                                  Logprob]]]]
-
             for output in final_res.outputs:
                 assert request.max_tokens is not None
                 if request.echo:
@@ -405,7 +404,7 @@ def request_output_to_completion_response(
                     if request.max_tokens == 0:
                         token_ids = prompt_token_ids
                         out_logprobs = prompt_logprobs
-                        output_text = prompt_text
+                        output_text = prompt_text + suffix
                     else:
                         token_ids = [*prompt_token_ids, *output.token_ids]
 
@@ -419,7 +418,7 @@ def request_output_to_completion_response(
                                 *output.logprobs,
                             ]
 
-                        output_text = prompt_text + output.text
+                        output_text = prompt_text + output.text + suffix
                 else:
                     token_ids = output.token_ids
                     out_logprobs = output.logprobs
diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py
index bbe7db8f13231..91aa070b056d8 100644
--- a/vllm/entrypoints/openai/serving_embedding.py
+++ b/vllm/entrypoints/openai/serving_embedding.py
@@ -1,24 +1,34 @@
 import asyncio
 import base64
 import time
-from typing import AsyncGenerator, List, Literal, Optional, Union, cast
+from typing import (Annotated, Any, AsyncGenerator, Dict, Final, List, Literal,
+                    Optional, Sequence, Tuple, Union, cast)
 
 import numpy as np
 from fastapi import Request
+from pydantic import Field
 from typing_extensions import assert_never
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+                                         ChatTemplateContentFormatOption,
+                                         ConversationMessage,
+                                         parse_chat_messages_futures)
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
                                               EmbeddingRequest,
                                               EmbeddingResponse,
                                               EmbeddingResponseData,
                                               ErrorResponse, UsageInfo)
-from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
+from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
+                                                    OpenAIServing,
+                                                    RequestPrompt)
+from vllm.inputs.data import TokensPrompt
 from vllm.logger import init_logger
 from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
 from vllm.utils import merge_async_iterators, random_uuid
 
 logger = init_logger(__name__)
@@ -77,7 +87,8 @@ def __init__(
         *,
         request_logger: Optional[RequestLogger],
         chat_template: Optional[str],
-    ):
+        chat_template_content_format: ChatTemplateContentFormatOption,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -85,7 +96,8 @@ def __init__(
                          prompt_adapters=None,
                          request_logger=request_logger)
 
-        self.chat_template = load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
     async def create_embedding(
         self,
@@ -134,29 +146,49 @@ async def create_embedding(
                 raise NotImplementedError("Prompt adapter is not supported "
                                           "for embedding models")
 
-            if isinstance(request, EmbeddingChatRequest):
-                (
-                    _,
-                    request_prompts,
-                    engine_prompts,
-                ) = await self._preprocess_chat(
-                    request,
-                    tokenizer,
-                    request.messages,
-                    chat_template=request.chat_template or self.chat_template,
-                    add_generation_prompt=request.add_generation_prompt,
-                    continue_final_message=request.continue_final_message,
-                    truncate_prompt_tokens=truncate_prompt_tokens,
-                    add_special_tokens=request.add_special_tokens,
-                )
+            if self.model_config.is_cross_encoder:
+                if isinstance(request, EmbeddingChatRequest):
+                    (
+                        _,
+                        request_prompts,
+                        engine_prompts,
+                    ) = await self._preprocess_cross_encoding(
+                        tokenizer,
+                        request.messages,
+                        truncate_prompt_tokens=truncate_prompt_tokens,
+                    )
+                else:
+                    return self.create_error_response(
+                        "Cross encoding requests must "
+                        "use the chat embedding API")
             else:
-                request_prompts, engine_prompts = self._preprocess_completion(
-                    request,
-                    tokenizer,
-                    request.input,
-                    truncate_prompt_tokens=truncate_prompt_tokens,
-                    add_special_tokens=request.add_special_tokens,
-                )
+                if isinstance(request, EmbeddingChatRequest):
+                    (
+                        _,
+                        request_prompts,
+                        engine_prompts,
+                    ) = await self._preprocess_chat(
+                        request,
+                        tokenizer,
+                        request.messages,
+                        chat_template=request.chat_template
+                        or self.chat_template,
+                        chat_template_content_format=self.
+                        chat_template_content_format,
+                        add_generation_prompt=request.add_generation_prompt,
+                        continue_final_message=request.continue_final_message,
+                        truncate_prompt_tokens=truncate_prompt_tokens,
+                        add_special_tokens=request.add_special_tokens,
+                    )
+                else:
+                    request_prompts, engine_prompts = self\
+                        ._preprocess_completion(
+                        request,
+                        tokenizer,
+                        request.input,
+                        truncate_prompt_tokens=truncate_prompt_tokens,
+                        add_special_tokens=request.add_special_tokens,
+                    )
         except ValueError as e:
             logger.exception("Error in preprocessing prompt inputs")
             return self.create_error_response(str(e))
@@ -221,3 +253,44 @@ async def create_embedding(
             return self.create_error_response(str(e))
 
         return response
+
+    async def _preprocess_cross_encoding(
+        self,
+        tokenizer: AnyTokenizer,
+        messages: List[ChatCompletionMessageParam],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=3)]] = None,
+    ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
+               List[TokensPrompt]]:
+
+        conversation, mm_data_future = parse_chat_messages_futures(
+            messages, self.model_config, tokenizer, "string")
+        await mm_data_future
+
+        if len(conversation) != 2:
+            raise ValueError("For cross encoding two inputs must be provided")
+        prompts = []
+        for msg in conversation:
+            content = msg["content"]
+            assert type(msg["content"]) is str
+            prompts.append(content)
+
+        if isinstance(tokenizer, MistralTokenizer):
+            raise ValueError(
+                "MistralTokenizer not supported for cross-encoding")
+
+        request_prompt = f"{prompts[0]}{tokenizer.sep_token}{prompts[1]}"
+
+        tokenization_kwargs: Dict[str, Any] = {}
+        if truncate_prompt_tokens is not None:
+            tokenization_kwargs["truncation"] = True
+            tokenization_kwargs["max_length"] = truncate_prompt_tokens
+
+        prompt_inputs = tokenizer(prompts[0],
+                                  text_pair=prompts[1],
+                                  **tokenization_kwargs)
+
+        engine_prompt = TokensPrompt(
+            prompt_token_ids=prompt_inputs["input_ids"],
+            token_type_ids=prompt_inputs.get("token_type_ids"))
+
+        return conversation, [request_prompt], [engine_prompt]
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index fa315fa516632..912bdf679d725 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -11,12 +11,17 @@
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
+# yapf conflicts with isort for this block
+# yapf: disable
 from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+                                         ChatTemplateContentFormatOption,
                                          ConversationMessage,
                                          apply_hf_chat_template,
                                          apply_mistral_chat_template,
-                                         parse_chat_messages_futures)
+                                         parse_chat_messages_futures,
+                                         resolve_chat_template_content_format)
 from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.fim import FIMEncoder, get_fim_encoder_lookup
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -98,6 +103,7 @@ def __init__(
         prompt_adapters: Optional[List[PromptAdapterPath]],
         request_logger: Optional[RequestLogger],
         return_tokens_as_token_ids: bool = False,
+        fim_encoder: Optional[str] = None,
     ):
         super().__init__()
 
@@ -138,6 +144,9 @@ def __init__(
         self.request_logger = request_logger
         self.return_tokens_as_token_ids = return_tokens_as_token_ids
 
+        self.get_fim_encoder: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
+            get_fim_encoder_lookup(fim_encoder)
+
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
         model_cards = [
@@ -225,22 +234,32 @@ def _normalize_prompt_text_to_input(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         prompt: str,
+        suffix: Optional[str],
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
         add_special_tokens: bool,
     ) -> TextTokensPrompt:
-        if truncate_prompt_tokens is None:
-            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
+        if suffix:
+            if not (get_fim_encoder := self.get_fim_encoder):
+                raise ValueError("fim support must be enabled to use suffix")
+            if truncate_prompt_tokens is not None:
+                raise ValueError(
+                    "truncate_prompt_tokens is not supported with suffix")
+            fim_encoder = get_fim_encoder(tokenizer)
+            input_ids = fim_encoder.encode_with_suffix(prefix=prompt,
+                                                       suffix=suffix)
         else:
-            encoded = tokenizer(prompt,
-                                add_special_tokens=add_special_tokens,
-                                truncation=True,
-                                max_length=truncate_prompt_tokens)
-
-        input_ids = encoded.input_ids
+            if truncate_prompt_tokens is None:
+                encoded = tokenizer(prompt,
+                                    add_special_tokens=add_special_tokens)
+            else:
+                encoded = tokenizer(prompt,
+                                    add_special_tokens=add_special_tokens,
+                                    truncation=True,
+                                    max_length=truncate_prompt_tokens)
 
-        input_text = prompt
+            input_ids = encoded.input_ids
 
-        return self._validate_input(request, input_ids, input_text)
+        return self._validate_input(request, input_ids, input_text=prompt)
 
     def _normalize_prompt_tokens_to_input(
         self,
@@ -335,6 +354,7 @@ def _tokenize_prompt_inputs(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         prompt_inputs: Iterable[Union[str, List[int]]],
+        suffix: Optional[str] = None,
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         add_special_tokens: bool = True,
     ) -> Iterator[TextTokensPrompt]:
@@ -348,10 +368,14 @@ def _tokenize_prompt_inputs(
                     request,
                     tokenizer,
                     prompt=text,
+                    suffix=suffix,
                     truncate_prompt_tokens=truncate_prompt_tokens,
                     add_special_tokens=add_special_tokens,
                 )
             else:
+                if suffix:
+                    raise ValueError(
+                        "suffix is only supported with string prompt input")
                 yield self._normalize_prompt_tokens_to_input(
                     request,
                     tokenizer,
@@ -364,6 +388,7 @@ def _tokenize_prompt_input_or_inputs(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+        suffix: Optional[str] = None,
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         add_special_tokens: bool = True,
     ) -> Iterator[TextTokensPrompt]:
@@ -384,10 +409,14 @@ def _tokenize_prompt_input_or_inputs(
                     request,
                     tokenizer,
                     prompt=prompt_input["content"],
+                    suffix=suffix,
                     truncate_prompt_tokens=truncate_prompt_tokens,
                     add_special_tokens=add_special_tokens,
                 )
             else:
+                if suffix:
+                    raise ValueError(
+                        "suffix is only supported with string prompt input")
                 yield self._normalize_prompt_tokens_to_input(
                     request,
                     tokenizer,
@@ -400,6 +429,7 @@ def _preprocess_completion(
         request: CompletionLikeRequest,
         tokenizer: AnyTokenizer,
         input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+        suffix: Optional[str] = None,
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         add_special_tokens: bool = True,
     ) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
@@ -409,6 +439,7 @@ def _preprocess_completion(
                 request,
                 tokenizer,
                 input_or_inputs,
+                suffix=suffix,
                 truncate_prompt_tokens=truncate_prompt_tokens,
                 add_special_tokens=add_special_tokens,
             )
@@ -426,7 +457,8 @@ async def _preprocess_chat(
         request: ChatLikeRequest,
         tokenizer: AnyTokenizer,
         messages: List[ChatCompletionMessageParam],
-        chat_template: Optional[str] = None,
+        chat_template: Optional[str],
+        chat_template_content_format: ChatTemplateContentFormatOption,
         add_generation_prompt: bool = True,
         continue_final_message: bool = False,
         tool_dicts: Optional[List[Dict[str, Any]]] = None,
@@ -437,10 +469,16 @@ async def _preprocess_chat(
         add_special_tokens: bool = False,
     ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
                List[TokensPrompt]]:
+        resolved_content_format = resolve_chat_template_content_format(
+            chat_template,
+            chat_template_content_format,
+            tokenizer,
+        )
         conversation, mm_data_future = parse_chat_messages_futures(
             messages,
             self.model_config,
             tokenizer,
+            content_format=resolved_content_format,
         )
 
         _chat_template_kwargs: Dict[str, Any] = dict(
diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py
index 1fd82304f7a4d..59b3b1311f881 100644
--- a/vllm/entrypoints/openai/serving_tokenization.py
+++ b/vllm/entrypoints/openai/serving_tokenization.py
@@ -1,8 +1,8 @@
-from typing import List, Optional, Union
+from typing import Final, List, Optional, Union
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
 from vllm.entrypoints.logger import RequestLogger
 # yapf conflicts with isort for this block
 # yapf: disable
@@ -33,7 +33,8 @@ def __init__(
         lora_modules: Optional[List[LoRAModulePath]],
         request_logger: Optional[RequestLogger],
         chat_template: Optional[str],
-    ):
+        chat_template_content_format: ChatTemplateContentFormatOption,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -41,12 +42,8 @@ def __init__(
                          prompt_adapters=None,
                          request_logger=request_logger)
 
-        # If this is None we use the tokenizer's default chat template
-        # the list of commonly-used chat template names for HF named templates
-        hf_chat_templates: List[str] = ['default', 'tool_use']
-        self.chat_template = chat_template \
-            if chat_template in hf_chat_templates \
-            else load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
     async def create_tokenize(
         self,
@@ -75,9 +72,12 @@ async def create_tokenize(
                     request,
                     tokenizer,
                     request.messages,
-                    chat_template=self.chat_template,
+                    chat_template=request.chat_template or self.chat_template,
+                    chat_template_content_format=self.
+                    chat_template_content_format,
                     add_generation_prompt=request.add_generation_prompt,
                     continue_final_message=request.continue_final_message,
+                    chat_template_kwargs=request.chat_template_kwargs,
                     add_special_tokens=request.add_special_tokens,
                 )
             else:
diff --git a/vllm/envs.py b/vllm/envs.py
index f320e35971f94..716e835a555f1 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -69,7 +69,6 @@
     VLLM_SKIP_P2P_CHECK: bool = False
     VLLM_TORCH_COMPILE_LEVEL: int = 0
     VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
-    VLLM_CUSTOM_OPS: List[str] = []
     VLLM_DISABLED_KERNELS: List[str] = []
     VLLM_USE_V1: bool = False
     VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
@@ -217,18 +216,6 @@ def get_default_config_root():
     "VLLM_TORCH_COMPILE_CONFIG":
     lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
 
-    # Fine-grained control over which custom ops to enable/disable.
-    # Use 'all' to enable all, 'none' to disable all.
-    # Also specify a list of custom op names to enable (prefixed with a '+'),
-    # or disable (prefixed with a '-').
-    # Examples:
-    # - 'all,-op1' to enable all except op1
-    # - 'none,+op1,+op2' to enable only op1 and op2
-    # By default, all custom ops are enabled when running without Inductor
-    # and disabled when running with Inductor (compile_level >= Inductor).
-    "VLLM_CUSTOM_OPS":
-    lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
-
     # local rank of the process in the distributed setting, used to determine
     # the GPU device id
     "LOCAL_RANK":
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 4ceb5a837dd7f..1542a2ae367eb 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -2,9 +2,6 @@
 from functools import partial
 from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
 
-import vllm.envs as envs
-from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
-                         SchedulerConfig)
 from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                   ResultHandler, WorkerMonitor)
@@ -13,7 +10,7 @@
 from vllm.model_executor.layers.sampler import SamplerOutput
 from vllm.prompt_adapter.request import PromptAdapterRequest
 from vllm.sequence import ExecuteModelRequest
-from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
+from vllm.utils import (get_distributed_init_method, get_open_port,
                         get_vllm_instance_id, make_async)
 from vllm.worker.worker_base import WorkerWrapperBase
 
@@ -57,13 +54,6 @@ def _init_executor(self) -> None:
         os.environ["LOCAL_WORLD_SIZE"] = str(
             self.parallel_config.tensor_parallel_size)
 
-        self.model_config = _verify_and_get_model_config(self.model_config)
-        self.cache_config = _verify_and_get_cache_config(self.cache_config)
-        self.scheduler_config = _verify_and_get_scheduler_config(
-            self.scheduler_config)
-        self.parallel_config = _verify_and_get_parallel_config(
-            self.parallel_config)
-
         # Multiprocessing-based executor does not support multi-node setting.
         # Since it only works for single node, we can use the loopback address
         # 127.0.0.1 for communication.
@@ -313,62 +303,6 @@ async def check_health_async(self) -> None:
         self.check_health()
 
 
-def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
-    # Reminder: Please update docs/source/serving/compatibility_matrix.rst
-    # If the feature combo become valid
-    if not config.enforce_eager:
-        logger.warning(
-            "CUDA graph is not supported on CPU, fallback to the eager "
-            "mode.")
-        config.enforce_eager = True
-    return config
-
-
-def _verify_and_get_scheduler_config(
-        config: SchedulerConfig) -> SchedulerConfig:
-    # Reminder: Please update docs/source/serving/compatibility_matrix.rst
-    # If the feature combo become valid
-    if config.chunked_prefill_enabled:
-        logger.warning("Chunked prefill is not supported on CPU, disable it.")
-        config.chunked_prefill_enabled = False
-
-    return config
-
-
-def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
-    # Reminder: Please update docs/source/serving/compatibility_matrix.rst
-    # If the feature combo become valid
-    if config.enable_prefix_caching:
-        logger.warning("Prefix caching is not supported on CPU, disable it.")
-        config.enable_prefix_caching = False
-
-    kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
-
-    if kv_cache_space >= 0:
-        if kv_cache_space == 0:
-            config.cpu_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
-            logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
-                           "for CPU backend is not set, using 4 by default.")
-        else:
-            config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore
-    else:
-        raise RuntimeError(
-            "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
-            f" {kv_cache_space}, expect a positive integer value.")
-
-    return config
-
-
-def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
-    if (config.distributed_executor_backend is not None
-            and config.distributed_executor_backend != "mp"):
-        logger.warning(
-            "%s is not supported on CPU, fallback to mp distributed executor "
-            "backend.", config.distributed_executor_backend)
-        config.distributed_executor_backend = "mp"
-    return config
-
-
 def _driver_method_invoker(driver, method: str, *args, **kwargs):
     return getattr(driver, method)(*args, **kwargs)
 
diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py
index 41dd59bc65ec5..4f28efd639084 100644
--- a/vllm/executor/ray_utils.py
+++ b/vllm/executor/ray_utils.py
@@ -234,7 +234,7 @@ def initialize_ray_cluster(
     if current_platform.is_rocm() or current_platform.is_xpu():
         # Try to connect existing ray instance and create a new one if not found
         try:
-            ray.init("auto")
+            ray.init("auto", ignore_reinit_error=True)
         except ConnectionError:
             logger.warning(
                 "No existing RAY instance detected. "
diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py
index 07ff9faa50f13..fb7dbbebd7b90 100644
--- a/vllm/inputs/data.py
+++ b/vllm/inputs/data.py
@@ -38,6 +38,9 @@ class TokensPrompt(TypedDict):
     prompt_token_ids: List[int]
     """A list of token IDs to pass to the model."""
 
+    token_type_ids: NotRequired[List[int]]
+    """A list of token type IDs to pass to the cross encoder model."""
+
     multi_modal_data: NotRequired["MultiModalDataDict"]
     """
     DEPRECATED: Optional multi-modal data to pass to the model,
@@ -133,6 +136,9 @@ class TokenInputs(TypedDict):
     prompt_token_ids: List[int]
     """The token IDs of the prompt."""
 
+    token_type_ids: NotRequired[List[int]]
+    """The token type IDs of the prompt."""
+
     prompt: NotRequired[str]
     """
     The original prompt text corresponding to the token IDs, if available.
@@ -160,6 +166,7 @@ class TokenInputs(TypedDict):
 
 def token_inputs(
     prompt_token_ids: List[int],
+    token_type_ids: Optional[List[int]] = None,
     prompt: Optional[str] = None,
     multi_modal_data: Optional["MultiModalDataDict"] = None,
     multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
@@ -170,6 +177,8 @@ def token_inputs(
 
     if prompt is not None:
         inputs["prompt"] = prompt
+    if token_type_ids is not None:
+        inputs["token_type_ids"] = token_type_ids
     if multi_modal_data is not None:
         inputs["multi_modal_data"] = multi_modal_data
     if multi_modal_placeholders is not None:
@@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]:
 
         assert_never(inputs)
 
+    @cached_property
+    def token_type_ids(self) -> List[int]:
+        inputs = self.inputs
+
+        if inputs["type"] == "token" or inputs["type"] == "multimodal":
+            return inputs.get("token_type_ids", [])
+
+        assert_never(inputs)
+
     @cached_property
     def prompt_embeds(self) -> Optional[torch.Tensor]:
         inputs = self.inputs
diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py
index aacff87df6d79..1801397811b22 100644
--- a/vllm/inputs/preprocess.py
+++ b/vllm/inputs/preprocess.py
@@ -305,6 +305,7 @@ def _prompt_to_llm_inputs(
             tokens_content = parsed["content"]
 
             prompt_token_ids = tokens_content["prompt_token_ids"]
+            token_type_ids = tokens_content.get("token_type_ids")
             multi_modal_data = tokens_content.get("multi_modal_data")
             mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
 
@@ -318,6 +319,7 @@ def _prompt_to_llm_inputs(
 
             return token_inputs(
                 prompt_token_ids=prompt_token_ids,
+                token_type_ids=token_type_ids,
                 multi_modal_data=multi_modal_data,
                 mm_processor_kwargs=mm_processor_kwargs,
             )
diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py
index 6a32387a6f36c..f176259fddc78 100644
--- a/vllm/lora/ops/bgmv_expand.py
+++ b/vllm/lora/ops/bgmv_expand.py
@@ -75,7 +75,9 @@ def _bgmv_expand_kernel(
             other=0.0,
         )  # [BLOCK_N,BLOCK_K]
         if ADD_INPUTS:
-            tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
+            tiled_out = tl.load(c_ptr + current_n * cn_stride,
+                                mask=c_mask,
+                                other=0.0)
             accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
         else:
             accumulator = tl.sum(tiled_a * tiled_b, 1)
diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py
index 73628fd20d327..2c6ed96c253f0 100644
--- a/vllm/lora/ops/bgmv_expand_slice.py
+++ b/vllm/lora/ops/bgmv_expand_slice.py
@@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel(
         )  # [BLOCK_N,BLOCK_K]
 
         if ADD_INPUTS:
-            tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
+            # explicitly pass in other=None to tell triton that masked values
+            # can be uninitialized. This is OK because the later tl.store
+            # operation uses the same mask, eliminating the risk of garbage
+            # values propagating
+            tiled_out = tl.load(c_ptr + current_n * cn_stride,
+                                mask=c_mask,
+                                other=None)
             accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
         else:
             accumulator = tl.sum(tiled_a * tiled_b, 1)
diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py
index 4910cb4061298..ee2cd2e05e2ee 100644
--- a/vllm/lora/ops/sgmv_expand.py
+++ b/vllm/lora/ops/sgmv_expand.py
@@ -88,7 +88,10 @@ def _sgmv_expand_kernel(
     c_mask = (offset_cm[:, None] <
               (cur_seq_start + M)) & (offset_cn[None, :] < N)
     if ADD_INPUTS:
-        tiled_out = tl.load(c_ptr, mask=c_mask)
+        # explicitly pass in other=None to tell triton that masked values
+        # can be uninitialized. This is OK because the later tl.store operation
+        # uses the same mask, eliminating the risk of garbage values propagating
+        tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
         tiled_c += tiled_out
     tl.store(c_ptr, tiled_c, mask=c_mask)
 
diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py
index 844f5cec39e93..5244fa14913a4 100644
--- a/vllm/lora/ops/sgmv_expand_slice.py
+++ b/vllm/lora/ops/sgmv_expand_slice.py
@@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel(
     c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
                                                            (slice_offset + N))
     if ADD_INPUTS:
-        tiled_out = tl.load(c_ptr, mask=c_mask)
+        # explicitly pass in other=None to tell triton that masked values
+        # can be uninitialized. This is OK because the later tl.store operation
+        # uses the same mask, eliminating the risk of garbage values propagating
+        tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
         tiled_c += tiled_out
     tl.store(c_ptr, tiled_c, mask=c_mask)
 
diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py
index 24d75f4df4e02..6ae7d7cf6964f 100644
--- a/vllm/model_executor/custom_op.py
+++ b/vllm/model_executor/custom_op.py
@@ -1,12 +1,10 @@
-from functools import lru_cache
 from typing import Dict, Type
 
 import torch.nn as nn
 
-import vllm.envs as envs
-from vllm.compilation.levels import CompilationLevel
 from vllm.logger import init_logger
 from vllm.platforms import current_platform
+from vllm.plugins import get_current_vllm_config
 from vllm.utils import print_warning_once
 
 logger = init_logger(__name__)
@@ -87,6 +85,8 @@ def dispatch_forward(self):
     @classmethod
     def enabled(cls) -> bool:
         # if no name, then it was not registered
+        compilation_config = get_current_vllm_config().compilation_config
+        custom_ops = compilation_config.custom_ops
         if not hasattr(cls, "name"):
             print_warning_once(
                 f"Custom op {cls.__name__} was not registered, "
@@ -94,22 +94,25 @@ def enabled(cls) -> bool:
                 f"It will be enabled/disabled based on the global settings.")
             return CustomOp.default_on()
 
-        enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
-        disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
+        enabled = f"+{cls.name}" in custom_ops
+        disabled = f"-{cls.name}" in custom_ops
         assert not (enabled
                     and disabled), f"Cannot enable and disable {cls.name}"
 
         return (CustomOp.default_on() or enabled) and not disabled
 
-    # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
-    # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
     @staticmethod
-    @lru_cache
     def default_on() -> bool:
-        count_none = envs.VLLM_CUSTOM_OPS.count("none")
-        count_all = envs.VLLM_CUSTOM_OPS.count("all")
-        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
-        return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
+        """
+        On by default if level < CompilationLevel.PIECEWISE
+        Specifying 'all' or 'none' in custom_op takes precedence.
+        """
+        from vllm.config import CompilationLevel
+        compilation_config = get_current_vllm_config().compilation_config
+        custom_ops = compilation_config.custom_ops
+        count_none = custom_ops.count("none")
+        count_all = custom_ops.count("all")
+        return compilation_config.level < CompilationLevel.PIECEWISE and \
             not count_none > 0 or count_all > 0
 
     # Dictionary of all custom ops (classes, indexed by registered name).
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 340da32263c1c..e6f9f01ef0f74 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -105,16 +105,18 @@ def fused_moe_kernel(
     num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
     if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
         return
-    offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+    offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
+        tl.int64)
     offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
     token_mask = offs_token < num_valid_tokens
 
-    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+    offs_bn = (pid_n * BLOCK_SIZE_N +
+               tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
     offs_k = tl.arange(0, BLOCK_SIZE_K)
     a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
                       offs_k[None, :] * stride_ak)
 
-    off_experts = tl.load(expert_ids_ptr + pid_m)
+    off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
     b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
                                                 offs_bn[None, :] * stride_bn)
     if use_int8_w8a16:
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 94f30412e43b3..226909c72e8be 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -1,5 +1,5 @@
 from abc import abstractmethod
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Type
 
 import torch
 import torch.nn.functional as F
@@ -27,7 +27,7 @@
     "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
     "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
     "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
-    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
+    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
 ]
 
 
@@ -135,6 +135,36 @@ def apply(self,
         return F.linear(x, layer.weight, bias)
 
 
+class TiedWeightLinearMethod(UnquantizedLinearMethod):
+    """Linear method base with noop create_weights
+
+    Can be used to prevent the initialization of weights
+    during the initialization of modules with weight tying.
+    """
+
+    def create_weights(self, layer: torch.nn.Module,
+                       input_size_per_partition: int,
+                       output_partition_sizes: List[int], input_size: int,
+                       output_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
+        ...
+
+
+class QuantizationConfigOverride(QuantizationConfig):
+    """Config class to inject a specific LinearMethod.
+    """
+
+    def __init__(self, cls: Type[LinearMethodBase]):
+        self.cls = cls
+
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional[LinearMethodBase]:
+        return self.cls()
+
+
+QuantizationConfigOverride.__abstractmethods__ = frozenset()
+
+
 class LinearBase(torch.nn.Module):
     """Base linear layer.
 
diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py
index 6fee57a0a03eb..bfe2d7d0f382e 100644
--- a/vllm/model_executor/layers/pooler.py
+++ b/vllm/model_executor/layers/pooler.py
@@ -118,14 +118,13 @@ def forward(
             if returned_token_ids is not None and len(returned_token_ids) > 0:
                 hidden_states = hidden_states[:, returned_token_ids]
 
-            logits = hidden_states.softmax(dim=-1)
             step_tag_id = self.step_tag_id
 
             offset = 0
             pooled_data_lst = []
             for prompt_len, seq_data_i in zip(
                     prompt_lens, pooling_metadata.seq_data.values()):
-                pooled_data_i = logits[offset:offset + prompt_len]
+                pooled_data_i = hidden_states[offset:offset + prompt_len]
                 if step_tag_id is not None:
                     token_ids = torch.tensor(seq_data_i.prompt_token_ids)
                     pooled_data_i = pooled_data_i[token_ids == step_tag_id]
diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py
index bbb7fc8ad5087..ace8f4a348812 100644
--- a/vllm/model_executor/layers/quantization/awq_triton.py
+++ b/vllm/model_executor/layers/quantization/awq_triton.py
@@ -42,7 +42,7 @@ def awq_dequantize_kernel(
     result_masks = result_masks_y[:, None] & result_masks_x[None, :]
 
     # Load the weights.
-    iweights = tl.load(qweight_ptr + offsets, masks)
+    iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
     iweights = tl.interleave(iweights, iweights)
     iweights = tl.interleave(iweights, iweights)
     iweights = tl.interleave(iweights, iweights)
@@ -71,7 +71,7 @@ def awq_dequantize_kernel(
     zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
 
     # Load the zeros.
-    zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
+    zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
     zeros = tl.interleave(zeros, zeros)
     zeros = tl.interleave(zeros, zeros)
     zeros = tl.interleave(zeros, zeros)
@@ -91,7 +91,7 @@ def awq_dequantize_kernel(
     scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
 
     # Load the scales.
-    scales = tl.load(scales_ptr + scale_offsets, scale_masks)
+    scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
     scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
 
     # Dequantize.
@@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
         masks_k = offsets_k < K
         masks_a = masks_am[:, None] & masks_k[None, :]
-        a = tl.load(a_ptrs, mask=masks_a)
+        a = tl.load(a_ptrs, mask=masks_a, other=0.0)
 
         masks_b = masks_k[:, None] & masks_bn[None, :]
-        b = tl.load(b_ptrs, mask=masks_b)
+        b = tl.load(b_ptrs, mask=masks_b, other=0.0)
         b = tl.interleave(b, b)
         b = tl.interleave(b, b)
         b = tl.interleave(b, b)
@@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
         masks_zk = offsets_szk < K // group_size
         masks_z = masks_zk[:, None] & masks_zn[None, :]
         zeros_ptrs = zeros_ptr + offsets_z
-        zeros = tl.load(zeros_ptrs, mask=masks_z)
+        zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
         zeros = tl.interleave(zeros, zeros)
         zeros = tl.interleave(zeros, zeros)
         zeros = tl.interleave(zeros, zeros)
@@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
         masks_sk = offsets_szk < K // group_size
         masks_s = masks_sk[:, None] & masks_sn[None, :]
         scales_ptrs = scales_ptr + offsets_s
-        scales = tl.load(scales_ptrs, mask=masks_s)
+        scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
         scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
 
         b = (b >> shifts) & 0xF
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 0aa605e62454e..abafad0f1047e 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -210,7 +210,6 @@ def create_weights(
 
     def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
         # for torch.compile
-        layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
         layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
         layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
         layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 1f72e3afbbce5..a3e58bf1b2a4c 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -23,6 +23,7 @@
                                            PackedColumnParameter,
                                            PackedvLLMParameter,
                                            RowvLLMParameter)
+from vllm.platforms import current_platform
 from vllm.scalar_type import scalar_types
 
 logger = init_logger(__name__)
@@ -134,6 +135,9 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
         sym = quant_config.get("sym")
         desc_act = quant_config.get("desc_act")
 
+        if not current_platform.is_cuda():
+            return False
+
         if quant_method != "gptq":
             return False
 
diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py
index 330c2ad195d78..c16a962134d06 100644
--- a/vllm/model_executor/layers/quantization/ipex_quant.py
+++ b/vllm/model_executor/layers/quantization/ipex_quant.py
@@ -2,21 +2,26 @@
 
 import torch
 
-from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
-from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
+from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
+                                               UnquantizedLinearMethod)
+from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
+                                                         is_layer_skipped_awq)
 from vllm.model_executor.layers.quantization.base_config import (
     QuantizationConfig)
+from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
 from vllm.platforms import current_platform
 
+MIN_IPEX_VERSION = "2.5.0"
+
 
 class IPEXConfig(QuantizationConfig):
-    """INT8 quantization config class using IPEX for the CPU backend,
-    including AWQ.
+    """INT8 quantization config class using IPEX for the CPU/XPU backend,
+    including AWQ, GPTQ.
     """
 
     IPEX_QUANT_METHOD_MAP = {
         "awq": 1,
-        "gptq": 2,
+        "gptq": 0,
     }
 
     def __init__(
@@ -24,29 +29,30 @@ def __init__(
         method: str,
         weight_bits: int,
         group_size: int,
+        modules_to_not_convert: Optional[List[str]] = None,
+        desc_act: Optional[bool] = None,
+        lm_head_quantized: Optional[bool] = None,
     ) -> None:
         self.method = method
         self.weight_bits = weight_bits
         self.group_size = group_size
+        self.modules_to_not_convert = modules_to_not_convert or []
+        self.desc_act = desc_act
+        self.lm_head_quantized = lm_head_quantized
         self.pack_factor = 32 // self.weight_bits
 
         if self.weight_bits not in [4]:
             raise ValueError(f"IPEX quantization supports weight bits [4], "
                              f"but got {self.weight_bits}.")
 
-        if self.method == "awq":
-            self.quant_method = IPEXAWQLinearMethod
-        else:
-            raise ValueError(f"IPEX quantization supports [awq], "
+        if self.method not in ["awq", "gptq"]:
+            raise ValueError(f"IPEX quantization supports [awq, gptq], "
                              f"but got {self.method}.")
 
     def __repr__(self) -> str:
-        return (f"IPEXConfig(method={self.method}"
+        return (f"IPEXConfig(method={self.method},"
                 f"weight_bits={self.weight_bits}, "
-                f"group_size={self.group_size}")
-
-    def get_ipex_quant_method_id(self) -> int:
-        return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
+                f"group_size={self.group_size})")
 
     @classmethod
     def get_name(cls) -> str:
@@ -70,19 +76,32 @@ def get_config_filenames() -> List[str]:
     @classmethod
     def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
         method = cls.get_from_keys(config, ["quant_method"]).lower()
-        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
-        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
-        return cls(method, weight_bits, group_size)
+        if method == "awq":
+            weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
+            group_size = cls.get_from_keys(config,
+                                           ["q_group_size", "group_size"])
+            modules_to_not_convert = cls.get_from_keys_or(
+                config, ["modules_to_not_convert"], None)
+            return cls(method, weight_bits, group_size, modules_to_not_convert,
+                       False, False)
+        # otherwise for gptq
+        weight_bits = cls.get_from_keys(config, ["bits"])
+        group_size = cls.get_from_keys(config, ["group_size"])
+        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
+                                                 default=False)
+        desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
+        return cls(method, weight_bits, group_size, [], desc_act,
+                   lm_head_quantized)
 
     @classmethod
     def override_quantization_method(cls, hf_quant_cfg,
                                      user_quant) -> Optional[str]:
-        if not current_platform.is_cpu():
+        if not current_platform.is_cpu() and not current_platform.is_xpu():
             return None
 
         quant_method = hf_quant_cfg.get("quant_method", "").lower()
 
-        if quant_method in ["awq"]:
+        if quant_method in ["awq", "gptq"]:
             return cls.get_name()
 
         return None
@@ -90,12 +109,81 @@ def override_quantization_method(cls, hf_quant_cfg,
     def get_quant_method(self, layer: torch.nn.Module,
                          prefix: str) -> Optional["LinearMethodBase"]:
         if isinstance(layer, LinearBase):
-            return self.quant_method(self)
+            if self.method == "awq":
+                if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
+                    return UnquantizedLinearMethod()
+                return IPEXAWQLinearMethod(self)
+            if self.method == "gptq":
+                return IPEXGPTQLinearMethod(self)
         return None
 
 
+class IPEXGPTQLinearMethod(GPTQLinearMethod):
+    """GPTQ linear method using IPEX for the CPU/XPU backend.
+    """
+
+    def __init__(self, quant_config: IPEXConfig):
+        self.quant_config = quant_config  # type: ignore
+
+    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        bias = layer.bias if not layer.skip_bias_add else None
+
+        try:
+            import intel_extension_for_pytorch as ipex
+            if ipex.__version__ < MIN_IPEX_VERSION:
+                raise ImportError(
+                    "intel_extension_for_pytorch version is "
+                    "wrong. Please install "
+                    f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
+        except ImportError as err:
+            raise ImportError(
+                "Please install "
+                f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
+                f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
+                " to use IPEX-AWQ linear method.") from err
+        # Using the compute dtype (lowp_mode) as INT8 to leverage instructions
+        # with better performance.
+        lowp_mode = ipex.quantization.WoqLowpMode.INT8
+        # The weight will be de-packed from INT4 to INT8.
+        weight_dtype = ipex.quantization.WoqWeightDtype.INT4
+        # The float activation will be quantized (dynamic, per-token) to INT8.
+        act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
+
+        qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
+            weight_dtype=weight_dtype,
+            lowp_mode=lowp_mode,
+            act_quant_mode=act_quant_mode,
+            group_size=self.quant_config.group_size,
+        )
+        layer.ipex_output_size = layer.qweight.shape[-1]
+        g_idx = layer.g_idx if self.quant_config.desc_act else None
+        layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
+            IPEXWeightOnlyQuantizedLinear.from_weight(
+            layer.qweight,
+            layer.scales,
+            layer.qzeros,
+            layer.qweight.size(0),
+            layer.ipex_output_size,
+            qconfig=qconfig,
+            g_idx=g_idx,
+            bias=bias,
+            group_size=self.quant_config.group_size,
+            quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
+        )
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        reshaped_x = x.reshape(-1, x.shape[-1])
+        out = layer.ipex_qlinear(reshaped_x)
+        if bias is not None:
+            out.add_(bias)
+        return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
+
+
 class IPEXAWQLinearMethod(AWQLinearMethod):
-    """AWQ linear method using IPEX for the CPU backend.
+    """AWQ linear method using IPEX for the CPU/XPU backend.
     """
 
     def __init__(self, quant_config: IPEXConfig):
@@ -108,15 +196,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
 
         try:
             import intel_extension_for_pytorch as ipex
-            if ipex.__version__ < "2.4.0":
-                raise ImportError("intel_extension_for_pytorch version is "
-                                  "wrong. Please install "
-                                  "intel_extension_for_pytorch>=2.4.0.")
+            if ipex.__version__ < MIN_IPEX_VERSION:
+                raise ImportError(
+                    "intel_extension_for_pytorch version is "
+                    "wrong. Please install "
+                    f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
         except ImportError as err:
             raise ImportError(
                 "Please install "
-                "intel_extension_for_pytorch>=2.4.0 via "
-                "`pip install intel_extension_for_pytorch>=2.4.0`"
+                f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
+                f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
                 " to use IPEX-AWQ linear method.") from err
 
         # Using the compute dtype (lowp_mode) as INT8 to leverage instructions
@@ -136,19 +225,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
 
         layer.ipex_output_size = layer.qweight.size(
             1) * self.quant_config.pack_factor
-        layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\
-            WeightOnlyQuantizedLinear.from_weight(
-                layer.qweight,
-                layer.scales,
-                layer.qzeros,
-                layer.qweight.size(0),
-                layer.ipex_output_size,
-                qconfig=qconfig,
-                bias=bias,
-                group_size=self.quant_config.group_size,
-                quant_method=
-                    self.quant_config.get_ipex_quant_method_id() # type: ignore
-            )
+        layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
+            IPEXWeightOnlyQuantizedLinear.from_weight(
+            layer.qweight,
+            layer.scales,
+            layer.qzeros,
+            layer.qweight.size(0),
+            layer.ipex_output_size,
+            qconfig=qconfig,
+            bias=bias,
+            group_size=self.quant_config.group_size,
+            quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"]  # type: ignore
+        )
 
     def apply(self,
               layer: torch.nn.Module,
@@ -156,5 +244,4 @@ def apply(self,
               bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         reshaped_x = x.reshape(-1, x.shape[-1])
         out = layer.ipex_qlinear(reshaped_x)
-
         return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py
index e5696d08f30f5..15df0200f30b5 100644
--- a/vllm/model_executor/layers/quantization/kernels/machete.py
+++ b/vllm/model_executor/layers/quantization/kernels/machete.py
@@ -79,7 +79,9 @@ def transform_w_q(x):
                                                           c.weight_type,
                                                           packed_dim=0)
             x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
-                                           self.config.weight_type)
+                                           a_type=c.act_type,
+                                           b_type=c.weight_type,
+                                           group_scales_type=c.act_type)
             return x
 
         def transform_w_s(x):
@@ -105,12 +107,12 @@ def apply_weights(self,
         if c.has_g_idx:
             x_2d = self.act_perm(x_2d)
 
-        output = ops.machete_gemm(a=x_2d,
-                                  b_q=w_q,
-                                  b_type=c.weight_type,
-                                  b_zeros=None,
-                                  b_scales=w_s,
-                                  b_group_size=c.group_size)
+        output = ops.machete_mm(a=x_2d,
+                                b_q=w_q,
+                                b_type=c.weight_type,
+                                b_group_zeros=None,
+                                b_group_scales=w_s,
+                                b_group_size=c.group_size)
 
         if bias is not None:
             output.add_(bias)  # In-place add
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index c217f5ca620a1..83055d6000d83 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
 
 def quantize_weights(w: torch.Tensor,
                      quant_type: ScalarType,
-                     group_size: int,
+                     group_size: Optional[int],
                      zero_points: bool = False,
                      ref_zero_points_after_scales: bool = False):
     assert quant_type.is_integer(), \
         "Floating point quantization may work but has not been tested"
+    assert not zero_points or group_size is not None, \
+        "to have group zero points, group_size must be provided "\
+        "(-1 group_size is channelwise)"
 
     orig_device = w.device
     orig_type = w.dtype
@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
 
     if group_size == -1:
         group_size = size_k
-    assert group_size <= size_k
 
     # Reshape to [groupsize, -1]
-    if group_size < size_k:
+    if group_size is not None and group_size < size_k:
         w = w.reshape((-1, group_size, size_n))
         w = w.permute(1, 0, 2)
         w = w.reshape((group_size, -1))
@@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
     max_q_val = quant_type.max()
     min_q_val = quant_type.min()
 
-    if zero_points:
-        assert not quant_type.is_signed() and quant_type.max() > 0
-        w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
-        maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
-            .clamp(min_q_val, max_q_val).int()
-    else:
-        # If the bias is such that there are no possible negative/positive
-        #  values, set the max value to inf to avoid divide by 0
-        w_s = torch.max(
-            abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
-            abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
-        maybe_w_zp = None
+    w_s = torch.Tensor([1.0]).to(w.device)  # unscaled case
+    maybe_w_zp = None
+    if group_size is not None:
+        if zero_points:
+            assert not quant_type.is_signed() and quant_type.max() > 0
+            w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
+            maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
+                .clamp(min_q_val, max_q_val).int()
+        else:
+            # If the bias is such that there are no possible negative/positive
+            #  values, set the max value to inf to avoid divide by 0
+            w_s = torch.max(
+                abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
+                abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
 
     # Quantize
     w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
     # For some kernels (namely Machete) the zero-points are applied after the
     # scales are applied, for this case computing the reference in similar way
     # allows us to use tighter error tolerances in our unit tests.
-    if ref_zero_points_after_scales and zero_points:
+    if ref_zero_points_after_scales and maybe_w_zp is not None:
         w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
     else:
         w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
         w_q += quant_type.bias
 
     # Restore original shapes
-    if group_size < size_k:
+    if group_size is not None and group_size < size_k:
 
         def reshape_w(w):
             w = w.reshape((group_size, -1, size_n))
@@ -195,17 +199,16 @@ def reshape_w(w):
 
         w_q = reshape_w(w_q)
         w_ref = reshape_w(w_ref)
+        w_s = w_s.reshape((-1, size_n)).contiguous()
 
-    w_s = w_s.reshape((-1, size_n)).contiguous()
-
-    if zero_points:
+    if maybe_w_zp is not None:
         maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
         maybe_w_zp = maybe_w_zp.to(device=orig_device)
 
     return (
         w_ref.to(device=orig_device),
         w_q.to(device=orig_device),
-        w_s.to(device=orig_device),
+        w_s if group_size is not None else None,
         maybe_w_zp,
     )
 
diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py
index 2e9a0e170693b..3ab0ba9e9f5c2 100644
--- a/vllm/model_executor/layers/rejection_sampler.py
+++ b/vllm/model_executor/layers/rejection_sampler.py
@@ -368,7 +368,7 @@ def _smallest_positive_value(self) -> float:
 # Note that we always sample with replacement.
 # probs will be modified in place, but this is fine, as we pass
 # in a copy already.
-@torch.jit.script
+@torch.compile(dynamic=True)
 def _multinomial(
     probs: torch.Tensor,
     num_samples: int,
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 63ceec63e8317..117fe086e5e87 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -847,6 +847,7 @@ def get_input_positions(
         vision_end_token_id: int,
         spatial_merge_size: int,
         context_len: int = 0,
+        seq_len: Optional[int] = None,
     ) -> Tuple[List[List[int]], int]:
         """Get mrope input positions and delta value."""
 
@@ -921,9 +922,9 @@ def get_input_positions(
                 torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
 
         llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
-        llm_positions = llm_positions[:, context_len:]
         mrope_position_delta = (llm_positions.max() + 1 -
                                 len(input_tokens)).item()
+        llm_positions = llm_positions[:, context_len:seq_len]
 
         return llm_positions.tolist(), mrope_position_delta
 
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index c10efefea5471..8792bd42d54d2 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -995,7 +995,9 @@ def get_logprobs(
     if len(query_indices) == 0:
         empty_sampled_logprob: SampleLogprobs = []
         empty_prompt_logprob: Optional[PromptLogprobs] = None
-        return [empty_prompt_logprob], [empty_sampled_logprob]
+        num_seq_groups = len(sampling_metadata.seq_groups)
+        return [empty_prompt_logprob
+                ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
 
     selected_logprobs, ranks = None, None
     top_logprobs, top_token_ids = None, None
@@ -1262,6 +1264,10 @@ def _build_sampler_output(
         assert sample_logprobs is not None
         assert not isinstance(maybe_deferred_sample_results,
                               SampleResultArgsType)
+        assert len(sampling_metadata.seq_groups) \
+            == len(maybe_deferred_sample_results) \
+            == len(prompt_logprobs) \
+            == len(sample_logprobs)
         deferred_sample_results_args = None
 
         for (seq_group, sample_result, group_prompt_logprobs,
diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py
index 52771f50a7a23..30548e656c557 100644
--- a/vllm/model_executor/layers/vocab_parallel_embedding.py
+++ b/vllm/model_executor/layers/vocab_parallel_embedding.py
@@ -133,13 +133,13 @@ def __post_init__(self):
         assert self.num_added_elements <= self.num_added_elements_padded
 
 
-@torch.jit.script
+@torch.compile(dynamic=True)
 def get_masked_input_and_mask(
         input_: torch.Tensor, org_vocab_start_index: int,
         org_vocab_end_index: int, num_org_vocab_padding: int,
         added_vocab_start_index: int,
         added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    # torch.jit.script will fuse all of the pointwise ops below
+    # torch.compile will fuse all of the pointwise ops below
     # into a single kernel, making it very fast
     org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
                                                           org_vocab_end_index)
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 140b61fe6d56a..b41c23704b7ff 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -29,6 +29,8 @@
 from vllm.logger import init_logger
 from vllm.model_executor.layers.linear import (ReplicatedLinear,
                                                RowParallelLinear)
+from vllm.model_executor.layers.quantization.base_config import (
+    QuantizeMethodBase)
 from vllm.model_executor.model_loader.tensorizer import (
     TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
     serialize_vllm_model, tensorizer_weights_iterator)
@@ -42,6 +44,7 @@
     safetensors_weights_iterator)
 from vllm.model_executor.utils import set_weight_attrs
 from vllm.platforms import current_platform
+from vllm.plugins import set_current_vllm_config
 from vllm.utils import is_pin_memory_available
 
 
@@ -97,7 +100,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
     all_params = [param.name for param in signatures.parameters.values()]
     if "vllm_config" in all_params and "prefix" in all_params:
         # new-style model class
-        return model_class(vllm_config=vllm_config, prefix=prefix)
+        with set_current_vllm_config(vllm_config):
+            return model_class(vllm_config=vllm_config, prefix=prefix)
     msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
            "input arguments. Possibly you have an old-style model class"
            " registered from out of tree and it is used for new vLLM version. "
@@ -121,7 +125,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
         kwargs["lora_config"] = vllm_config.lora_config
     if "scheduler_config" in all_params:
         kwargs["scheduler_config"] = vllm_config.scheduler_config
-    return model_class(**kwargs)
+    with set_current_vllm_config(vllm_config):
+        return model_class(**kwargs)
 
 
 class BaseModelLoader(ABC):
@@ -331,11 +336,21 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
             with target_device:
                 model = _initialize_model(vllm_config=vllm_config)
 
-            model.load_weights(self._get_all_weights(model_config, model))
+            weights_to_load = {name for name, _ in model.named_parameters()}
+            loaded_weights = model.load_weights(
+                self._get_all_weights(model_config, model))
+            # We only enable strict check for non-quantiized models
+            # that have loaded weights tracking currently.
+            if model_config.quantization is None and loaded_weights is not None:
+                weights_not_loaded = weights_to_load - loaded_weights
+                if weights_not_loaded:
+                    raise ValueError(
+                        "Following weights were not initialized from "
+                        f"checkpoint: {weights_not_loaded}")
 
             for _, module in model.named_modules():
                 quant_method = getattr(module, "quant_method", None)
-                if quant_method is not None:
+                if isinstance(quant_method, QuantizeMethodBase):
                     # When quant methods need to process weights after loading
                     # (for repacking, quantizing, etc), they expect parameters
                     # to be on the global target device. This scope is for the
diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py
index 9ee2a2cc09a24..e58ad19cab54c 100644
--- a/vllm/model_executor/models/arctic.py
+++ b/vllm/model_executor/models/arctic.py
@@ -1,5 +1,5 @@
 """Inference-only Snowflake Arctic model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -389,6 +389,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -396,9 +399,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -439,6 +446,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -446,9 +456,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -468,7 +480,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -506,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                         ("ws", f"experts.{expert_id}.w3.weight", expert_id))
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         logger.info(
             "It will take ~10 minutes loading from the 16-bit weights. "
@@ -561,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                         weight_loader = getattr(param, "weight_loader",
                                                 default_weight_loader)
                         weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index aabbd31192a40..3749a16a38994 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -18,7 +18,7 @@
 # limitations under the License.
 """Inference-only BaiChuan model compatible with HuggingFace weights."""
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -284,6 +284,9 @@ def __init__(
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -291,9 +294,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -363,6 +370,9 @@ def __init__(
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -370,9 +380,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -392,13 +404,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -437,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py
index 42dd6119e76f1..0f6347a7fd78b 100644
--- a/vllm/model_executor/models/bert.py
+++ b/vllm/model_executor/models/bert.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -17,8 +17,11 @@
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.pooling_metadata import PoolingMetadata
-from vllm.sequence import IntermediateTensors, PoolerOutput
+from vllm.model_executor.models.interfaces import SupportsCrossEncoding
+from vllm.model_executor.pooling_metadata import (PoolingMetadata,
+                                                  PoolingTensors)
+from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
+                           PoolerOutput)
 
 from .utils import maybe_prefix
 
@@ -48,7 +51,9 @@ def __init__(self, config: BertConfig):
     def forward(
         self,
         input_ids: torch.Tensor,
-        position_ids: Optional[torch.Tensor] = None,
+        seq_lens: torch.Tensor,
+        position_ids: torch.Tensor,
+        token_type_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         input_shape = input_ids.size()
 
@@ -58,17 +63,34 @@ def forward(
         # Position embeddings.
         position_embeddings = self.position_embeddings(position_ids)
 
-        # Token type embeddings. (TODO: move off hotpath?)
-        token_type_embeddings = self.token_type_embeddings(
-            torch.zeros(input_shape,
-                        dtype=torch.long,
-                        device=inputs_embeds.device))
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(input_shape,
+                                         dtype=torch.long,
+                                         device=inputs_embeds.device)
+
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
 
         embeddings = inputs_embeds + token_type_embeddings + position_embeddings
         embeddings = self.LayerNorm(embeddings)
         return embeddings
 
 
+class BertPooler(nn.Module):
+
+    def __init__(self, config: BertConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[0, :]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
 class BertEncoder(nn.Module):
 
     def __init__(self,
@@ -309,7 +331,8 @@ def __init__(self,
                  *,
                  vllm_config: VllmConfig,
                  prefix: str = "",
-                 embedding_class: type = BertEmbedding):
+                 embedding_class: type = BertEmbedding,
+                 add_pooling_layer: bool = False):
         super().__init__()
         config = vllm_config.model_config.hf_config
         cache_config = vllm_config.cache_config
@@ -319,6 +342,7 @@ def __init__(self,
                                    cache_config,
                                    quant_config,
                                    prefix=f"{prefix}.encoder")
+        self.pooler = BertPooler(config) if add_pooling_layer else None
 
     def forward(
         self,
@@ -328,16 +352,21 @@ def forward(
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         if inputs_embeds is not None:
             hidden_states = inputs_embeds
         else:
-            hidden_states = self.embeddings(input_ids=input_ids,
-                                            position_ids=position_ids)
-
+            assert hasattr(attn_metadata, "seq_lens_tensor")
+            hidden_states = self.embeddings(
+                input_ids=input_ids,
+                seq_lens=attn_metadata.seq_lens_tensor,
+                position_ids=position_ids,
+                token_type_ids=token_type_ids)
         return self.encoder(hidden_states, kv_caches, attn_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "query", "q"),
@@ -346,8 +375,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
-            if "pooler" in name:
+            if self.pooler is None and "pooler" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
@@ -368,6 +398,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class BertEmbeddingModel(nn.Module):
@@ -426,3 +458,95 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
                                                 pooling_type=PoolingType.CLS,
                                                 normalize=True,
                                                 softmax=False)
+
+
+class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
+    """A model that uses Bert to provide embedding functionalities.
+
+   This class encapsulates the BertModel and provides an interface for
+   embedding operations and customized pooling functions.
+
+   Attributes:
+       model: An instance of BertModel used for forward operations.
+       _pooler: An instance of Pooler used for pooling operations.
+   """
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+
+        self.num_labels = config.num_labels
+        self.bert = BertModel(vllm_config=vllm_config,
+                              prefix=maybe_prefix(prefix, "bert"),
+                              embedding_class=BertEmbedding,
+                              add_pooling_layer=True)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        self_weights = []
+
+        def weight_filter():
+            for name, weight in weights:
+                if name.startswith("bert."):
+                    yield (name[len("bert."):], weight)
+                else:
+                    self_weights.append((name, weight))
+
+        self.bert.load_weights(weight_filter())
+
+        params_dict = dict(self.named_parameters())
+
+        for name, loaded_weight in self_weights:
+            if name.startswith("classifier"):
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+
+        prompt_lens = PoolingTensors.from_pooling_metadata(
+            pooling_metadata, hidden_states.device).prompt_lens
+
+        offset = 0
+        pooled_data_lst = []
+        for prompt_len in prompt_lens:
+            pooled_data_i = hidden_states[offset:offset + prompt_len]
+
+            pooled_data_i = self.bert.pooler(pooled_data_i)
+
+            pooled_data_lst.append(pooled_data_i)
+            offset += prompt_len
+
+        pooled_output = torch.stack(pooled_data_lst)
+
+        classifier_output = self.classifier(pooled_output)
+
+        pooled_outputs = [
+            EmbeddingSequenceGroupOutput(data.tolist())
+            for data in classifier_output
+        ]
+        return PoolerOutput(outputs=pooled_outputs)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        return self.bert(input_ids=input_ids,
+                         position_ids=positions,
+                         kv_caches=kv_caches,
+                         inputs_embeds=inputs_embeds,
+                         intermediate_tensors=intermediate_tensors,
+                         attn_metadata=attn_metadata,
+                         token_type_ids=token_type_ids)
diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py
index e612010677364..6af59697160a0 100644
--- a/vllm/model_executor/models/blip.py
+++ b/vllm/model_executor/models/blip.py
@@ -1,13 +1,14 @@
 """Minimal implementation of BlipVisionModel intended to be only used 
 within a vision language model."""
-from typing import Iterable, Optional, Tuple, Union
+from typing import Iterable, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from PIL import Image
 from transformers import Blip2VisionConfig, BlipVisionConfig
-from transformers.models.blip.modeling_blip import BlipAttention
 
+from vllm.attention.selector import _Backend
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
 from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -21,11 +22,7 @@
                                    repeat_and_pad_placeholder_tokens)
 from vllm.sequence import SequenceData
 
-try:
-    from xformers import ops as xops
-    USE_XFORMERS_OPS = True
-except ImportError:
-    USE_XFORMERS_OPS = False
+from .utils import get_vit_attn_backend
 
 
 def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
         return embeddings
 
 
-class BlipParallelAttention(nn.Module):
+class BlipAttention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
     def __init__(
@@ -208,6 +205,12 @@ def __init__(
         self.tp_size = get_tensor_model_parallel_world_size()
         self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
 
+        # Detect attention implementation.
+        self.attn_backend = get_vit_attn_backend(support_fa=False)
+        if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
+            raise RuntimeError(
+                f"BLIP does not support {self.attn_backend} backend now.")
+
     def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
         return tensor.view(bsz, seq_len, self.num_heads,
                            self.head_dim).transpose(1, 2).contiguous()
@@ -231,11 +234,26 @@ def forward(
                                          self.num_heads_per_partition,
                                          self.head_dim)
 
-        out = xops.memory_efficient_attention_forward(query_states,
-                                                      key_states,
-                                                      value_states,
-                                                      p=self.dropout,
-                                                      scale=self.scale)
+        if self.attn_backend == _Backend.XFORMERS:
+            from xformers import ops as xops
+
+            out = xops.memory_efficient_attention_forward(query_states,
+                                                          key_states,
+                                                          value_states,
+                                                          p=self.dropout,
+                                                          scale=self.scale)
+        elif self.attn_backend == _Backend.TORCH_SDPA:
+            query_states, key_states, value_states = (x.transpose(1, 2)
+                                                      for x in (query_states,
+                                                                key_states,
+                                                                value_states))
+            out = F.scaled_dot_product_attention(query_states,
+                                                 key_states,
+                                                 value_states,
+                                                 dropout_p=self.dropout,
+                                                 scale=self.scale)
+            out = out.transpose(1, 2)
+
         out = out.view(bsz, tgt_len, -1)
         attn_output, _ = self.projection(out)
 
@@ -285,18 +303,11 @@ def __init__(
         super().__init__()
 
         # fallback to sdpa attention if tp unavailable
-        num_heads = config.num_attention_heads
-        tp_size = get_tensor_model_parallel_world_size()
-        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = BlipParallelAttention(
-                config,
-                quant_config=quant_config,
-                prefix=f"{prefix}.self_attn",
-            )
-        else:
-            # Blip doesn't have SDPA attention implemented in transformers
-            # use eager attention instead for cpu backend
-            self.self_attn = BlipAttention(config)
+        self.self_attn = BlipAttention(
+            config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.self_attn",
+        )
         self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
         self.mlp = BlipMLP(config,
@@ -374,11 +385,6 @@ def __init__(
         prefix: str = "",
     ) -> None:
         super().__init__()
-
-        tp_size = get_tensor_model_parallel_world_size()
-        num_heads = config.num_attention_heads
-        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
-
         self.config = config
 
         self.embeddings = BlipVisionEmbeddings(config)
@@ -415,14 +421,16 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
 
         return self.post_layernorm(hidden_states)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
-        ] if self.shard_weight else []
+        ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -440,8 +448,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -450,3 +458,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 03dc1d15ab697..7d7639b4a92ce 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -692,6 +692,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index 84adf574af5e2..1060d418474ef 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 """Inference-only BLOOM model compatible with HuggingFace weights."""
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -251,6 +251,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -258,10 +261,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.word_embeddings(input_ids)
-            hidden_states = self.word_embeddings_layernorm(hidden_states)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -301,6 +307,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -308,9 +317,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -330,8 +341,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if name == "lm_head.weight":
                 continue
@@ -360,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 7b59c818e0b60..8f91abffaea90 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
+from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                     Tuple, TypedDict, Union)
 
 import torch
@@ -1034,7 +1034,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 70e9b607b0642..625e31bb0d368 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -3,7 +3,8 @@
 """Inference-only ChatGLM model compatible with THUDM weights."""
 from argparse import Namespace
 from array import array
-from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
+from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
+                    TypedDict)
 
 import torch
 from PIL import Image
@@ -29,6 +30,7 @@
     ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
+from vllm.model_executor.models.module_mapping import MultiModelKeys
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY
 from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
@@ -573,25 +575,8 @@ def forward(
         return hidden_states
 
 
-@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
-class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
-                         SupportsMultiModal):
-    packed_modules_mapping = {
-        "query_key_value": ["query_key_value"],
-        "dense_h_to_4h": ["dense_h_to_4h"]
-    }
-    # LoRA specific attributes
-    supported_lora_modules = [
-        "query_key_value",
-        "dense",
-        "dense_h_to_4h",
-        "dense_4h_to_h",
-    ]
-    embedding_modules = {}
-    embedding_padding_modules = []
+class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP,
+                       SupportsMultiModal):
 
     def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         super().__init__()
@@ -645,7 +630,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
         merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
             "transformer.vision.linear_proj.merged_proj.weight": {
@@ -655,6 +641,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         }
 
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             is_weight_to_be_merge = False
             for _, merged_weight_dict in merged_weights_dict.items():
@@ -677,6 +664,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
 
         for combined_name, merged_weight_dict in merged_weights_dict.items():
             if combined_name in params_dict:
@@ -686,3 +674,81 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, combined_weight)
+                loaded_params.add(combined_name)
+        return loaded_params
+
+
+class ChatGLM(ChatGLMBaseModel):
+    packed_modules_mapping = {
+        "query_key_value": ["query_key_value"],
+        "dense_h_to_4h": ["dense_h_to_4h"]
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "query_key_value",
+        "dense",
+        "dense_h_to_4h",
+        "dense_4h_to_h",
+    ]
+
+    embedding_modules = {}
+    embedding_padding_modules = []
+
+
+class ChatGLMV(ChatGLMBaseModel):
+    packed_modules_mapping = {
+        "query_key_value": ["query_key_value"],
+        "dense_h_to_4h": ["dense_h_to_4h"],
+        "merged_proj": ["gate_proj", "dense_h_to_4h"]
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "query_key_value",
+        "dense",
+        "dense_h_to_4h",
+        "dense_4h_to_h",
+        # vision
+        "fc1",
+        "fc2",
+        "merged_proj",
+        "linear_proj"
+    ]
+
+    embedding_modules = {}
+    embedding_padding_modules = []
+
+    def get_mm_mapping(self) -> MultiModelKeys:
+        """
+        Get the module prefix in multimodal models
+        """
+        return MultiModelKeys.from_string_field(
+            language_model="transformer.encoder",
+            connector="transformer.vision.linear_proj",
+            tower_model="transformer.vision.transformer")
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
+class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
+                         SupportsMultiModal):
+    # Ensure that the LoRA support check passes when the class is not
+    # initialized, but set all these attributes to empty.
+    packed_modules_mapping = {}
+    supported_lora_modules = []
+    embedding_modules = {}
+    embedding_padding_modules = []
+
+    def __new__(
+        cls,
+        vllm_config: VllmConfig,
+        prefix: str = "",
+    ) -> None:
+        config = vllm_config.model_config.hf_config
+        # Initialize VL
+        if hasattr(config, "visual"):
+            return ChatGLM(vllm_config=vllm_config, prefix=prefix)
+        # Initialize LLM
+        else:
+            return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py
index 2d81b9266826b..7f638506f9fb2 100644
--- a/vllm/model_executor/models/clip.py
+++ b/vllm/model_executor/models/clip.py
@@ -1,14 +1,15 @@
 """Minimal implementation of CLIPVisionModel intended to be only used
 within a vision language model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from PIL import Image
 from transformers import CLIPVisionConfig
-from transformers.models.clip.modeling_clip import CLIPSdpaAttention
 
+from vllm.attention.selector import _Backend
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
 from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -23,11 +24,7 @@
                                    repeat_and_pad_placeholder_tokens)
 from vllm.sequence import SequenceData
 
-try:
-    from xformers import ops as xops
-    USE_XFORMERS_OPS = True
-except ImportError:
-    USE_XFORMERS_OPS = False
+from .utils import get_vit_attn_backend
 
 
 def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
         return embeddings
 
 
-class CLIPParallelAttention(nn.Module):
+class CLIPAttention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
     def __init__(
@@ -237,6 +234,12 @@ def __init__(
         self.tp_size = get_tensor_model_parallel_world_size()
         self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
 
+        # Detect attention implementation.
+        self.attn_backend = get_vit_attn_backend(support_fa=False)
+        if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
+            raise RuntimeError(
+                f"CLIP does not support {self.attn_backend} backend now.")
+
     def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
         return tensor.view(bsz, seq_len, self.num_heads,
                            self.head_dim).transpose(1, 2).contiguous()
@@ -261,11 +264,26 @@ def forward(
                                          self.num_heads_per_partition,
                                          self.head_dim)
 
-        out = xops.memory_efficient_attention_forward(query_states,
-                                                      key_states,
-                                                      value_states,
-                                                      p=self.dropout,
-                                                      scale=self.scale)
+        if self.attn_backend == _Backend.XFORMERS:
+            from xformers import ops as xops
+
+            out = xops.memory_efficient_attention_forward(query_states,
+                                                          key_states,
+                                                          value_states,
+                                                          p=self.dropout,
+                                                          scale=self.scale)
+        elif self.attn_backend == _Backend.TORCH_SDPA:
+            query_states, key_states, value_states = (x.transpose(1, 2)
+                                                      for x in (query_states,
+                                                                key_states,
+                                                                value_states))
+            out = F.scaled_dot_product_attention(query_states,
+                                                 key_states,
+                                                 value_states,
+                                                 dropout_p=self.dropout,
+                                                 scale=self.scale)
+            out = out.transpose(1, 2)
+
         out = out.view(bsz, tgt_len, -1)
         attn_output, _ = self.out_proj(out)
 
@@ -311,17 +329,11 @@ def __init__(
         prefix: str = "",
     ) -> None:
         super().__init__()
-
-        num_heads = config.num_attention_heads
-        tp_size = get_tensor_model_parallel_world_size()
-        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = CLIPParallelAttention(
-                config,
-                quant_config=quant_config,
-                prefix=f"{prefix}.self_attn",
-            )
-        else:
-            self.self_attn = CLIPSdpaAttention(config)
+        self.self_attn = CLIPAttention(
+            config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.self_attn",
+        )
         self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
         self.mlp = CLIPMLP(config,
@@ -461,11 +473,6 @@ def __init__(
         prefix: str = "",
     ) -> None:
         super().__init__()
-
-        tp_size = get_tensor_model_parallel_world_size()
-        num_heads = config.num_attention_heads
-        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
-
         self.vision_model = CLIPVisionTransformer(
             config=config,
             quant_config=quant_config,
@@ -483,14 +490,16 @@ def device(self):
 
     # (TODO) Add prefix argument for filtering out weights to be loaded
     #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
-        ] if self.shard_weight else []
+        ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.vision_model.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -508,8 +517,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
+                name = name.replace(weight_name, param_name)
 
-                param = params_dict[name.replace(weight_name, param_name)]
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -518,3 +528,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index cd5c1d6844716..9fd083e5a02a9 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -280,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -287,9 +290,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -354,6 +361,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     @torch.no_grad()
     def forward(
         self,
@@ -362,9 +372,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -390,7 +402,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -435,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
             loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index fff8710f6b475..eab338800249e 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -321,6 +321,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.d_model))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -328,9 +331,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.wte(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors
             hidden_states = intermediate_tensors["hidden_states"]
@@ -376,6 +383,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -383,9 +393,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -405,13 +417,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
 
         expert_params_mapping = [(
             "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
             f"mlp.{weight_name}",
         ) for weight_name in ["w1", "v1", "w2"]]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for param_name, weight_name in expert_params_mapping:
                 if weight_name not in name:
@@ -435,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py
index b38fd9fa49c21..c551853956b92 100644
--- a/vllm/model_executor/models/decilm.py
+++ b/vllm/model_executor/models/decilm.py
@@ -22,7 +22,7 @@
 # limitations under the License.
 """Inference-only DeciLM model compatible with HuggingFace weights."""
 
-from typing import Iterable, Tuple
+from typing import Iterable, Set, Tuple
 
 import torch
 
@@ -57,7 +57,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         delattr(config, "num_key_value_heads_per_layer")
         super().__init__(vllm_config=vllm_config)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -67,6 +68,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -97,6 +99,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
         hidden_size = self.config.hidden_size
diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py
index a9bf1440c4d60..8c5ad9904e925 100644
--- a/vllm/model_executor/models/deepseek.py
+++ b/vllm/model_executor/models/deepseek.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Deepseek model."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -353,6 +353,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -360,9 +363,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             hidden_states = intermediate_tensors["hidden_states"]
@@ -401,6 +408,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -408,9 +418,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -430,7 +442,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -441,6 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -475,3 +489,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 4fb1eed15a2e7..d2c4ca0bf85e9 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only DeepseekV2 model."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -445,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -452,9 +455,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -495,6 +502,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -502,9 +512,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -538,7 +550,8 @@ def make_empty_intermediate_tensors(
                         device=device),
         })
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
@@ -554,6 +567,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.n_routed_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -611,3 +625,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py
index 85c51e8404584..f138d13630263 100644
--- a/vllm/model_executor/models/eagle.py
+++ b/vllm/model_executor/models/eagle.py
@@ -78,6 +78,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
     def sampler(self):
         return self.model.sampler
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -86,11 +89,14 @@ def forward(
         attn_metadata: AttentionMetadata,
         previous_hidden_states: torch.Tensor,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
 
-        tok_embeds = self.model.model.embed_tokens(input_ids)
+        if inputs_embeds is None:
+            inputs_embeds = self.get_input_embeddings(input_ids)
+
         inputs_embeds = self.fc(
-            torch.cat([tok_embeds, previous_hidden_states], dim=-1))
+            torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
 
         inputs_embeds[positions == 0] = 0  # masking inputs at position=0
 
@@ -100,7 +106,8 @@ def forward(
             positions=positions,
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
-            intermediate_tensors=intermediate_tensors)
+            intermediate_tensors=intermediate_tensors,
+        )
         return hidden_states
 
     def compute_logits(self, hidden_states: torch.Tensor,
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index cd3e7da657e0e..9d739d0479548 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -22,7 +22,7 @@
 # limitations under the License.
 """Inference-only Exaone model compatible with HuggingFace weights."""
 
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -479,6 +479,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -486,9 +489,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         model_output = self.transformer(input_ids, positions, kv_caches,
-                                        attn_metadata, intermediate_tensors)
+                                        attn_metadata, intermediate_tensors,
+                                        inputs_embeds)
         return model_output
 
     def compute_logits(
@@ -508,7 +513,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -518,6 +524,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".c_fc_1", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -538,6 +545,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -571,6 +579,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index b3dbf063ac298..2aa4b67d99894 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -18,7 +18,7 @@
 """PyTorch Falcon model."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -367,6 +367,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.word_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -374,9 +377,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.word_embeddings(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             hidden_states = intermediate_tensors["hidden_states"]
         for i in range(self.start_layer, self.end_layer):
@@ -432,6 +439,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.LongTensor,
@@ -439,9 +449,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -461,7 +473,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         total_num_heads = self.config.num_attention_heads
         if self.config.new_decoder_architecture:
             total_num_kv_heads = self.config.num_kv_heads
@@ -471,6 +484,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             total_num_kv_heads = total_num_heads
         num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if name == "lm_head.weight" and self.tie_word_embeddings:
                 # Falcon uses tied embeddings except Falcon-11b.
@@ -507,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py
index 971a71180164b..d3a9ff6915b84 100644
--- a/vllm/model_executor/models/florence2.py
+++ b/vllm/model_executor/models/florence2.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -156,7 +156,8 @@ def sample(self, logits: torch.Tensor,
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -165,12 +166,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -183,6 +185,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class Florence2ForConditionalGeneration(nn.Module):
@@ -248,10 +252,11 @@ def sample(
     ) -> SamplerOutput:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         skip_prefixes = [
             'image_projection', "vision_tower", "image_proj_norm",
             "image_pos_embed", "visual_temporal_embed"
         ]
         loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 31fc098a8bb3f..7b46907ac83ab 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -16,7 +16,8 @@
 """ PyTorch Fuyu model."""
 import math
 from array import array
-from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
+                    TypedDict)
 
 import torch
 import torch.nn as nn
@@ -354,6 +355,7 @@ def sample(
         next_tokens = self.language_model.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 55baba809e58f..64e03b30bf2f1 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -390,6 +390,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -397,9 +400,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -419,7 +424,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -464,3 +470,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             logger.warning(
                 "Some weights are not initialized from checkpoints: %s",
                 unloaded_params)
+        return loaded_params
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index eeb3fd98a7eac..4ba39223cc07f 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -272,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: Optional[torch.Tensor],
@@ -285,7 +288,7 @@ def forward(
             if inputs_embeds is not None:
                 hidden_states = inputs_embeds
             else:
-                hidden_states = self.embed_tokens(input_ids)
+                hidden_states = self.get_input_embeddings(input_ids)
             hidden_states *= self.normalizer
             residual = None
         else:
@@ -309,7 +312,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -351,6 +355,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             logger.warning(
                 "Some weights are not initialized from checkpoints: %s",
                 unloaded_params)
+        return loaded_params
 
 
 class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
@@ -414,6 +419,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -421,9 +429,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -443,13 +453,14 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
 
 class Gemma2EmbeddingModel(nn.Module, SupportsPP):
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index cc85693f99526..1c61408ae1dd9 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -16,7 +16,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-2 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -209,6 +209,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.n_embd))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -220,7 +223,7 @@ def forward(
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
             if inputs_embeds is None:
-                inputs_embeds = self.wte(input_ids)
+                inputs_embeds = self.get_input_embeddings(input_ids)
             position_embeds = self.wpe(position_ids)
             hidden_states = inputs_embeds + position_embeds
         else:
@@ -262,7 +265,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             self.transformer.make_empty_intermediate_tensors)
 
     def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
-        return self.transformer.wte(input_ids)
+        return self.transformer.get_input_embeddings(input_ids)
 
     def forward(
         self,
@@ -295,8 +298,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
@@ -325,3 +330,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index ab25c66c3a887..bcbb5a278ad29 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -17,7 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPTBigCode model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -28,14 +28,19 @@
 from vllm.config import CacheConfig, VllmConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
+# yapf conflicts with isort for this block
+# yapf: disable
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                QKVParallelLinear,
-                                               RowParallelLinear)
+                                               QuantizationConfigOverride,
+                                               RowParallelLinear,
+                                               TiedWeightLinearMethod)
+# yapf: enable
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
 from vllm.model_executor.layers.vocab_parallel_embedding import (
-    ParallelLMHead, VocabParallelEmbedding)
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.sequence import IntermediateTensors
@@ -204,9 +209,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         lora_vocab = (lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0
         self.vocab_size = config.vocab_size + lora_vocab
-        self.wte = VocabParallelEmbedding(self.vocab_size,
-                                          self.embed_dim,
-                                          org_num_embeddings=config.vocab_size)
+        self.wte = VocabParallelEmbedding(
+            self.vocab_size,
+            self.embed_dim,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE
+            # We need bigger padding if using lora for kernel
+            # compatibility
+            if not lora_config else lora_config.lora_vocab_padding_size,
+        )
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.start_layer, self.end_layer, self.h = make_layers(
             config.num_hidden_layers,
@@ -218,6 +229,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.n_embd))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -225,11 +239,12 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            inputs_embeds = self.wte(input_ids)
-            position_embeds = self.wpe(position_ids)
-            hidden_states = inputs_embeds + position_embeds
+            if inputs_embeds is None:
+                inputs_embeds = self.get_input_embeddings(input_ids)
+            hidden_states = inputs_embeds + self.wpe(position_ids)
         else:
             hidden_states = intermediate_tensors["hidden_states"]
 
@@ -248,7 +263,7 @@ def forward(
 class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
     packed_modules_mapping = {"c_attn": ["c_attn"]}
 
-    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
+    supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
 
     embedding_modules = {
         "wte": "input_embeddings",
@@ -269,22 +284,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.quant_config = quant_config
         self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                            prefix=prefix)
-        if self.config.tie_word_embeddings:
-            self.lm_head = self.transformer.wte
-        else:
-            self.lm_head = ParallelLMHead(
-                self.transformer.vocab_size,
-                self.transformer.embed_dim,
-                org_num_embeddings=self.config.vocab_size)
+
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:
             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+        if self.config.tie_word_embeddings:
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=QuantizationConfigOverride(
+                    TiedWeightLinearMethod),
+                params_dtype=self.transformer.wte.weight.dtype,
+            )
+            self.lm_head.register_parameter("weight",
+                                            self.transformer.wte.weight)
+        else:
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=quant_config,
+            )
+
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
         self.sampler = get_sampler()
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -292,9 +332,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -314,10 +356,12 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
-            if "lm_head.weight" in name:
+            if "lm_head.weight" in name and self.config.tie_word_embeddings:
                 continue
             if ".attn.bias" in name:
                 # Skip attention mask.
@@ -335,3 +379,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader(param, loaded_weight, 'v')
             else:
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index a83d03480dde1..d5defc60764e6 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-J model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -201,6 +201,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.n_embd))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -208,9 +211,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.wte(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             hidden_states = intermediate_tensors["hidden_states"]
         for i in range(self.start_layer, self.end_layer):
@@ -250,6 +257,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -257,9 +267,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -279,7 +291,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -289,6 +302,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "attn.bias" in name or "attn.masked_bias" in name:
                 continue
@@ -318,3 +332,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index 794b141bfa4aa..0bb5e2f9b95f9 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -214,6 +214,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_in(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -221,9 +224,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_in(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             hidden_states = intermediate_tensors["hidden_states"]
         for i in range(self.start_layer, self.end_layer):
@@ -262,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.gpt_neox.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.gpt_neox.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -269,9 +279,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
-                                      attn_metadata, intermediate_tensors)
+                                      attn_metadata, intermediate_tensors,
+                                      inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -291,8 +303,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if ("attention.bias" in name or "attention.masked_bias" in name
                     or "rotary_emb.inv_freq" in name):
@@ -325,3 +339,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index d1e6e31f2b8d1..c1e2e87f08ec3 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only IBM Granite model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -409,6 +409,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         else:
             self.lm_head = PPMissingLayer()
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -416,9 +419,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         model_output = self.model(input_ids, positions, kv_caches,
-                                  attn_metadata, intermediate_tensors)
+                                  attn_metadata, intermediate_tensors,
+                                  inputs_embeds)
         return model_output
 
     def compute_logits(
@@ -450,7 +455,8 @@ def make_empty_intermediate_tensors(
                         device=device),
         })
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -460,6 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -480,6 +487,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
@@ -513,6 +521,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index 2ed115c56af45..a91a18816995f 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GraniteMoe model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -277,6 +277,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
 
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -284,9 +287,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             hidden_states *= self.embedding_multiplier
             residual = None
         else:
@@ -366,6 +373,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
 
         self.sampler = get_sampler()
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -373,9 +383,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -407,7 +419,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         new_weights = {}
         for n, p in weights:
             if n.endswith('.block_sparse_moe.input_linear.weight'):
@@ -440,4 +453,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 pass
             else:
                 new_weights[n] = p
-        mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
+        return mixtral.MixtralForCausalLM.load_weights(self,
+                                                       new_weights.items())
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index b21bc2a3f9ce1..16192928beb1f 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 """PyTorch Idefics2 model."""
 
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -331,7 +331,8 @@ def forward(
         last_hidden_state = self.post_layernorm(encoder_outputs)
         return last_hidden_state
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -339,11 +340,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -352,3 +355,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py
index 0cecc754e916f..5d176b2a4e416 100644
--- a/vllm/model_executor/models/idefics3.py
+++ b/vllm/model_executor/models/idefics3.py
@@ -15,7 +15,7 @@
 
 import math
 from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
-                    Optional, Tuple, TypedDict, Union)
+                    Optional, Set, Tuple, TypedDict, Union)
 
 import torch
 import torch.utils.checkpoint
@@ -751,9 +751,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
     def get_mm_mapping(self) -> MultiModelKeys:
         """
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index dcead65115132..4f0c75b2c6a57 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -7,6 +7,8 @@
 from vllm.logger import init_logger
 from vllm.utils import supports_kw
 
+from .interfaces_base import is_embedding_model
+
 if TYPE_CHECKING:
     from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
     from vllm.sequence import IntermediateTensors
@@ -350,3 +352,37 @@ def is_attention_free(
         return isinstance(model, _IsAttentionFreeType)
 
     return isinstance(model, IsAttentionFree)
+
+
+@runtime_checkable
+class SupportsCrossEncoding(Protocol):
+    """The interface required for all models that support cross encoding."""
+
+    supports_cross_encoding: ClassVar[Literal[True]] = True
+
+
+@overload
+def supports_cross_encoding(
+        model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
+    ...
+
+
+@overload
+def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
+    ...
+
+
+def _supports_cross_encoding(
+    model: Union[Type[object], object],
+) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
+
+    if isinstance(model, type):
+        return isinstance(model, SupportsCrossEncoding)
+
+    return isinstance(model, SupportsCrossEncoding)
+
+
+def supports_cross_encoding(
+    model: Union[Type[object], object],
+) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
+    return is_embedding_model(model) and _supports_cross_encoding(model)
diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py
index 9761635d2a6c2..c4346fcb3bd2a 100644
--- a/vllm/model_executor/models/intern_vit.py
+++ b/vllm/model_executor/models/intern_vit.py
@@ -5,13 +5,14 @@
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
 from functools import partial
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from transformers import PretrainedConfig
 
+from vllm.attention.selector import _Backend
 from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size,
                               split_tensor_along_last_dim,
@@ -24,11 +25,7 @@
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 
-try:
-    from xformers import ops as xops
-    USE_XFORMERS_OPS = True
-except ImportError:
-    USE_XFORMERS_OPS = False
+from .utils import get_vit_attn_backend
 
 NORM2FN = {
     'rms_norm': RMSNorm,
@@ -186,6 +183,11 @@ def __init__(
             prefix=f"{prefix}.proj",
         )
 
+        self.attn_backend = get_vit_attn_backend(support_fa=False)
+        if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
+            raise RuntimeError(
+                f"InternViT does not support {self.attn_backend} backend now.")
+
     def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
         if self.tp_size > 1:
             q = tensor_model_parallel_all_gather(q.contiguous())
@@ -211,11 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
         v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
 
-        x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
-        x = x.view(B, N, -1)
+        if self.attn_backend == _Backend.XFORMERS:
+            from xformers import ops as xops
 
-        x, _ = self.proj(x)
-        return x
+            out = xops.memory_efficient_attention_forward(q,
+                                                          k,
+                                                          v,
+                                                          scale=self.scale)
+        elif self.attn_backend == _Backend.TORCH_SDPA:
+            q, k, v = (x.transpose(1, 2) for x in (q, k, v))
+            out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
+            out = out.transpose(1, 2)
+
+        out = out.view(B, N, -1)
+        out, _ = self.proj(out)
+        return out
 
 
 class InternSdpaAttention(nn.Module):
@@ -362,7 +374,7 @@ def _init_attn(
         tp_size = get_tensor_model_parallel_world_size()
         num_heads = config.num_attention_heads
 
-        if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
+        if (num_heads + num_dummy_heads) % tp_size == 0:
             return InternParallelAttention(config,
                                            quant_config=quant_config,
                                            num_dummy_heads=num_dummy_heads,
@@ -469,10 +481,14 @@ def forward(
 
         return encoder_outputs
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 21fa6983063b8..94b819b5d9366 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -1,5 +1,5 @@
 from functools import partial
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -290,7 +290,7 @@ def forward(
             if inputs_embeds is not None:
                 hidden_states = inputs_embeds
             else:
-                hidden_states = self.tok_embeddings(input_ids)
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -335,6 +335,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -342,9 +345,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -364,13 +369,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w1", 0),
             ("gate_up_proj", "w3", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -397,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 92579e3aae949..7ea2f9be2191d 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -6,7 +6,7 @@
 # --------------------------------------------------------
 import re
 from functools import cached_property, partial
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -663,6 +663,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index 65800c44e5a93..41db85b678456 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -19,7 +19,7 @@
 """Inference-only Jais model compatible with HuggingFace weights."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -250,6 +250,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.n_embd))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -257,9 +260,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[IntermediateTensors, torch.Tensor]:
         if get_pp_group().is_first_rank:
-            inputs_embeds = self.wte(input_ids)
+            if inputs_embeds is None:
+                inputs_embeds = self.get_input_embeddings(input_ids)
             if self.wpe is not None:
                 position_embeds = self.wpe(position_ids)
                 hidden_states = inputs_embeds + position_embeds
@@ -311,6 +316,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -318,9 +326,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[IntermediateTensors, torch.Tensor]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -340,8 +350,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
@@ -372,3 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py
index 88fb8d5cf555a..f83f0fce7275f 100644
--- a/vllm/model_executor/models/jamba.py
+++ b/vllm/model_executor/models/jamba.py
@@ -1,5 +1,5 @@
 """Inference-only Jamba model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -292,6 +292,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.final_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -299,8 +302,12 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         mamba_cache_params: MambaCacheParams,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
-        hidden_states = self.embed_tokens(input_ids)
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.get_input_embeddings(input_ids)
         residual = None
         for i in range(len(self.layers)):
             layer = self.layers[i]
@@ -381,12 +388,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
                                                 config.vocab_size)
         self.sampler = get_sampler()
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(self,
                 input_ids: torch.Tensor,
                 positions: torch.Tensor,
                 kv_caches: List[KVCache],
                 attn_metadata: AttentionMetadata,
                 intermediate_tensors: Optional[IntermediateTensors] = None,
+                inputs_embeds: Optional[torch.Tensor] = None,
                 **kwargs):
         if self.mamba_cache is None:
             max_batch_size = (_get_graph_batch_size(
@@ -409,7 +420,8 @@ def forward(self,
                                               mamba_cache_tensors[1],
                                               state_indices_tensor)
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, mamba_cache_params)
+                                   attn_metadata, mamba_cache_params,
+                                   inputs_embeds)
         return hidden_states
 
     def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
@@ -450,7 +462,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -467,6 +480,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -522,6 +536,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 def _is_moe_layer(name: str):
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index e53631ef19f31..2b40e9ec73fad 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only LLaMA model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -350,7 +350,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -360,6 +361,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -375,6 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -390,7 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
-
                 break
             else:
                 # Skip loading extra bias for GPTQ models.
@@ -408,6 +410,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
@@ -577,13 +581,14 @@ def sample(self, logits: torch.Tensor,
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(
+        return loader.load_weights(
             self.maybe_remap_mistral(name, loaded_weight)
             for name, loaded_weight in weights)
 
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index b13bcfa676811..e7d3161a7cb2d 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
+from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
                     Tuple, TypedDict, Union)
 
 import torch
@@ -547,6 +547,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index dd2fa6cac969f..37e2227a52dcd 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -654,6 +654,7 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index 5d5598d07bfde..e2880c76cf43d 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -1,6 +1,6 @@
 import math
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -445,10 +445,11 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             # This model doesn't support images for now
             ignore_unexpected_prefixes=["image_newline"],
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index a5b2108177830..705ca1e4ab6e6 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -1,6 +1,6 @@
 import math
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -887,6 +887,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py
index 55c575e22a0f6..405b8f7787ba8 100644
--- a/vllm/model_executor/models/mamba.py
+++ b/vllm/model_executor/models/mamba.py
@@ -1,5 +1,5 @@
 """PyTorch MAMBA model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -106,15 +106,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.norm_f = RMSNorm(config.hidden_size,
                               eps=config.layer_norm_epsilon)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
         attn_metadata: AttentionMetadata,
         mamba_cache_params: MambaCacheParams,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
 
-        hidden_states = self.embeddings(input_ids)
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.get_input_embeddings(input_ids)
         residual = None
 
         for i in range(len(self.layers)):
@@ -168,12 +175,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
                                                 config.vocab_size)
         self.sampler = get_sampler()
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.backbone.get_input_embeddings(input_ids)
+
     def forward(self,
                 input_ids: torch.Tensor,
                 positions: torch.Tensor,
                 kv_caches: List[KVCache],
                 attn_metadata: AttentionMetadata,
                 intermediate_tensors: Optional[IntermediateTensors] = None,
+                inputs_embeds: Optional[torch.Tensor] = None,
                 **kwargs):
         if self.mamba_cache is None:
             max_batch_size = (_get_graph_batch_size(
@@ -194,7 +205,7 @@ def forward(self,
                                               state_indices_tensor)
 
         hidden_states = self.backbone(input_ids, positions, attn_metadata,
-                                      mamba_cache_params)
+                                      mamba_cache_params, inputs_embeds)
 
         return hidden_states
 
@@ -232,8 +243,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "A_log" in name:
                 name = name.replace("A_log", "A")
@@ -245,3 +258,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py
index de5b2d89c0962..b4ed6538bddac 100644
--- a/vllm/model_executor/models/medusa.py
+++ b/vllm/model_executor/models/medusa.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -14,11 +14,14 @@
 
 class ResidualBlock(nn.Module):
 
-    def __init__(self, hidden_size: int, num_layers: int) -> None:
+    def __init__(self, config: VllmConfig, hidden_size: int,
+                 num_layers: int) -> None:
         super().__init__()
 
         self.layers = nn.ModuleList([
-            nn.Linear(hidden_size, hidden_size, bias=False)
+            nn.Linear(hidden_size,
+                      hidden_size,
+                      bias=getattr(config, "medusa_fc_bias", False))
             for _ in range(num_layers)
         ])
         self.act = nn.SiLU()
@@ -49,7 +52,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
         super().__init__()
         self.config = config
         self.blocks = nn.ModuleList([
-            ResidualBlock(hidden_size=self.config.hidden_size,
+            ResidualBlock(config=config,
+                          hidden_size=self.config.hidden_size,
                           num_layers=self.config.num_hidden_layers)
             for _ in range(self.config.num_heads)
         ])
@@ -152,8 +156,10 @@ def generate_proposals(
             sampling_metadata=sampling_metadata,
         )
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         weights_map = {}
 
@@ -177,9 +183,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
 
         if self.token_map is not None:
             self.token_map.to(device=self.lm_heads[0].weight.device)
 
         assert (self.truncated_vocab_size
                 == self.orig_vocab_size) or (self.token_map is not None)
+
+        return loaded_params
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 2db953329fd91..b92bff4d7c28c 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -21,7 +21,7 @@
 # limitations under the License.
 """Inference-only MiniCPM model compatible with HuggingFace weights."""
 import math
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -504,6 +504,9 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.model = MiniCPMModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -511,9 +514,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -534,7 +539,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -551,6 +557,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for weight_name in ["w1", "w2", "w3"]
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -601,3 +608,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index fd8eda997f76f..99bf1d42d0355 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -24,7 +24,7 @@
 import re
 from functools import partial
 from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
-                    Tuple, TypedDict, Union)
+                    Set, Tuple, TypedDict, Union)
 
 import torch
 import torch.types
@@ -602,7 +602,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -612,6 +613,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                 if key_to_modify in name:
@@ -630,10 +632,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 for param_name, weight_name, shard_id in stacked_params_mapping:
                     if weight_name not in name:
                         continue
-                    if is_pp_missing_parameter(
-                            name.replace(weight_name, param_name), self):
+                    name = name.replace(weight_name, param_name)
+                    if is_pp_missing_parameter(name, self):
                         continue
-                    param = params_dict[name.replace(weight_name, param_name)]
+                    param = params_dict[name]
                     weight_loader = param.weight_loader
                     weight_loader(param, loaded_weight, shard_id)
                     break
@@ -646,6 +648,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     def get_mm_mapping(self) -> MultiModelKeys:
         """
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index 3eb2f60fd4fc7..0faffb4f1b00c 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -281,6 +281,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -288,9 +291,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -363,6 +370,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -370,9 +380,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -392,7 +404,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -409,6 +422,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -466,3 +480,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py
index 95cfb6f54dc10..ddd6afcf6a1b6 100644
--- a/vllm/model_executor/models/mixtral_quant.py
+++ b/vllm/model_executor/models/mixtral_quant.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
@@ -318,6 +318,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -325,9 +328,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -368,6 +375,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -375,9 +385,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -397,7 +409,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -406,6 +419,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -436,3 +450,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py
index db7ee7b2d8537..41f62b37f3bd9 100644
--- a/vllm/model_executor/models/mllama.py
+++ b/vllm/model_executor/models/mllama.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 """PyTorch Mllama model."""
 import math
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -1427,7 +1427,8 @@ def forward(
 
         return outputs
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1437,7 +1438,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        updated_params = set()
+        updated_params: Set[str] = set()
         for name, loaded_weight in weights:
             if 'patch_embedding.weight' in name:
                 name = name.replace('patch_embedding.weight',
@@ -1457,6 +1458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+                updated_params.add(name)
+        return updated_params
 
 
 def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py
index 4d7e82880041d..f2aa2653c4f5c 100644
--- a/vllm/model_executor/models/mlp_speculator.py
+++ b/vllm/model_executor/models/mlp_speculator.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Tuple
+from typing import Iterable, List, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -188,11 +188,15 @@ def generate_proposals(
 
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             param = params_dict.get(name.replace("speculator.", ""))
             if param is not None:
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+                loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index 035a1e2ab7b02..a7c90a3f5031b 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -187,7 +187,7 @@ def __init__(
         )
 
         # Detect attention implementation.
-        self.attn_backend: _Backend = get_vit_attn_backend()
+        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
         if self.attn_backend not in {
                 _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
         }:
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index e15c0fe8db060..8716e92b0f1c2 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -1,6 +1,6 @@
 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.d_model))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -244,9 +247,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.wte(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -283,6 +290,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.transformer.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -290,9 +300,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, intermediate_tensors)
+                                         attn_metadata, intermediate_tensors,
+                                         inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -312,8 +324,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             # Skip loading extra bias for GPTQ models.
             if name.endswith(".bias") and name not in params_dict:
@@ -324,3 +338,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index e09d7088a69ce..ceab299a7950a 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Nemotron model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -440,6 +440,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -447,9 +450,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         model_output = self.model(input_ids, positions, kv_caches,
-                                  attn_metadata, intermediate_tensors)
+                                  attn_metadata, intermediate_tensors,
+                                  inputs_embeds)
         return model_output
 
     def compute_logits(
@@ -469,7 +474,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -477,6 +483,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".qkv_proj", ".v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -517,3 +524,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 3467ae5896494..dc138e2e636ad 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OLMo model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -248,6 +248,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -255,17 +258,16 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         """
         :param input_ids: A tensor of shape `(batch_size, seq_len)`.
         """
         if get_pp_group().is_first_rank:
-            # Get embeddings of input.
-            # shape: (batch_size, seq_len, d_model)
-            inputs_embeds = self.embed_tokens(input_ids)
-
-            # embed positions
-            hidden_states = inputs_embeds
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -315,6 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -322,6 +327,7 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(
             input_ids=input_ids,
@@ -329,6 +335,7 @@ def forward(
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
             intermediate_tensors=intermediate_tensors,
+            inputs_embeds=inputs_embeds,
         )
         return hidden_states
 
@@ -349,7 +356,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -359,6 +367,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -395,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index 3d31919edd862..ab87695d8e650 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -10,7 +10,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OLMoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -269,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -276,9 +279,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -326,6 +333,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -333,9 +343,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(self, hidden_states: torch.Tensor,
@@ -352,7 +364,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -371,6 +384,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -443,3 +457,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 997fe642439e6..db85a494980a7 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -16,7 +16,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OPT model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -394,7 +394,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -402,6 +403,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name and self.config.tie_word_embeddings:
                 continue
@@ -431,3 +433,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index 38821c8288347..b01734af8ddd8 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -3,7 +3,7 @@
 # Copyright (c) OrionStar Inc.
 # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
 """Inference-only Orion-14B model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
                 "hidden_states",
             ], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -244,9 +247,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -286,6 +293,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -293,9 +303,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -315,7 +327,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -325,6 +338,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -356,3 +370,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index eea229359255e..dd5256eb87ab3 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -1,4 +1,4 @@
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -295,6 +295,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py
index 2e34a7cc30873..3b8199f4f1661 100644
--- a/vllm/model_executor/models/persimmon.py
+++ b/vllm/model_executor/models/persimmon.py
@@ -19,7 +19,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only persimmon model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -248,7 +251,7 @@ def forward(
             if inputs_embeds is not None:
                 hidden_states = inputs_embeds
             else:
-                hidden_states = self.embed_tokens(input_ids)
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -318,8 +324,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -352,3 +360,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index 262f6996fc374..0a117bf16c9b3 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -34,7 +34,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -225,9 +228,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -303,6 +310,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -310,9 +320,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
 
         return hidden_states
 
@@ -333,7 +345,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -341,6 +354,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v")
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
@@ -371,3 +385,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py
index 8a5fb6d303e60..f71cbd1264c45 100644
--- a/vllm/model_executor/models/phi3_small.py
+++ b/vllm/model_executor/models/phi3_small.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -54,12 +54,12 @@ def weight_loader(self, param: torch.nn.Parameter,
         return load_column_parallel_weight(param, loaded_weight)
 
 
-@torch.jit.script
+@torch.compile(dynamic=True)
 def quick_gelu(x):
     return x * torch.sigmoid(1.702 * x)
 
 
-@torch.jit.script
+@torch.compile(dynamic=True)
 def gegelu(input, limit: Optional[float] = None):
     a_gelu, a_linear = input[..., ::2], input[..., 1::2]
     if limit is not None:
@@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
-    def get_input_embeddings(self):
-        return self.embed_tokens
-
-    def set_input_embeddings(self, value):
-        self.embed_tokens = value
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
 
     def forward(
         self,
@@ -337,9 +334,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor],
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             if (self.mup_embedding_multiplier is not None
                     and self.mup_embedding_multiplier > 0.0):
                 hidden_states = hidden_states * self.mup_embedding_multiplier
@@ -397,8 +398,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         else:
             self.dummy_token_indices = None
 
-    def get_input_embeddings(self):
-        return self.model.embed_tokens
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
 
     def set_input_embeddings(self, value):
         self.model.embed_tokens = value
@@ -433,6 +434,7 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         output_hidden_states = self.model(
             input_ids=input_ids,
@@ -440,6 +442,7 @@ def forward(
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
             intermediate_tensors=intermediate_tensors,
+            inputs_embeds=inputs_embeds,
         )
         output_hidden_states = output_hidden_states
         return output_hidden_states
@@ -454,9 +457,11 @@ def sample(
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -468,3 +473,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index 4db65edc174f1..2e583bb08e87a 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -15,7 +15,7 @@
 import itertools
 import re
 from functools import cached_property, lru_cache
-from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
+from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                     Tuple, TypedDict, Union)
 
 import numpy as np
@@ -744,7 +744,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         hf_to_vllm_mapper = WeightsMapper(
             orig_to_new_prefix={
                 "model.vision_embed_tokens.wte": "embed_tokens",
@@ -759,5 +760,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
 
         # The HF config doesn't specify whether these are tied,
         # so we detect it this way
-        if "embed_tokens" not in autoloaded_weights:
+        if "embed_tokens.weight" not in autoloaded_weights:
             self.embed_tokens = self.language_model.model.embed_tokens
+            autoloaded_weights.add("embed_tokens.weight")
+        return autoloaded_weights
diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py
index 6d71a8949111b..e475d286bd7ea 100644
--- a/vllm/model_executor/models/phimoe.py
+++ b/vllm/model_executor/models/phimoe.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only PhiMoE model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -465,6 +465,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -472,9 +475,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -560,6 +567,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -567,9 +577,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(self, hidden_states: torch.Tensor,
@@ -586,7 +598,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -601,6 +614,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -654,3 +668,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index a3e30ea2dd299..f7f46770057e2 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -1,7 +1,7 @@
 from dataclasses import dataclass, fields
 from functools import cached_property
 from itertools import tee
-from typing import Iterable, List, Mapping, Optional, Tuple, Union
+from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
 
 import numpy
 import torch
@@ -17,6 +17,7 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import ModelConfig, VllmConfig
+from vllm.distributed import divide, get_tensor_model_parallel_world_size
 from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                          InputContext, token_inputs)
 from vllm.model_executor.layers.activation import get_act_and_mul_fn
@@ -30,6 +31,7 @@
 from vllm.model_executor.models.utils import merge_multimodal_embeddings
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
+from vllm.multimodal.inputs import PlaceholderRange
 from vllm.multimodal.utils import (cached_get_tokenizer,
                                    consecutive_placeholder_ranges)
 from vllm.sequence import IntermediateTensors, SequenceData
@@ -773,15 +775,28 @@ def input_processor_for_pixtral_hf(
         replace_tokens[-1] = image_end_id
         replace_tokens_list.append(replace_tokens)
 
+    reverse_offsets: List[int] = []
     # Backward iteration for replacement without affecting known indices
     for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
                                                reversed(replace_tokens_list)):
+        reverse_offsets.append(
+            len(new_token_ids) - placeholder_idx + len(replace_tokens))
         new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
 
+    placeholder_ranges: List[PlaceholderRange] = []
+    for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
+                                              replace_tokens_list):
+        placeholder_ranges.append(
+            PlaceholderRange(
+                offset=len(new_token_ids) - reverse_offset,
+                length=len(replace_tokens),
+            ))
+
     # NOTE: Create a defensive copy of the original inputs
     return token_inputs(prompt_token_ids=new_token_ids,
                         prompt=new_prompt,
-                        multi_modal_data=multi_modal_data)
+                        multi_modal_data=multi_modal_data,
+                        multi_modal_placeholders={"image": placeholder_ranges})
 
 
 class PixtralHFMLP(nn.Module):
@@ -829,17 +844,20 @@ def __init__(
 
         self.config = config
         assert not config.hidden_size % config.num_attention_heads
-        self.n_heads = config.num_attention_heads
+        self.total_num_heads = config.num_attention_heads
+        tp_size = get_tensor_model_parallel_world_size()
+        self.n_heads = divide(config.num_attention_heads, tp_size)
         self.head_dim = config.hidden_size // config.num_attention_heads
 
         self.qkv_proj = QKVParallelLinear(
             hidden_size=config.hidden_size,
             head_size=self.head_dim,
-            total_num_heads=self.n_heads,
+            total_num_heads=self.total_num_heads,
             bias=False,
             quant_config=quant_config,
             prefix=f"{prefix}.qkv_proj",
         )
+        assert self.total_num_heads * self.head_dim == config.hidden_size
         self.o_proj = RowParallelLinear(
             input_size=config.hidden_size,
             output_size=config.hidden_size,
@@ -1053,7 +1071,8 @@ def forward(
 
     # (TODO) Add prefix argument for filtering out weights to be loaded
     #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1063,6 +1082,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.transformer.layers)
 
         for name, loaded_weight in weights:
@@ -1075,8 +1095,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -1085,3 +1105,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 3d26ede722dd1..3978c176a2144 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -8,7 +8,7 @@
 import re
 from functools import partial
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Tuple, TypedDict, Union)
+                    Optional, Set, Tuple, TypedDict, Union)
 
 import numpy as np
 import torch
@@ -578,6 +578,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
                                         quant_config=quant_config) if hasattr(
                                             config, "visual") else None
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.wte(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -586,6 +589,7 @@ def forward(
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
         pixel_values: Optional[QwenImageInputs],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         img_pos = None
         # If pixel / visual embeddings are provided, this is a visual model
@@ -606,6 +610,10 @@ def forward(
                 )
 
         if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             hidden_states = self.wte(input_ids)
             # Merge the image embeddings into the hidden states if actually have
             # visual features and the corresponding image tokens
@@ -915,6 +923,9 @@ def _get_image_input_type(
                 )
         return None
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.transformer.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -922,7 +933,8 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
-        pixel_values: Optional[torch.Tensor] = None
+        pixel_values: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if intermediate_tensors is not None:
             input_ids = None
@@ -932,7 +944,7 @@ def forward(
 
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata, intermediate_tensors,
-                                         pixel_values)
+                                         pixel_values, inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -952,13 +964,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w2", 0),
             ("gate_up_proj", "w1", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -987,6 +1001,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class QWenLLM(QWenBaseModel):
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 431e397e1e10d..370cff5fa153f 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -21,7 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -309,7 +309,7 @@ def forward(
             if inputs_embeds is not None:
                 hidden_states = inputs_embeds
             else:
-                hidden_states = self.embed_tokens(input_ids)
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -332,7 +332,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -342,6 +343,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -372,6 +374,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
@@ -494,13 +498,14 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
 
 class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
@@ -564,7 +569,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index d30950361ad89..a4965f34b1ca8 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -20,7 +20,8 @@
 # limitations under the License.
 """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
 from functools import lru_cache
-from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
+from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
+                    Union)
 
 import librosa
 import numpy as np
@@ -420,7 +421,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -430,6 +432,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -463,3 +466,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py
index 120403e948686..dc5dabf6fc38b 100644
--- a/vllm/model_executor/models/qwen2_cls.py
+++ b/vllm/model_executor/models/qwen2_cls.py
@@ -4,7 +4,7 @@
 # Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 """Inference-only Qwen2-Classification model compatible with HF weights."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -72,6 +72,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             normalize=False,
             softmax=True)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -79,9 +82,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         logits, _ = self.score(hidden_states)
         return logits
 
@@ -92,7 +97,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 51c0cd5664fd2..96a9bc451f4df 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -21,7 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -344,6 +344,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -351,9 +354,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             assert intermediate_tensors is not None
@@ -395,6 +402,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -402,9 +412,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -424,7 +436,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -443,6 +456,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -520,3 +534,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py
index 55843d8325348..988d682d36be3 100644
--- a/vllm/model_executor/models/qwen2_rm.py
+++ b/vllm/model_executor/models/qwen2_rm.py
@@ -3,7 +3,7 @@
 # Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -85,6 +85,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -92,9 +95,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         logits, _ = self.score(hidden_states)
         return logits
 
@@ -105,7 +110,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 2335baf459771..a929b9323b245 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -23,7 +23,7 @@
 """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
 from functools import partial
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Tuple, Type, TypedDict, Union)
+                    Optional, Set, Tuple, Type, TypedDict, Union)
 
 import torch
 import torch.nn as nn
@@ -260,7 +260,7 @@ def __init__(
                                       prefix=f"{prefix}.proj")
 
         # Detect attention implementation.
-        self.attn_backend: _Backend = get_vit_attn_backend()
+        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
         if self.attn_backend not in {
                 _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
         }:
@@ -1333,7 +1333,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -1343,6 +1344,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "gate_proj", 0),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -1392,3 +1394,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 22c2e328bfb65..17ecd8001a3e5 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -21,7 +21,8 @@
 from vllm.platforms import current_platform
 
 from .interfaces import (has_inner_state, is_attention_free,
-                         supports_multimodal, supports_pp)
+                         supports_cross_encoding, supports_multimodal,
+                         supports_pp)
 from .interfaces_base import is_embedding_model, is_text_generation_model
 
 logger = init_logger(__name__)
@@ -121,6 +122,14 @@
     "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
 }
 
+_CROSS_ENCODER_MODELS = {
+    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
+    "RobertaForSequenceClassification": ("roberta",
+                                         "RobertaForSequenceClassification"),
+    "XLMRobertaForSequenceClassification": ("roberta",
+                                            "RobertaForSequenceClassification"),
+}
+
 _MULTIMODAL_MODELS = {
     # [Decoder-only]
     "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
@@ -159,6 +168,7 @@
 _VLLM_MODELS = {
     **_TEXT_GENERATION_MODELS,
     **_EMBEDDING_MODELS,
+    **_CROSS_ENCODER_MODELS,
     **_MULTIMODAL_MODELS,
     **_SPECULATIVE_DECODING_MODELS,
 }
@@ -193,6 +203,7 @@
 class _ModelInfo:
     is_text_generation_model: bool
     is_embedding_model: bool
+    supports_cross_encoding: bool
     supports_multimodal: bool
     supports_pp: bool
     has_inner_state: bool
@@ -203,6 +214,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
         return _ModelInfo(
             is_text_generation_model=is_text_generation_model(model),
             is_embedding_model=is_embedding_model(model),
+            supports_cross_encoding=supports_cross_encoding(model),
             supports_multimodal=supports_multimodal(model),
             supports_pp=supports_pp(model),
             has_inner_state=has_inner_state(model),
@@ -415,6 +427,12 @@ def is_embedding_model(
     ) -> bool:
         return self.inspect_model_cls(architectures).is_embedding_model
 
+    def is_cross_encoder_model(
+        self,
+        architectures: Union[str, List[str]],
+    ) -> bool:
+        return self.inspect_model_cls(architectures).supports_cross_encoding
+
     def is_multimodal_model(
         self,
         architectures: Union[str, List[str]],
@@ -489,4 +507,4 @@ def _run() -> None:
 
 
 if __name__ == "__main__":
-    _run()
\ No newline at end of file
+    _run()
diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py
index c1dcdd36ec3de..b6b17332e58d2 100644
--- a/vllm/model_executor/models/roberta.py
+++ b/vllm/model_executor/models/roberta.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
@@ -8,8 +8,14 @@
 from vllm.config import VllmConfig
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
-from vllm.sequence import IntermediateTensors
+from vllm.model_executor.models.interfaces import SupportsCrossEncoding
+from vllm.model_executor.models.utils import maybe_prefix
+from vllm.model_executor.pooling_metadata import (PoolingMetadata,
+                                                  PoolingTensors)
+from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
+                           PoolerOutput)
 
 
 class RobertaEmbedding(nn.Module):
@@ -39,34 +45,93 @@ def __init__(self, config: RobertaConfig):
     def forward(
         self,
         input_ids: torch.Tensor,
-        position_ids: Optional[torch.Tensor] = None,
+        seq_lens: torch.Tensor,
+        position_ids: torch.Tensor,
+        token_type_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         input_shape = input_ids.size()
-
-        # Input embeddings.
         inputs_embeds = self.word_embeddings(input_ids)
 
-        # TODO: figure out if there is a better way
-        # to make to make position ids start at padding_idx + 1
+        # Replace position ids because in RoBERTa models
+        # they have to start at padding_idx + 1 and ignore
+        # existing padding tokens
         # References:
         # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
         # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
-        position_ids += self.padding_idx + 1
+        pos_list = []
+        token_list = []
+        offset = 0
+        for seq_len in seq_lens:
+            pos_list.append(position_ids[offset:offset + seq_len])
+            token_list.append(input_ids[offset:offset + seq_len])
+            offset += seq_len
+
+        new_pos_list = []
+        for positions, tokens in zip(pos_list, token_list):
+            # Verify assumption that incoming position are
+            # always a sequence from 0 to N.
+            expected_pos = torch.arange(positions.size()[0],
+                                        dtype=torch.long,
+                                        device=inputs_embeds.device)
+            assert torch.equal(positions, expected_pos)
+            new_pos_list.append(
+                create_position_ids_from_input_ids(tokens, self.padding_idx))
+        position_ids = torch.cat(new_pos_list)
 
         # Position embeddings.
         position_embeddings = self.position_embeddings(position_ids)
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(input_shape,
+                                         dtype=torch.long,
+                                         device=inputs_embeds.device)
 
-        # Token type embeddings. (TODO: move off hotpath?)
-        token_type_embeddings = self.token_type_embeddings(
-            torch.zeros(input_shape,
-                        dtype=torch.long,
-                        device=inputs_embeds.device))
-
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
         embeddings = inputs_embeds + token_type_embeddings + position_embeddings
         embeddings = self.LayerNorm(embeddings)
         return embeddings
 
 
+# Adapted from transformers
+def create_position_ids_from_input_ids(input_ids,
+                                       padding_idx,
+                                       past_key_values_length=0):
+    """
+    Replace non-padding symbols with their position numbers.
+    Position numbers begin at padding_idx+1. Padding symbols
+    are ignored. This is modified from fairseq's `utils.make_positions`.
+
+    Args:
+        x: torch.Tensor x:
+
+    Returns: torch.Tensor
+    """
+    # The series of casts and type-conversions here are carefully
+    # balanced to both work with ONNX export and XLA.
+    mask = input_ids.ne(padding_idx).int()
+
+    incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
+                           past_key_values_length) * mask
+
+    return incremental_indices.long() + padding_idx
+
+
+# Adapted from transformers
+class RobertaClassificationHead(nn.Module):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(self, config: RobertaConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+    def forward(self, features, **kwargs):
+        x = features[0, :]  # take  token (equiv. to [CLS])
+        x = self.dense(x)
+        x = torch.tanh(x)
+        x = self.out_proj(x)
+        return x
+
+
 class RobertaEmbeddingModel(BertEmbeddingModel):
     """A model that uses Roberta to provide embedding functionalities.
 
@@ -85,6 +150,78 @@ def _build_model(self,
                          prefix=prefix,
                          embedding_class=RobertaEmbedding)
 
+
+class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
+    """A model that uses Roberta to provide embedding functionalities.
+
+   This class encapsulates the BertModel and provides an interface for
+   embedding operations and customized pooling functions.
+
+   Attributes:
+       roberta: An instance of BertModel used for forward operations.
+       _pooler: An instance of Pooler used for pooling operations.
+   """
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+
+        self.num_labels = config.num_labels
+        self.roberta = BertModel(vllm_config=vllm_config,
+                                 prefix=maybe_prefix(prefix, "bert"),
+                                 embedding_class=RobertaEmbedding,
+                                 add_pooling_layer=False)
+        self.classifier = RobertaClassificationHead(config)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        self_weights = []
+
+        def weight_filter():
+            for name, weight in weights:
+                if name.startswith("roberta."):
+                    yield (name[len("roberta."):], weight)
+                else:
+                    self_weights.append((name, weight))
+
+        self.roberta.load_weights(weight_filter())
+
+        params_dict = dict(self.named_parameters())
+
+        for name, loaded_weight in self_weights:
+            if name.startswith("classifier"):
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+
+        prompt_lens = PoolingTensors.from_pooling_metadata(
+            pooling_metadata, hidden_states.device).prompt_lens
+
+        offset = 0
+        pooled_data_lst = []
+        for prompt_len in prompt_lens:
+            pooled_data_i = hidden_states[offset:offset + prompt_len]
+
+            pooled_data_i = self.classifier(pooled_data_i)
+
+            pooled_data_lst.append(pooled_data_i)
+            offset += prompt_len
+
+        pooled_output = torch.stack(pooled_data_lst)
+
+        pooled_outputs = [
+            EmbeddingSequenceGroupOutput(data.tolist())
+            for data in pooled_output
+        ]
+        return PoolerOutput(outputs=pooled_outputs)
+
     def forward(
         self,
         input_ids: Optional[torch.Tensor],
@@ -93,25 +230,12 @@ def forward(
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
-
-        # Verify assumption that position are always a sequence from
-        # 0 to N. (Actually here we just check 0 and N to simplify).
-        # This is important to fix the position which are assumed to
-        # start from padding_idx + 1 instead of 0 in the Roberta models.
-        assert hasattr(attn_metadata, "seq_lens_tensor")
-        cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
-        start_pos = torch.cat(
-            (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
-             cumulative[:-1]))
-        assert len(torch.nonzero(positions[start_pos])) == 0
-        end_pos = cumulative - 1
-        last_tokens = attn_metadata.seq_lens_tensor - 1
-        assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
-
-        return super().forward(input_ids=input_ids,
-                               positions=positions,
-                               kv_caches=kv_caches,
-                               attn_metadata=attn_metadata,
-                               intermediate_tensors=intermediate_tensors,
-                               inputs_embeds=inputs_embeds)
+        return self.roberta(input_ids=input_ids,
+                            position_ids=positions,
+                            kv_caches=kv_caches,
+                            inputs_embeds=inputs_embeds,
+                            intermediate_tensors=intermediate_tensors,
+                            attn_metadata=attn_metadata,
+                            token_type_ids=token_type_ids)
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
index acaf4afdecfe5..c58ad99692900 100644
--- a/vllm/model_executor/models/siglip.py
+++ b/vllm/model_executor/models/siglip.py
@@ -2,15 +2,16 @@
 within a vision language model."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
+import torch.nn.functional as F
 from PIL import Image
 from torch import nn
 from transformers import SiglipVisionConfig
-from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
 
+from vllm.attention.selector import _Backend
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
 from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -27,11 +28,7 @@
                                    repeat_and_pad_placeholder_tokens)
 from vllm.sequence import SequenceData
 
-try:
-    from xformers import ops as xops
-    USE_XFORMERS_OPS = True
-except ImportError:
-    USE_XFORMERS_OPS = False
+from .utils import get_vit_attn_backend
 
 
 def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -254,7 +251,7 @@ def forward(self,
         return embeddings
 
 
-class SiglipParallelAttention(nn.Module):
+class SiglipAttention(nn.Module):
 
     def __init__(
         self,
@@ -293,6 +290,11 @@ def __init__(
         self.tp_size = get_tensor_model_parallel_world_size()
         self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
 
+        self.attn_backend = get_vit_attn_backend(support_fa=False)
+        if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
+            raise RuntimeError(
+                f"SIGLIP does not support {self.attn_backend} backend now.")
+
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -313,11 +315,26 @@ def forward(
                                          self.num_heads_per_partition,
                                          self.head_dim)
 
-        out = xops.memory_efficient_attention_forward(query_states,
-                                                      key_states,
-                                                      value_states,
-                                                      p=self.dropout,
-                                                      scale=self.scale)
+        if self.attn_backend == _Backend.XFORMERS:
+            from xformers import ops as xops
+
+            out = xops.memory_efficient_attention_forward(query_states,
+                                                          key_states,
+                                                          value_states,
+                                                          p=self.dropout,
+                                                          scale=self.scale)
+        elif self.attn_backend == _Backend.TORCH_SDPA:
+            query_states, key_states, value_states = (x.transpose(1, 2)
+                                                      for x in (query_states,
+                                                                key_states,
+                                                                value_states))
+            out = F.scaled_dot_product_attention(query_states,
+                                                 key_states,
+                                                 value_states,
+                                                 dropout_p=self.dropout,
+                                                 scale=self.scale)
+            out = out.transpose(1, 2)
+
         out = out.view(batch_size, q_len, -1)
         attn_output, _ = self.out_proj(out)
 
@@ -372,17 +389,11 @@ def __init__(
 
         self.embed_dim = config.hidden_size
 
-        num_heads = config.num_attention_heads
-        tp_size = get_tensor_model_parallel_world_size()
-        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = SiglipParallelAttention(
-                config,
-                quant_config=quant_config,
-                prefix=f"{prefix}.self_attn",
-            )
-        else:
-            self.self_attn = SiglipSdpaAttention(config)
-
+        self.self_attn = SiglipAttention(
+            config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.self_attn",
+        )
         self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                         eps=config.layer_norm_eps)
         self.mlp = SiglipMLP(
@@ -569,10 +580,6 @@ def __init__(
     ) -> None:
         super().__init__()
 
-        num_heads = config.num_attention_heads
-        tp_size = get_tensor_model_parallel_world_size()
-        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
-
         self.vision_model = SiglipVisionTransformer(
             config,
             quant_config,
@@ -594,14 +601,16 @@ def forward(
             interpolate_pos_encoding=interpolate_pos_encoding,
         )
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
-        ] if self.shard_weight else []
+        ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.vision_model.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -619,8 +628,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
+                name = name.replace(weight_name, param_name)
 
-                param = params_dict[name.replace(weight_name, param_name)]
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -629,3 +639,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index 4f03ca501fb68..6d6fafc5ab0eb 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -21,7 +21,7 @@
 # limitations under the License.
 """Inference-only Solar model compatible with HuggingFace weights."""
 
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -456,9 +456,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         model_output = self.model(input_ids, positions, kv_caches,
-                                  attn_metadata, intermediate_tensors)
+                                  attn_metadata, intermediate_tensors,
+                                  inputs_embeds)
         return model_output
 
     def compute_logits(self, hidden_states: torch.Tensor,
@@ -475,7 +477,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -485,6 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -500,6 +504,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -533,6 +538,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 1125f9e9f9617..e11d2e916730a 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -18,7 +18,7 @@
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
 """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -225,9 +228,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -265,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -272,9 +282,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -294,7 +306,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -304,6 +317,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -335,3 +349,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index ce7a7957f52c4..74c66042226de 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -17,7 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """ PyTorch Starcoder2 model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -221,6 +221,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(["hidden_states"],
                                                     config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -228,9 +231,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
         else:
             assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
@@ -273,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -280,9 +290,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -302,7 +314,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -311,6 +324,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -334,3 +348,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 9fde22c016de0..512adbc7db35e 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -3,7 +3,7 @@
 
 import math
 from functools import cached_property, lru_cache
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union, cast)
 
 import numpy as np
@@ -504,10 +504,11 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         hf_to_vllm_mapper = WeightsMapper(
             orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
 
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["audio_tower."])
-        loader.load_weights(weights, mapper=hf_to_vllm_mapper)
+        return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index 1d51885f9094a..03226f42ee053 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -1,7 +1,7 @@
 import itertools
 from dataclasses import dataclass, field
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Protocol, Tuple, Union, overload)
+                    Optional, Protocol, Set, Tuple, Union, overload)
 
 import torch
 import torch.nn as nn
@@ -172,8 +172,9 @@ def _load_module(
         if module != self.module:
             module_load_weights = getattr(module, "load_weights", None)
             if callable(module_load_weights):
-                module_load_weights(weights)
-                return
+                loaded_params = module_load_weights(weights)
+                yield from map(lambda x: self._get_qualname(base_prefix, x),
+                               loaded_params)
 
         child_modules = dict(module.named_children())
         child_params = dict(module.named_parameters(recurse=False))
@@ -222,11 +223,11 @@ def load_weights(
         weights: Iterable[Tuple[str, torch.Tensor]],
         *,
         mapper: Optional[WeightsMapper] = None,
-    ) -> List[str]:
+    ) -> Set[str]:
         if mapper is not None:
             weights = mapper.apply(weights)
 
-        autoloaded_weights = list(self._load_module("", self.module, weights))
+        autoloaded_weights = set(self._load_module("", self.module, weights))
         return autoloaded_weights
 
 
@@ -586,7 +587,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
         return llm(*args, **kwargs)
 
 
-def get_vit_attn_backend() -> _Backend:
+def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
+    """
+    Get the available attention backend for Vision Transformer.
+    """
+    # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
     selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
     if selected_backend is None:
         backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
@@ -595,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
     if selected_backend is None:
         # For Volta and Turing GPUs, use xformers instead.
         device_available = current_platform.has_device_capability(80)
-        if device_available:
+        if device_available and support_fa:
             from transformers.utils import is_flash_attn_2_available
             if is_flash_attn_2_available():
                 selected_backend = _Backend.FLASH_ATTN
@@ -605,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
                     "so we use xformers backend instead. You can run "
                     "`pip install flash-attn` to use flash-attention backend.")
                 selected_backend = _Backend.XFORMERS
-        elif current_platform.is_cpu():
+        elif current_platform.is_cpu() or current_platform.is_rocm():
+            # ROCM doesn't support xformers
             selected_backend = _Backend.TORCH_SDPA
         else:
             selected_backend = _Backend.XFORMERS
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 153527da20d75..bc37a997eabb5 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -19,7 +19,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Xverse model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -252,6 +252,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
             make_empty_intermediate_tensors_factory(
                 ["hidden_states", "residual"], config.hidden_size))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -259,9 +262,13 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors],
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
-            hidden_states = self.embed_tokens(input_ids)
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
             residual = None
         else:
             hidden_states = intermediate_tensors["hidden_states"]
@@ -335,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -342,9 +352,11 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, intermediate_tensors)
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
         return hidden_states
 
     def compute_logits(
@@ -364,7 +376,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
@@ -373,6 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if ("rotary_emb.inv_freq" in name
                     or "rotary_emb.cos_cached" in name
@@ -401,3 +415,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index 64a4c58d5509c..4035a87231712 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -6,7 +6,7 @@
 import torch
 import torch.types
 from PIL.Image import Image
-from typing_extensions import TypeAlias
+from typing_extensions import NotRequired, TypeAlias
 
 from vllm.utils import JSONTree, is_list_of, json_map_leaves
 
@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
     prompt_token_ids: List[int]
     """The processed token IDs which includes placeholder tokens."""
 
+    token_type_ids: NotRequired[List[int]]
+    """The token type IDs of the prompt."""
+
     mm_kwargs: MultiModalKwargs
     """Keyword arguments to be directly passed to the model after batching."""
 
diff --git a/vllm/outputs.py b/vllm/outputs.py
index badf50d0602d6..4ae9b377ae693 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -5,6 +5,7 @@
 from typing import Union
 
 from vllm.lora.request import LoRARequest
+from vllm.multimodal.inputs import MultiModalPlaceholderDict
 from vllm.sampling_params import RequestOutputKind
 from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
                            SequenceGroup, SequenceGroupBase, SequenceStatus)
@@ -103,10 +104,13 @@ def __init__(
         encoder_prompt: Optional[str] = None,
         encoder_prompt_token_ids: Optional[List[int]] = None,
         num_cached_tokens: Optional[int] = None,
+        *,
+        multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
     ) -> None:
         self.request_id = request_id
         self.prompt = prompt
         self.prompt_token_ids = prompt_token_ids
+        self.multi_modal_placeholders = multi_modal_placeholders or {}
         self.prompt_logprobs = prompt_logprobs
         self.outputs = outputs
         self.finished = finished
@@ -275,17 +279,26 @@ def from_seq_group(
         finished_time = time.time() if finished else None
         seq_group.set_finished_time(finished_time)
 
-        init_args = (seq_group.request_id, prompt, prompt_token_ids,
-                     prompt_logprobs, outputs, finished, seq_group.metrics,
-                     seq_group.lora_request, encoder_prompt,
-                     encoder_prompt_token_ids, num_cached_tokens)
+        init_kwargs = {
+            "request_id": seq_group.request_id,
+            "prompt": prompt,
+            "prompt_token_ids": prompt_token_ids,
+            "prompt_logprobs": prompt_logprobs,
+            "outputs": outputs,
+            "finished": finished,
+            "metrics": seq_group.metrics,
+            "lora_request": seq_group.lora_request,
+            "encoder_prompt": encoder_prompt,
+            "encoder_prompt_token_ids": encoder_prompt_token_ids,
+            "num_cached_tokens": num_cached_tokens,
+            "multi_modal_placeholders": seq_group.multi_modal_placeholders
+        }
 
         if use_cache:
             request_output = seq_group.cached_request_output
-            request_output.__init__(*init_args)  # type: ignore
-
+            request_output.__init__(**init_kwargs)  # type: ignore
         else:
-            request_output = cls(*init_args)
+            request_output = cls(**init_kwargs)  # type: ignore
 
         return request_output
 
@@ -300,7 +313,8 @@ def __repr__(self) -> str:
                 f"finished={self.finished}, "
                 f"metrics={self.metrics}, "
                 f"lora_request={self.lora_request}, "
-                f"num_cached_tokens={self.num_cached_tokens})")
+                f"num_cached_tokens={self.num_cached_tokens}, "
+                f"multi_modal_placeholders={self.multi_modal_placeholders})")
 
 
 class EmbeddingRequestOutput:
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index 5243f59203afc..42bee31dfb0e9 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -1,8 +1,19 @@
+from typing import TYPE_CHECKING
+
 import psutil
 import torch
 
+from vllm.logger import init_logger
+
 from .interface import Platform, PlatformEnum
 
+if TYPE_CHECKING:
+    from vllm.config import VllmConfig
+else:
+    VllmConfig = None
+
+logger = init_logger(__name__)
+
 
 class CpuPlatform(Platform):
     _enum = PlatformEnum.CPU
@@ -18,3 +29,52 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
     @classmethod
     def inference_mode(cls):
         return torch.no_grad()
+
+    @classmethod
+    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
+        import vllm.envs as envs
+        from vllm.utils import GiB_bytes
+        model_config = vllm_config.model_config
+        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
+        # If the feature combo become valid
+        if not model_config.enforce_eager:
+            logger.warning(
+                "CUDA graph is not supported on CPU, fallback to the eager "
+                "mode.")
+            model_config.enforce_eager = True
+
+        cache_config = vllm_config.cache_config
+
+        if cache_config.enable_prefix_caching:
+            logger.warning(
+                "Prefix caching is not supported on CPU, disable it.")
+            cache_config.enable_prefix_caching = False
+
+        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
+
+        if kv_cache_space >= 0:
+            if kv_cache_space == 0:
+                cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
+                logger.warning(
+                    "Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
+                    "for CPU backend is not set, using 4 by default.")
+            else:
+                cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore # noqa
+        else:
+            raise RuntimeError(
+                "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
+                f" {kv_cache_space}, expect a positive integer value.")
+
+        scheduler_config = vllm_config.scheduler_config
+        if scheduler_config.chunked_prefill_enabled:
+            logger.warning(
+                "Chunked prefill is not supported on CPU, disable it.")
+            scheduler_config.chunked_prefill_enabled = False
+
+        parallel_config = vllm_config.parallel_config
+        if (parallel_config.distributed_executor_backend is not None
+                and parallel_config.distributed_executor_backend != "mp"):
+            logger.warning(("%s is not supported on CPU, fallback to mp "
+                            "distributed executor backend."),
+                           parallel_config.distributed_executor_backend)
+            parallel_config.distributed_executor_backend = "mp"
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 81d8bdae2383c..970c0d1be617e 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -1,10 +1,15 @@
 import enum
 import random
-from typing import NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
 
 import numpy as np
 import torch
 
+if TYPE_CHECKING:
+    from vllm.config import VllmConfig
+else:
+    VllmConfig = None
+
 
 class PlatformEnum(enum.Enum):
     CUDA = enum.auto()
@@ -129,6 +134,19 @@ def seed_everything(cls, seed: int) -> None:
         np.random.seed(seed)
         torch.manual_seed(seed)
 
+    @classmethod
+    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
+        """
+        Check and update the configuration for the current platform.
+
+        It can raise an exception if the configuration is not compatible with
+        the current platform, or it can update the configuration to make it
+        compatible with the current platform.
+
+        The config is passed by reference, so it can be modified in place.
+        """
+        pass
+
 
 class UnspecifiedPlatform(Platform):
     _enum = PlatformEnum.UNSPECIFIED
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index 8d0ce47df4040..643db835c85ff 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -1,20 +1,14 @@
 import os
+from typing import TYPE_CHECKING
 
 import torch
 
-import vllm.envs as envs
-from vllm.compilation.levels import CompilationLevel
-from vllm.plugins import set_torch_compile_backend
-
 from .interface import Platform, PlatformEnum
 
-if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
-    os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
-
-assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
-     "TPU does not support Inductor."
-
-set_torch_compile_backend("openxla")
+if TYPE_CHECKING:
+    from vllm.config import VllmConfig
+else:
+    VllmConfig = None
 
 
 class TpuPlatform(Platform):
@@ -31,3 +25,15 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
     @classmethod
     def inference_mode(cls):
         return torch.no_grad()
+
+    @classmethod
+    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
+        from vllm.config import CompilationLevel
+        compilation_config = vllm_config.compilation_config
+        if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
+            compilation_config.level = CompilationLevel.DYNAMO_ONCE
+        assert compilation_config.level < CompilationLevel.PIECEWISE,\
+            "TPU does not support Inductor."
+
+        if compilation_config.backend == "":
+            compilation_config.backend = "openxla"
diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py
index 8373e11cfff9f..a0c73a752b5e8 100644
--- a/vllm/plugins/__init__.py
+++ b/vllm/plugins/__init__.py
@@ -1,11 +1,11 @@
 import logging
-from typing import TYPE_CHECKING, Callable, Optional, Union
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Optional
 
 import vllm.envs as envs
 
 if TYPE_CHECKING:
-    from vllm.compilation.config import CompilationConfig
-    from vllm.config import VllmConfig
+    from vllm.config import CompilationConfig, VllmConfig
 else:
     CompilationConfig = None
     VllmConfig = None
@@ -27,28 +27,27 @@ def load_general_plugins():
     allowed_plugins = envs.VLLM_PLUGINS
 
     discovered_plugins = entry_points(group='vllm.general_plugins')
+    if len(discovered_plugins) == 0:
+        logger.info("No plugins found.")
+        return
+    logger.info("Available plugins:")
+    for plugin in discovered_plugins:
+        logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value,
+                    plugin.group)
+    if allowed_plugins is None:
+        logger.info("all available plugins will be loaded.")
+        logger.info("set environment variable VLLM_PLUGINS to control"
+                    " which plugins to load.")
+    else:
+        logger.info("plugins to load: %s", allowed_plugins)
     for plugin in discovered_plugins:
-        logger.info("Found general plugin: %s", plugin.name)
         if allowed_plugins is None or plugin.name in allowed_plugins:
             try:
                 func = plugin.load()
                 func()
-                logger.info("Loaded general plugin: %s", plugin.name)
+                logger.info("plugin %s loaded.", plugin.name)
             except Exception:
-                logger.exception("Failed to load general plugin: %s",
-                                 plugin.name)
-
-
-_torch_compile_backend: Optional[Union[Callable, str]] = None
-
-
-def set_torch_compile_backend(backend: Union[Callable, str]):
-    global _torch_compile_backend
-    _torch_compile_backend = backend
-
-
-def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
-    return _torch_compile_backend
+                logger.exception("Failed to load plugin %s", plugin.name)
 
 
 _compilation_config: Optional[CompilationConfig] = None
@@ -61,3 +60,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
 
 def get_compilation_config() -> Optional[CompilationConfig]:
     return _compilation_config
+
+
+_current_vllm_config: Optional[VllmConfig] = None
+
+
+@contextmanager
+def set_current_vllm_config(vllm_config: VllmConfig):
+    """
+    Temporarily set the current VLLM config.
+    Used during model initialization.
+    We save the current VLLM config in a global variable,
+    so that all modules can access it, e.g. custom ops
+    can access the VLLM config to determine how to dispatch.
+    """
+    global _current_vllm_config
+    old_vllm_config = _current_vllm_config
+    try:
+        _current_vllm_config = vllm_config
+        yield
+    finally:
+        _current_vllm_config = old_vllm_config
+
+
+def get_current_vllm_config() -> VllmConfig:
+    assert _current_vllm_config is not None, "Current VLLM config is not set."
+    return _current_vllm_config
diff --git a/vllm/scripts.py b/vllm/scripts.py
index 4e4c071784287..a51c21cfa29e7 100644
--- a/vllm/scripts.py
+++ b/vllm/scripts.py
@@ -9,6 +9,7 @@
 from openai import OpenAI
 from openai.types.chat import ChatCompletionMessageParam
 
+import vllm.version
 from vllm.engine.arg_utils import EngineArgs
 from vllm.entrypoints.openai.api_server import run_server
 from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@@ -143,6 +144,11 @@ def main():
     env_setup()
 
     parser = FlexibleArgumentParser(description="vLLM CLI")
+    parser.add_argument('-v',
+                        '--version',
+                        action='version',
+                        version=vllm.version.__version__)
+
     subparsers = parser.add_subparsers(required=True, dest="subparser")
 
     serve_parser = subparsers.add_parser(
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 3b41d25a2fe42..16c8d8654a6a9 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -449,6 +449,10 @@ def prompt_token_ids(self) -> List[int]:
     def prompt_embeds(self) -> Optional[torch.Tensor]:
         return self.inputs.prompt_embeds
 
+    @property
+    def token_type_ids(self) -> List[int]:
+        return self.inputs.token_type_ids
+
     @property
     def multi_modal_data(self) -> "MultiModalDataDict":
         return self.inputs.multi_modal_data
@@ -684,6 +688,10 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]:
         return (self.encoder_seq.prompt_token_ids
                 if self.encoder_seq is not None else None)
 
+    @property
+    def token_type_ids(self) -> Optional[List[int]]:
+        return self.first_seq.token_type_ids
+
     @property
     def multi_modal_data(self) -> MultiModalDataDict:
         return self.first_seq.multi_modal_data
@@ -906,6 +914,7 @@ class SequenceGroupMetadata(
         default_factory=lambda: SequenceGroupState())
     # "MultiModalDataDict" types. We have to use Any due to msgspec
     # doesn't allow to have union of 2 different dicts.
+    token_type_ids: Optional[List[int]] = None
     multi_modal_data: Optional[Any] = None
     multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
     mm_processor_kwargs: Optional[Dict[str, Any]] = None
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index 83b3c37d6f04c..d4e195118ccf7 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -7,6 +7,7 @@
 import huggingface_hub
 from huggingface_hub import HfApi, hf_hub_download
 from mistral_common.protocol.instruct.request import ChatCompletionRequest
+from mistral_common.tokens.instruct.request import FIMRequest
 from mistral_common.tokens.tokenizers.base import SpecialTokens
 # yapf: disable
 from mistral_common.tokens.tokenizers.mistral import (
@@ -251,6 +252,10 @@ def encode(self, prompt: str) -> List[int]:
         # For chat completion use `apply_chat_template`
         return self.tokenizer.encode(prompt, bos=True, eos=False)
 
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        fim = FIMRequest(prompt=prefix, suffix=suffix)
+        return self.mistral.encode_fim(fim).tokens
+
     def apply_chat_template(self,
                             messages: List["ChatCompletionMessageParam"],
                             tools: Optional[Dict[str, Any]] = None,
diff --git a/vllm/utils.py b/vllm/utils.py
index 111460a29de47..5d0514cd9d168 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -1600,3 +1600,12 @@ def direct_register_custom_op(
     my_lib.impl(op_name, op_func, "CUDA")
     if fake_impl is not None:
         my_lib._register_fake(op_name, fake_impl)
+
+
+def resolve_obj_by_qualname(qualname: str) -> Any:
+    """
+    Resolve an object by its fully qualified name.
+    """
+    module_name, obj_name = qualname.rsplit(".", 1)
+    module = importlib.import_module(module_name)
+    return getattr(module, obj_name)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index eebd1de96537f..d60f93a44f6dd 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -1,4 +1,3 @@
-import os
 import time
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
@@ -8,11 +7,8 @@
 import torch.distributed
 import torch.nn as nn
 
-from vllm import envs
 from vllm.compilation.compile_context import set_compile_context
-from vllm.compilation.config import CompilationConfig
-from vllm.compilation.levels import CompilationLevel
-from vllm.config import VllmConfig
+from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
 from vllm.forward_context import set_forward_context
 from vllm.inputs import INPUT_REGISTRY, InputRegistry
 from vllm.logger import init_logger
@@ -99,7 +95,7 @@ def __init__(
             pin_memory=self.pin_memory,
         )
 
-        self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
+        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                                == CompilationLevel.PIECEWISE
                                and not self.model_config.enforce_eager)
         # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@@ -517,9 +513,9 @@ def load_model(self) -> None:
             # CUDA graphs do not work properly with the custom CUDA kernels.
             # FIXME(woosuk): Disable inductor to reduce the compilation time
             # and avoid any potential issues with the inductor.
-            os.environ["VLLM_CUSTOM_OPS"] = "none"
             set_compilation_config(
                 CompilationConfig(
+                    custom_ops=["none"],
                     use_cudagraph=True,
                     non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
                     use_inductor=True,
diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py
index 7053075bf4d8f..941e9910413fd 100644
--- a/vllm/worker/cpu_embedding_model_runner.py
+++ b/vllm/worker/cpu_embedding_model_runner.py
@@ -49,6 +49,9 @@ def execute_model(
         ]
 
         model_executable = self.model
+        cross_enc_kwargs = {}
+        if model_input.token_type_ids is not None:
+            cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
         execute_model_kwargs = {
             "input_ids":
             model_input.input_tokens,
@@ -60,12 +63,17 @@ def execute_model(
             model_input.attn_metadata,
             **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
                                          device=self.device),
+            **cross_enc_kwargs,
             "intermediate_tensors":
             intermediate_tensors,
         }
 
         hidden_states = model_executable(**execute_model_kwargs)
 
+        # Only perform pooling in the driver worker.
+        if not self.is_driver_worker:
+            return []
+
         return [
             self.model.pooler(hidden_states=hidden_states,
                               pooling_metadata=model_input.pooling_metadata)
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index d3e1202c15e61..bb1d5e58c3c8d 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
     """
     input_tokens: Optional[torch.Tensor] = None
     input_positions: Optional[torch.Tensor] = None
+    token_type_ids: Optional[torch.Tensor] = None
     attn_metadata: Optional["AttentionMetadata"] = None
     multi_modal_kwargs: Optional[BatchedTensorInputs] = None
     virtual_engine: Optional[int] = None
@@ -54,6 +55,7 @@ def as_broadcastable_tensor_dict(
         tensor_dict = {
             "input_tokens": self.input_tokens,
             "input_positions": self.input_positions,
+            "token_type_ids": self.token_type_ids,
             "multi_modal_kwargs": self.multi_modal_kwargs,
         }
         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
@@ -83,6 +85,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
         tensor_dict = {
             "input_tokens": self.input_tokens,
             "input_positions": self.input_positions,
+            "token_type_ids": self.token_type_ids,
         }
         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
         _add_sampling_metadata_broadcastable_dict(tensor_dict,
@@ -127,18 +130,20 @@ def build(self) -> ModelInputForCPU:
         is_prompt = self.seq_group_metadata_list[0].is_prompt
         # Prepare input tensors.
         if is_prompt:
-            (input_tokens, input_positions, attn_metadata, seq_lens,
-             multi_modal_kwargs) = self._prepare_prompt(
+            (input_tokens, input_positions, token_type_ids, attn_metadata,
+             seq_lens, multi_modal_kwargs) = self._prepare_prompt(
                  self.seq_group_metadata_list)
         else:
             (input_tokens, input_positions,
              attn_metadata) = self._prepare_decode(
                  self.seq_group_metadata_list)
             seq_lens = None
+            token_type_ids = None
 
         return self.model_input_cls(
             input_tokens=input_tokens,
             input_positions=input_positions,
+            token_type_ids=token_type_ids,
             attn_metadata=attn_metadata,
             multi_modal_kwargs=multi_modal_kwargs,
             # query_lens is not needed if chunked prefill is not
@@ -203,11 +208,12 @@ def _compute_multi_modal_input(
     def _prepare_prompt(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
-               BatchedTensorInputs]:
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, AttentionMetadata,
+               List[int], BatchedTensorInputs]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[int] = []
         input_positions: List[int] = []
+        token_type_ids: List[int] = []
         input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
 
         slot_mapping: List[int] = []
@@ -225,11 +231,13 @@ def _prepare_prompt(
 
             seq_data = seq_group_metadata.seq_data[seq_id]
             prompt_tokens = seq_data.get_token_ids()
+            token_types = seq_group_metadata.token_type_ids
             computed_len = seq_data.get_num_computed_tokens()
             seq_len = len(prompt_tokens)
 
             seq_lens.append(seq_len)  # Prompt token num
             input_tokens.extend(prompt_tokens)  # Token ids
+            token_type_ids.extend(token_types if token_types else [])
 
             mrope_positions = None
             if seq_group_metadata.multi_modal_data:
@@ -293,6 +301,11 @@ def _prepare_prompt(
                                        or input_mrope_positions,
                                        dtype=torch.long,
                                        device=self.device)  # type: ignore
+        token_type_ids = torch.tensor(token_type_ids,
+                                dtype=torch.long,
+                                device=self.device)\
+                                if token_type_ids else None
+
         slot_mapping = torch.tensor(slot_mapping,
                                     dtype=torch.long,
                                     device=self.device)  # type: ignore
@@ -317,8 +330,8 @@ def _prepare_prompt(
 
         multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
 
-        return (input_tokens, input_positions, attn_metadata, seq_lens,
-                multi_modal_kwargs)
+        return (input_tokens, input_positions, token_type_ids, attn_metadata,
+                seq_lens, multi_modal_kwargs)
 
     def _prepare_decode(
         self,
diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py
index 37cfcbf13d7a3..19951c4787cf9 100644
--- a/vllm/worker/embedding_model_runner.py
+++ b/vllm/worker/embedding_model_runner.py
@@ -97,6 +97,10 @@ def execute_model(
             model_forward_end = torch.cuda.Event(enable_timing=True)
             model_forward_start.record()
 
+        cross_enc_kwargs = {}
+        if model_input.token_types is not None:
+            cross_enc_kwargs["token_type_ids"] = model_input.token_types
+
         with set_forward_context(model_input.attn_metadata):
             hidden_or_intermediate_states = model_executable(
                 input_ids=model_input.input_tokens,
@@ -105,7 +109,8 @@ def execute_model(
                 attn_metadata=model_input.attn_metadata,
                 intermediate_tensors=intermediate_tensors,
                 **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
-                                             device=self.device))
+                                             device=self.device),
+                **cross_enc_kwargs)
 
         if (self.observability_config is not None
                 and self.observability_config.collect_model_forward_time):
diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py
index 1ff30d685c6b1..99cf9a7e67256 100644
--- a/vllm/worker/hpu_model_runner.py
+++ b/vllm/worker/hpu_model_runner.py
@@ -272,6 +272,19 @@ def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
     return indices, offsets
 
 
+def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
+    if module.__class__.__name__.endswith(suffix):
+
+        def forward_hook(module, args, output):
+            htorch.core.mark_step()
+            return output
+
+        module.register_forward_hook(forward_hook)
+
+    for child_name, child_module in module.named_children():
+        modify_decoder_layer(child_module)
+
+
 class HpuModelAdapter:
 
     def __init__(self, model, block_size, dtype, enforce_eager):
@@ -636,6 +649,7 @@ def load_model(self) -> None:
             else:
                 self.model = self.model.to("hpu")
                 htcore.mark_step()
+            modify_decoder_layer(self.model)
             torch.hpu.synchronize()
 
             with HabanaMemoryProfiler() as m_wrap:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 042f9f07eace6..c97ac33490c40 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -19,8 +19,7 @@
 from vllm.attention.backends.abstract import AttentionState
 from vllm.attention.backends.utils import CommonAttentionState
 from vllm.compilation.compile_context import set_compile_context
-from vllm.compilation.levels import CompilationLevel
-from vllm.config import VllmConfig
+from vllm.config import CompilationLevel, VllmConfig
 from vllm.core.scheduler import SchedulerOutputs
 from vllm.distributed import get_pp_group
 from vllm.distributed.parallel_state import graph_capture
@@ -93,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
     """
     input_tokens: Optional[torch.Tensor] = None
     input_positions: Optional[torch.Tensor] = None
+    token_types: Optional[torch.Tensor] = None
     seq_lens: Optional[List[int]] = None
     query_lens: Optional[List[int]] = None
     lora_mapping: Optional["LoRAMapping"] = None
@@ -201,6 +201,7 @@ class InterDataForSeqGroup:
         def simple_reinit(self):
             self.input_tokens[0].clear()  # type: ignore
             self.input_positions[0].clear()  # type: ignore
+            self.token_types[0].clear()  # type: ignore
             self.mrope_input_positions = None  # type: ignore
             self.seq_lens[0] = 0  # type: ignore
             self.orig_seq_lens[0] = 0  # type: ignore
@@ -227,6 +228,7 @@ def __init__(
             # Input tokens and positions.
             input_tokens: Optional[List[List[int]]] = None,
             input_positions: Optional[List[List[int]]] = None,
+            token_types: Optional[List[List[int]]] = None,
             mrope_input_positions: Optional[List[List[List[int]]]] = None,
 
             # The sequence length (may be capped to the sliding window).
@@ -292,6 +294,12 @@ def __init__(
                         for seq_id in range(len(self.seq_ids)):
                             self.input_positions[seq_id].clear()
 
+                    if token_types:
+                        self.token_types = token_types
+                    else:
+                        for seq_id in range(len(self.seq_ids)):
+                            self.token_types[seq_id].clear()
+
                     self.mrope_input_positions = None
 
                     if seq_lens:
@@ -355,6 +363,7 @@ def __init__(
             else:
                 self.input_tokens = input_tokens or []
                 self.input_positions = input_positions or []
+                self.token_types = token_types or []
                 self.mrope_input_positions = mrope_input_positions or None
                 self.seq_lens = seq_lens or []
                 self.orig_seq_lens = orig_seq_lens or []
@@ -387,6 +396,7 @@ def __post_init__(self):
 
             self.input_tokens = [[] for _ in range(self.n_seqs)]
             self.input_positions = [[] for _ in range(self.n_seqs)]
+            self.token_types = [[] for _ in range(self.n_seqs)]
             self.mrope_input_positions = None
             self.seq_lens = [0] * self.n_seqs
             self.orig_seq_lens = [0] * self.n_seqs
@@ -499,12 +509,15 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
 
         # Compute tokens.
         tokens = seq_data.get_token_ids()[context_len:seq_len]
+        token_types = seq_group_metadata.token_type_ids
 
         inter_data.seq_lens[seq_idx] = seq_len
         inter_data.orig_seq_lens[seq_idx] = seq_len
         inter_data.context_lens[seq_idx] = context_len
         inter_data.input_tokens[seq_idx].extend(tokens)
         inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
+        inter_data.token_types[seq_idx].extend(
+            token_types if token_types else [])
         inter_data.query_lens[seq_idx] = seq_len - context_len
 
         if seq_data.mrope_position_delta is not None:
@@ -562,6 +575,8 @@ def _compute_for_prefix_cache_hit(
                 seq_idx][uncomputed_start:]
             inter_data.input_positions[seq_idx] = inter_data.input_positions[
                 seq_idx][uncomputed_start:]
+            inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
+                uncomputed_start:]
             context_len = prefix_cache_len
 
             inter_data.context_lens[seq_idx] = context_len
@@ -576,6 +591,8 @@ def _compute_for_prefix_cache_hit(
                 seq_idx][-1:]
             inter_data.input_positions[seq_idx] = inter_data.input_positions[
                 seq_idx][-1:]
+            inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
+                -1:]
             inter_data.query_lens[seq_idx] = 1
             inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
 
@@ -700,6 +717,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
                         spatial_merge_size=hf_config.vision_config.
                         spatial_merge_size,
                         context_len=inter_data.context_lens[seq_idx],
+                        seq_len=inter_data.seq_lens[seq_idx],
                     )
 
                 seq_data.mrope_position_delta = mrope_position_delta
@@ -803,9 +821,12 @@ def build(self) -> ModelInputForGPU:
         """
         # Combine and flatten intermediate data.
         input_tokens = []
+        token_types = []
         for inter_data in self.inter_data_list:
             for cur_input_tokens in inter_data.input_tokens:
                 input_tokens.extend(cur_input_tokens)
+            for cur_token_types in inter_data.token_types:
+                token_types.extend(cur_token_types)
 
         if not input_tokens:
             # This may happen when all prefill requests hit
@@ -874,6 +895,12 @@ def build(self) -> ModelInputForGPU:
         input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                                self.runner.device,
                                                self.runner.pin_memory)
+
+        token_types_tensor = async_tensor_h2d(token_types, torch.long,
+                                               self.runner.device,
+                                               self.runner.pin_memory) \
+                                                if token_types else None
+
         if mrope_input_positions is not None:
             for idx in range(3):
                 mrope_input_positions[idx].extend(
@@ -952,6 +979,7 @@ def build(self) -> ModelInputForGPU:
         return self.model_input_cls(
             input_tokens=input_tokens_tensor,
             input_positions=input_positions_tensor,
+            token_types=token_types_tensor,
             attn_metadata=attn_metadata,
             seq_lens=seq_lens,
             query_lens=query_lens,
@@ -1141,10 +1169,9 @@ def load_model(self) -> None:
                     "provided. Defaulting to scaling factors of 1.0. "
                     "This may lead to less accurate results!")
 
-        if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
-            and supports_dynamo():
-            from vllm.plugins import get_torch_compile_backend
-            backend = get_torch_compile_backend() or "eager"
+        if self.vllm_config.compilation_config.level ==\
+            CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
+            backend = self.vllm_config.compilation_config.init_backend()
             self.model = torch.compile(
                 self.model,
                 fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
@@ -1770,7 +1797,7 @@ def capture(
         # Run the model a few times without capturing the graph.
         # This is to make sure that the captured graph does not include the
         # kernel launches for initial benchmarking (e.g., Triton autotune).
-        # Note one iteration is not enough for torch.jit.script
+        # Note one iteration is not enough for torch.compile
         for _ in range(_NUM_WARMUP_ITERS):
             self.model(
                 input_ids=input_ids,
diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py
index a721186137328..d7a641857a613 100644
--- a/vllm/worker/tpu_model_runner.py
+++ b/vllm/worker/tpu_model_runner.py
@@ -140,7 +140,7 @@ def load_model(self) -> None:
             model = get_model(vllm_config=self.vllm_config)
         model = model.eval()
         xm.wait_device_ops()
-        self.model = ModelWrapper(model)
+        self.model = ModelWrapper(model, self.vllm_config)
 
     def _dummy_run(
         self,
@@ -669,13 +669,15 @@ def execute_model(
 
 class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
 
-    def __init__(self, model: nn.Module):
+    def __init__(self, model: nn.Module, vllm_config: VllmConfig):
         self.model = model
         compiled_callable = torch.compile(self.forward,
                                           backend="openxla",
                                           fullgraph=True,
                                           dynamic=False)
-        super().__init__(compiled_callable)
+        super().__init__(
+            compiled_callable,
+            compilation_level=vllm_config.compilation_config.level)
 
     def __call__(self, *args, is_prompt: bool, **kwargs):
         if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: