diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh
index 4505dc7a9373c..fa4f74fca7a11 100644
--- a/.buildkite/run-hpu-test.sh
+++ b/.buildkite/run-hpu-test.sh
@@ -13,4 +13,4 @@ trap remove_docker_container EXIT
remove_docker_container
# Run the image and launch offline inference
-docker run --runtime=habana --name=hpu-test --network=host -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py
\ No newline at end of file
+docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py
\ No newline at end of file
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index cd721971d01d6..3cb91fc0f8232 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -3,13 +3,16 @@
# This lists cover the "core" components of vLLM that require careful review
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-CMakeLists.txt @tlrmchlsmth @WoosukKwon
+/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+CMakeLists.txt @tlrmchlsmth
+
+# vLLM V1
+/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
# Test ownership
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
index 71f4e520135d4..d1f6105a47166 100644
--- a/.github/FUNDING.yml
+++ b/.github/FUNDING.yml
@@ -1,2 +1,2 @@
github: [vllm-project]
-open_collective: [vllm]
+open_collective: vllm
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000..51a73c857ccb2
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,5 @@
+FILL IN THE PR DESCRIPTION HERE
+
+FIX #xxxx (*link existing issues this PR will resolve*)
+
+**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html **
diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh
index 3b2da7b9f8966..3246c6f9bc4b7 100755
--- a/.github/scripts/cleanup_pr_body.sh
+++ b/.github/scripts/cleanup_pr_body.sh
@@ -15,19 +15,36 @@ NEW=/tmp/new_pr_body.txt
gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}"
cp "${OLD}" "${NEW}"
-# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**"
-sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE\*\*/,$d' "${NEW}"
-
# Remove "FIX #xxxx (*link existing issues this PR will resolve*)"
sed -i '/FIX #xxxx.*$/d' "${NEW}"
# Remove "FILL IN THE PR DESCRIPTION HERE"
sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}"
+# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**"
+sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
+
+# Remove HTML section that includes text of "PR Checklist (Click to Expand)"
+python3 - <.*?.*?PR Checklist \(Click to Expand\).*?
.*?
', re.DOTALL)
+content = re.sub(pattern, '', content)
+
+with open("${NEW}", "w") as file:
+ file.write(content)
+EOF
+
# Run this only if ${NEW} is different than ${OLD}
if ! cmp -s "${OLD}" "${NEW}"; then
- echo "Updating PR body"
gh pr edit --body-file "${NEW}" "${PR_NUMBER}"
+ echo
+ echo "Updated PR body:"
+ echo
+ cat "${NEW}"
else
echo "No changes needed"
fi
diff --git a/Dockerfile.ubi b/Dockerfile.ubi
index faf3f498c1d2c..527c6b1557b12 100644
--- a/Dockerfile.ubi
+++ b/Dockerfile.ubi
@@ -201,7 +201,6 @@ WORKDIR /home/vllm
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
-
FROM vllm-openai as vllm-grpc-adapter
USER root
@@ -209,7 +208,12 @@ USER root
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \
- HOME=/root uv pip install "$(echo /workspace/dist/*.whl)[tensorizer]" vllm-tgis-adapter==0.5.3
+ uv pip install $(echo /workspace/dist/*.whl)'[tensorizer]' --verbose && \
+ uv pip install \
+ "git+https://github.com/opendatahub-io/vllm-tgis-adapter@ibm-20241106-adapter" --verbose
+
+RUN --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \
+ echo "Local dir and dist:" && pwd && ls -l /workspace/dist
ENV GRPC_PORT=8033 \
PORT=8000 \
diff --git a/README.md b/README.md
index 6530886ed7de2..0ef073210d070 100644
--- a/README.md
+++ b/README.md
@@ -100,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Dropbox
- Google Cloud
- Lambda Lab
+- Nebius
- NVIDIA
- Replicate
- Roblox
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index bdb8ea8e2a5dc..e9fc037a46965 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -251,6 +251,19 @@ def sample_hf_requests(
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
+ elif "image" in data and isinstance(data["image"], str):
+ if (data["image"].startswith("http://") or \
+ data["image"].startswith("file://")):
+ image_url = data["image"]
+ else:
+ image_url = f"file://{data['image']}"
+
+ mm_content = {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ },
+ }
else:
mm_content = None
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index 665b50bf18cf0..a0342d08f1db8 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -2,8 +2,10 @@
import copy
import itertools
import math
+import os
import pickle as pkl
import time
+from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
@@ -15,11 +17,12 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
- GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
+ GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
+ marlin_zero_points)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- gptq_pack, pack_rows, quantize_weights)
+ pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
@@ -27,149 +30,349 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]
+NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
+
+if NVTX_PROFILE:
+ import nvtx
+
+
+def terse_type_name(dt):
+ return {
+ torch.bfloat16: "bf16",
+ torch.float16: "fp16",
+ torch.int8: "int8",
+ torch.float8_e4m3fn: "fp8",
+ torch.bfloat16: "bf16",
+ torch.float: "float",
+ torch.int: "int",
+ }[dt]
+
+
+@dataclass
+class BenchmarkTensors:
+ w_ref: torch.Tensor
+ a: torch.Tensor
+
+ w_q: torch.Tensor
+ group_size: Optional[int]
+ wtype: ScalarType
+ w_g_s: torch.Tensor
+ w_g_zp: Optional[torch.Tensor]
+ w_ch_s: Optional[torch.Tensor]
+ w_tok_s: Optional[torch.Tensor]
+
+
+@dataclass
+class TypeConfig:
+ act_type: torch.dtype
+ weight_type: ScalarType
+ output_type: Optional[torch.dtype]
+ group_scale_type: Optional[torch.dtype]
+ group_zero_type: Optional[torch.dtype]
+ channel_scale_type: Optional[torch.dtype]
+ token_scale_type: Optional[torch.dtype]
+
+
+def rand_data(shape, dtype=torch.float16, scale=1):
+ if dtype.is_floating_point:
+ return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
+ else:
+ return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
+
+
+def quantize_and_pack(atype: torch.dtype,
+ w: torch.Tensor,
+ wtype: ScalarType,
+ stype: Optional[torch.dtype],
+ group_size: Optional[int],
+ zero_points: bool = False):
+ assert wtype.is_integer(), "TODO: support floating point weights"
+
+ w_ref, w_q, w_s, w_zp = quantize_weights(
+ w,
+ wtype,
+ group_size=group_size,
+ zero_points=zero_points,
+ # to match how the kernel applies zps
+ ref_zero_points_after_scales=True)
-def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
- w_q = w_q.t().contiguous().t() # make col major
- return ops.machete_prepack_B(w_q, wtype)
+ return w_ref, w_q, w_s, w_zp
-def make_bench_tensors(
- atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
- k: int
-) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
- torch.tensor]]]:
- assert wtype.is_integer(), "TODO: support floating point weights"
+def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
+ group_size: Optional[int]) -> List[BenchmarkTensors]:
+ m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
- num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
-
- a = torch.randn((m, k), device="cuda", dtype=atype) * 5
- weights = [
- torch.randn((k, n), device="cuda", dtype=atype)
- for _ in range(num_weights)
- ]
- quanitized_weights = [
- quantize_weights(w, wtype, group_size) for w in weights
- ]
-
- return a, quanitized_weights
+ num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
+ (k * n * types.weight_type.size_bits))
+
+ a = rand_data((m, k), types.act_type, scale=5)
+
+ benchmark_tensors: List[BenchmarkTensors] = []
+ for _ in range(num_weights):
+ w = rand_data((k, n), types.act_type, scale=5)
+
+ if types.group_scale_type is not None:
+ w = w.to(types.group_scale_type)
+ if w.dtype.itemsize == 1:
+ w = w.to(torch.float16)
+
+ w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
+ a.dtype, w, types.weight_type, types.group_scale_type, group_size,
+ types.group_zero_type is not None)
+
+ if not a.dtype.is_floating_point:
+ aiinfo = torch.iinfo(a.dtype)
+ w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
+
+ w_ref = w_ref.to(torch.float32)
+
+ w_ch_s = None if types.channel_scale_type is None else\
+ rand_data((n,), types.channel_scale_type)
+ w_tok_s = None if types.token_scale_type is None else\
+ rand_data((m,), types.token_scale_type)
+
+ benchmark_tensors.append(
+ BenchmarkTensors(w_ref=w_ref,
+ a=a,
+ w_q=w_q_packed,
+ wtype=types.weight_type,
+ w_g_s=w_s,
+ w_g_zp=w_zp,
+ group_size=group_size,
+ w_ch_s=w_ch_s,
+ w_tok_s=w_tok_s))
+
+ return benchmark_tensors
+
+
+def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ a = bt.a
+ w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
+ if a.dtype not in [torch.float16, torch.bfloat16]:
+ a = a.to(torch.float16)
+ w = w.to(torch.float16)
+ return lambda: torch.matmul(a, w)
+
+
+def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ if bt.w_ch_s is not None and bt.w_tok_s is not None:
+ scale_a = bt.w_tok_s.to(torch.float32)
+ scale_b = bt.w_ch_s.to(torch.float32)
+ else:
+ scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
+ scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
+ w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
+ return lambda: ops.cutlass_scaled_mm(
+ bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
+
+
+def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ device = bt.a.device
+
+ workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
+ GPTQ_MARLIN_MAX_PARALLEL)
+
+ if bt.w_g_zp is None:
+ w_zp = torch.empty(0, dtype=torch.int, device=device)
+ else:
+ w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.wtype.size_bits)
+
+ if bt.group_size is None:
+ w_s = torch.tensor([], device="cuda", dtype=torch.half)
+ else:
+ w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.group_size)
+
+ sort_indices = torch.empty(0, dtype=torch.int, device=device)
+ g_idx = torch.empty(0, dtype=torch.int, device=device)
+ w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.wtype.size_bits)
+
+ if bt.a.dtype.is_floating_point:
+ assert bt.w_ch_s is None
+ assert bt.w_tok_s is None
+ assert bt.group_size is not None
+
+ fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
+ b_q_weight=w_q,
+ b_scales=w_s,
+ b_zeros=w_zp,
+ g_idx=g_idx,
+ perm=sort_indices,
+ workspace=workspace.scratch,
+ b_q_type=bt.wtype,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0],
+ is_k_full=True)
+ else:
+ assert bt.a.dtype == torch.int8
+ assert bt.wtype == scalar_types.uint4b8
+
+ if bt.w_ch_s is not None:
+ s_ch = bt.w_ch_s.to(torch.float32)
+ else:
+ s_ch = torch.ones(bt.w_ref.shape[1],
+ dtype=torch.float32,
+ device=device)
+
+ if bt.w_tok_s is not None:
+ s_tok = bt.w_tok_s.to(torch.float32)
+ else:
+ s_tok = torch.ones(bt.a.shape[0],
+ dtype=torch.float32,
+ device=device)
+
+ fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
+ b_q_weight=w_q,
+ s_group=w_s,
+ s_tok=s_tok,
+ s_ch=s_ch,
+ workspace=workspace.scratch,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0])
+
+ return fn
+
+
+def machete_create_bench_fn(bt: BenchmarkTensors,
+ out_type=torch.dtype,
+ schedule=None) -> Callable:
+ w_q = bt.w_q.t().contiguous().t() # make col major
+ w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
+ None if bt.w_g_s is None else bt.w_g_s.dtype)
+
+ w_g_zp = bt.w_g_zp
+ if w_g_zp is not None:
+ w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
+
+ return lambda: ops.machete_mm(
+ a=bt.a,
+ b_q=bt.w_q,
+ b_type=bt.wtype,
+ b_group_scales=bt.w_g_s,
+ b_group_zeros=w_g_zp,
+ b_group_size=bt.group_size,
+ b_channel_scales=bt.w_ch_s,
+ a_token_scales=bt.w_tok_s,
+ out_type=out_type,
+ schedule=schedule,
+ )
# impl
-
# bench
-def bench_fn(label: str, sub_label: str, description: str,
- fn: Callable) -> TMeasurement:
- min_run_time = 1
- return TBenchmark.Timer(
- stmt="fn()",
+
+def bench_fns(label: str, sub_label: str, description: str,
+ fns: List[Callable]):
+
+ min_run_time = 1 if not NVTX_PROFILE else 0.1
+ res = TBenchmark.Timer(
+ stmt="""
+ for fn in fns:
+ fn()
+ """,
globals={
- "fn": fn
+ "fns": fns
},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
+ if NVTX_PROFILE:
+ with nvtx.annotate("mm-bench"), nvtx.annotate(
+ f"{label}|{sub_label}|{description}"):
+ fns[0]()
-def loop_over_weights(
- a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
- torch.tensor, torch.tensor]],
- fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
- None]):
- for w_ref, w_q, w_s, _ in weights:
- fn(a, w_ref, w_q, w_s)
+ return res
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
-def bench(atype: torch.dtype,
- wtype: ScalarType,
+def bench(types: TypeConfig,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
- benchmark_marlinv1: bool = True,
- sweep_schedules: bool = True) -> Iterable[TMeasurement]:
- global _SWEEP_SCHEDULES_RESULTS
-
- a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
- sub_label += f", L={len(weights)}"
-
- weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
- for w_ref, w_q, w_s, w_zp in weights]
+ sweep_schedules: bool = True) -> List[TMeasurement]:
+ benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
+ sub_label += f", L={len(benchmark_tensors)}"
+
+ name_type_string = f"W{types.weight_type}"+\
+ f"-A{terse_type_name(types.act_type)}"
+ if types.group_scale_type is not None:
+ name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
+ if types.group_zero_type is not None:
+ name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
+ if group_size is not None:
+ name_type_string += f"-G{group_size}"
+ if types.channel_scale_type is not None:
+ name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
+ if types.token_scale_type is not None:
+ name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
timers = []
# pytorch impl
timers.append(
- bench_fn(
- label, sub_label, "torch.matmul", lambda: loop_over_weights(
- a,
- weights,
- lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
- )))
+ bench_fns(
+ label, sub_label, "torch.matmul (fp16)",
+ [torch_matmul_f16_create_bench_fn(bt)
+ for bt in benchmark_tensors]))
- if benchmark_marlinv1:
- w_ref = weights[0][0]
-
- w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
- sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
- g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
-
- def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
- w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
- return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
- wtype.size_bits)
-
- def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
- return marlin_permute_scales(w_s, *w_ref.shape, group_size)
-
- weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
- marlinv1_permute_scales(w_s), w_zp)
- for w_ref, w_q, w_s, w_zp in weights]
-
- workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MAX_PARALLEL)
-
- # marlinv1
+ if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
+ timers.append(
+ bench_fns(
+ label, sub_label,
+ f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
+ cutlass_scaled_mm_create_bench_fn(bt)
+ for bt in benchmark_tensors
+ ]))
+
+ if types.act_type != torch.float8_e4m3fn:
timers.append(
- bench_fn(
- label, sub_label, "marlin_orig", lambda: loop_over_weights(
- a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
- gptq_marlin_gemm(a,
- w_q,
- w_s,
- w_zp_empty,
- g_idx,
- sort_indices,
- workspace.scratch,
- wtype,
- size_m=a.shape[0],
- size_n=w_ref.shape[1],
- size_k=w_ref.shape[0],
- is_k_full=True))))
+ bench_fns(label, sub_label, f"marlin ({name_type_string})",
+ [marlin_create_bench_fn(bt)
+ for bt in benchmark_tensors]))
# machete
timers.append(
- bench_fn(
- label, sub_label, "machete_heuristic", lambda: loop_over_weights(
- a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
- a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
+ bench_fns(label, sub_label, f"machete ({name_type_string})", [
+ machete_create_bench_fn(bt, out_type=types.output_type)
+ for bt in benchmark_tensors
+ ]))
if sweep_schedules:
+ global _SWEEP_SCHEDULES_RESULTS
+
print("Finding best schedule for machete")
best = None
best_schedule = None
- schedules = ops.machete_supported_schedules(wtype)
+ schedules = ops.machete_supported_schedules(
+ a_type=types.act_type,
+ b_type=types.weight_type,
+ group_scales_type=types.group_scale_type,
+ group_zeros_type=types.group_zero_type,
+ token_scales_type=types.token_scale_type,
+ channel_scales_type=types.channel_scale_type,
+ out_type=types.output_type)
+
+ if schedules is None or len(schedules) == 0:
+ raise ValueError("No schedules found to sweep")
+
for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])
@@ -177,16 +380,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
- def run(a, _, w_q, w_s, schedule=schedule):
- ops.machete_gemm(a,
- w_q,
- wtype,
- w_s,
- b_group_size=group_size,
- schedule=schedule)
-
- res = bench_fn(label, sub_label, "machete_best",
- lambda: loop_over_weights(a, weights_machete, run))
+ res = bench_fns(label, sub_label, "machete_best", [
+ machete_create_bench_fn(
+ bt, out_type=types.output_type, schedule=schedule)
+ for bt in benchmark_tensors
+ ])
results_row = {
"M": m,
@@ -213,25 +411,33 @@ def run(a, _, w_q, w_s, schedule=schedule):
# runner
-def print_timers(timers: Iterable[TMeasurement]):
+def print_timers(timers: List[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
-def run(dtype: torch.dtype, sweep_schedules: bool,
- MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+ types = TypeConfig(
+ act_type=args.act_type,
+ weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
+ else scalar_types.uint4,
+ output_type=args.out_type,
+ group_scale_type=args.group_scale_type,
+ group_zero_type=args.group_zero_type,
+ channel_scale_type=args.channel_scale_type,
+ token_scale_type=args.token_scale_type,
+ )
- results = []
+ results: List[TMeasurement] = []
for m, k, n in MKNs:
- timers = bench(dtype,
- scalar_types.uint4b8,
- 128,
+ timers = bench(types,
+ args.group_size,
m,
k,
n,
- f"{dtype}-gemm",
+ f"{args.act_type}-gemm",
f"MKN=({m}x{k}x{n})",
- sweep_schedules=sweep_schedules)
+ sweep_schedules=args.sweep_schedules)
print_timers(timers)
results.extend(timers)
@@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool,
# output makers
def make_output(
- data: Iterable[TMeasurement],
+ data: List[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None,
@@ -262,7 +468,6 @@ def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
-
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
@@ -306,33 +511,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
for k, n in KNs:
MKNs.append((m, k, n))
- data = run(args.dtype, args.sweep_schedules, MKNs)
+ data = run(args, MKNs)
model_bench_data.append(data)
+ type_string = f"{args.act_type}"
+
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
- print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
+ print(f"== Results {type_string} {model}-TP{tp_size} ====")
print_timers(data)
- timestamp = int(time.time())
+ timestr = time.strftime("%Y%m%d-%H%M%S")
- all_data = []
+ all_results = []
for d in model_bench_data:
- all_data.extend(d)
+ all_results.extend(d)
+
# pickle all data
- with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
- pkl.dump(all_data, f)
+ with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
+ args_dict = vars(args)
+ args_dict.pop("func")
+ pkl.dump({
+ "args": args_dict,
+ "results": all_results,
+ }, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
- if dt == "bfloat16":
- return torch.bfloat16
- if dt == "float16":
- return torch.float16
- raise ValueError("unsupported dtype")
+ return {
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ "int8": torch.int8,
+ "float8_e4m3fn": torch.float8_e4m3fn,
+ "int": torch.int,
+ "float": torch.float,
+ }[dt]
+
+ class ToTorchDtype(argparse.Action):
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ setattr(namespace, self.dest, to_torch_dtype(values))
parser = FlexibleArgumentParser(
description="""
@@ -352,12 +573,42 @@ def to_torch_dtype(dt):
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
-
parser.add_argument(
- "--dtype",
- type=to_torch_dtype,
+ "--act-type",
+ action=ToTorchDtype,
required=True,
- help="Available options are ['bfloat16', 'float16']",
+ choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
+ )
+ parser.add_argument(
+ "--group-scale-type",
+ action=ToTorchDtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--group-zero-type",
+ type=to_torch_dtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--channel-scale-type",
+ action=ToTorchDtype,
+ choices=['float'],
+ )
+ parser.add_argument(
+ "--token-scale-type",
+ action=ToTorchDtype,
+ choices=['float'],
+ )
+ parser.add_argument(
+ "--out-type",
+ action=ToTorchDtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--group-size",
+ type=int,
+ help="Available options are ['None', '-1', '128'], default=128",
+ default=128,
)
parser.add_argument(
"--sweep-schedules",
diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py
index de608fd05af70..7d0bd84150a27 100644
--- a/benchmarks/kernels/graph_machete_bench.py
+++ b/benchmarks/kernels/graph_machete_bench.py
@@ -20,10 +20,11 @@
args = parser.parse_args()
with open(args.filename, 'rb') as f:
- data: List[TMeasurement] = pickle.load(f)
+ data = pickle.load(f)
+ raw_results: List[TMeasurement] = data["results"]
results = defaultdict(lambda: list())
- for v in data:
+ for v in raw_results:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py
index 25ec9d6028627..51f24f3ba1774 100644
--- a/benchmarks/kernels/weight_shapes.py
+++ b/benchmarks/kernels/weight_shapes.py
@@ -40,4 +40,10 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
+ "meta-llama/Llama-3.1-405b-hf": [
+ ([16384, 18432], 1),
+ ([16384, 16384], 0),
+ ([16384, 106496], 1),
+ ([53248, 16384], 0),
+ ],
}
diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh
index 1842fab8b2cac..f61fe3ceb978a 100644
--- a/csrc/cutlass_extensions/cute_utils.cuh
+++ b/csrc/cutlass_extensions/cute_utils.cuh
@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x
template
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
- if constexpr (std::is_same_v)
+ if constexpr (std::is_same_v) {
return true;
- else {
+ } else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
similarity index 99%
rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
index d407d66ab2aa6..7aa87feb4cce2 100644
--- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
@@ -52,6 +52,7 @@
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
+#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
similarity index 100%
rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
new file mode 100644
index 0000000000000..c69e87999ae71
--- /dev/null
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
@@ -0,0 +1,317 @@
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
+
+/*
+ This file defines custom epilogues for fusing channel scales, token scales,
+ bias, and activation zero-points onto a GEMM operation using the
+ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
+
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
+ as well as a static prepare_args function that constructs an
+ EVTCompute::Arguments struct.
+*/
+
+namespace vllm::c2x {
+
+using namespace cute;
+
+/*
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
+ */
+template
+struct ScaledEpilogueBase {
+ protected:
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+ template
+ using ColOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using RowOrZeroLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ // This utility function constructs the arguments for the load descriptors
+ // from a tensor. It can handle both row and column, as well as row/column or
+ // scalar cases.
+ template
+ static auto args_from_tensor(torch::Tensor const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = static_cast(tensor.data_ptr());
+ if constexpr (std::is_same_v> ||
+ std::is_same_v>) {
+ return Arguments{data_ptr, tensor.numel() != 1};
+ } else {
+ // it would technically work but no use case as data_ptr is never nullptr
+ static_assert(!std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+ }
+
+ // This overload handles the case where there might not be a tensor, in which
+ // case a nullptr is passed and a constant (0) is used.
+ template
+ static auto args_from_tensor(c10::optional const& tensor) {
+ static_assert(std::is_same_v>);
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
+ return Arguments{data_ptr};
+ }
+};
+
+/*
+ This epilogue function defines a quantized GEMM operation similar to
+ torch._scaled_mm.
+
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
+ per-row. B can be quantized per-tensor or per-column.
+ Any combination of per-tensor and per-row or column is supported.
+ A and B must have symmetric quantization (zero point == 0).
+
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
+ scales are applied elementwise with numpy-style broadcasting.
+
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
+ the A and B operands respectively. These scales may be either per-tensor or
+ per row or column.
+*/
+template
+struct ScaledEpilogue
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+
+ typename EVTCompute0::Arguments evt0_args{b_args};
+ return ArgumentType{a_args, evt0_args};
+ }
+};
+
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
+template
+struct ScaledEpilogueBias
+ : protected ScaledEpilogueBase {
+ protected:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad;
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+
+ typename EVTCompute0::Arguments evt0_args{b_args};
+ return ArgumentType{a_args, evt0_args, bias_args};
+ }
+};
+
+/*
+ * This epilogue directly supports per-tensor azp in int32 form.
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
+ * term, which should already be multiplied with the scalar azp.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzp
+ : protected ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowOrZeroLoad;
+
+ // This is the full AZP term, azp * J @ B, shape (1,n)
+ using AzpWithAdj = typename SUPER::template RowLoad;
+
+ // Compute float(accum - azp_adj), both operands are int32_t
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ c10::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
+ }
+};
+
+/*
+ * This epilogue supports per-token azp by computing and applying
+ * the correction term using a rank-1 update. If the term were materialized,
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
+ * point for each row of A.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzpToken
+ : protected ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowOrZeroLoad;
+
+ // Per-token azp term, shape (m,1)
+ using Azp = typename SUPER::template ColLoad;
+
+ // This is the AZP adjustment term, J @ B, shape (1,n)
+ using AzpAdj = typename SUPER::template RowLoad;
+
+ // Compute azp * azp_adj
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, int32_t, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
+ using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAcc =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ torch::Tensor const& azp,
+ c10::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_args = SUPER::template args_from_tensor(azp);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
+ }
+};
+
+}; // namespace vllm::c2x
\ No newline at end of file
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
new file mode 100644
index 0000000000000..95764ecddc79f
--- /dev/null
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
@@ -0,0 +1,315 @@
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
+
+/*
+ This file defines custom epilogues for fusing channel scales, token scales,
+ bias, and activation zero-points onto a GEMM operation using the
+ CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
+
+ Epilogues must contain a public type named EVTCompute of type Sm90EVT,
+ as well as a static prepare_args function that constructs an
+ EVTCompute::Arguments struct.
+*/
+
+namespace vllm::c3x {
+
+using namespace cute;
+
+/*
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
+ */
+template
+struct ScaledEpilogueBase {
+ protected:
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
+
+ template
+ using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ Stride, Int<1>, Int<0>>>;
+
+ // Don't want to support nullptr by default
+ template
+ using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
+
+ // Don't want to support nullptr by default
+ template
+ using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
+
+ // This utility function constructs the arguments for the load descriptors
+ // from a tensor. It can handle both row and column, as well as row/column or
+ // scalar cases.
+ template
+ static auto args_from_tensor(torch::Tensor const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = static_cast(tensor.data_ptr());
+ if constexpr (std::is_same_v> ||
+ std::is_same_v>) {
+ return Arguments{data_ptr, tensor.numel() != 1};
+ } else {
+ static_assert(!std::is_same_v> &&
+ !std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+ }
+
+ // This overload handles the case where there might not be a tensor, in which
+ // case a nullptr is passed and a constant (0) is used.
+ template
+ static auto args_from_tensor(c10::optional const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
+ static_assert(std::is_same_v> ||
+ std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+};
+
+/*
+ This epilogue function defines a quantized GEMM operation similar to
+ torch.scaled_mm_.
+
+ A and B may be both either int8 or fp8_e4m3. A can be
+ quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
+ Any combination of per-tensor and per-row or column is supported.
+ A and B must have symmetric quantization (zero point == 0).
+
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
+ scales are applied elementwise with numpy-style broadcasting.
+
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
+ the A and B operands respectively. These scales may be either per-tensor or
+ per row or column.
+*/
+template
+struct ScaledEpilogue
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::fusion::Sm90EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+
+ typename EVTCompute0::Arguments evt0_args{b_args};
+ return ArgumentType{a_args, evt0_args};
+ }
+};
+
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
+template
+struct ScaledEpilogueBias
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad;
+
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using ArgumentType = typename EVTCompute::Arguments;
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+
+ typename EVTCompute0::Arguments evt0_args{b_args};
+ return ArgumentType{a_args, evt0_args, bias_args};
+ }
+};
+
+/*
+ * This epilogue directly supports per-tensor azp in int32 form.
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
+ * term, which should already be multiplied with the scalar azp.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzp
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad;
+
+ // This is the full AZP term, azp * J @ B, shape (1,n)
+ using AzpWithAdj = typename SUPER::template RowLoad;
+
+ // Compute float(accum - azp_adj), both operands are int32_t
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::fusion::Sm90EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ c10::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
+ }
+};
+
+/*
+ * This epilogue supports per-token azp by computing and applying
+ * the correction term using a rank-1 update. If the term were materialized,
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
+ * point for each row of A.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzpToken
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad;
+
+ // Per-token azp term, shape (m,1)
+ using Azp = typename SUPER::template ColLoad;
+
+ // This is the AZP adjustment term, J @ B, shape (1,n)
+ using AzpAdj = typename SUPER::template RowLoad;
+
+ // Compute azp * azp_adj
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, int32_t, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
+ using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAcc =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::fusion::Sm90EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::fusion::Sm90EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ torch::Tensor const& azp,
+ c10::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_args = SUPER::template args_from_tensor(azp);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
+ }
+};
+
+}; // namespace vllm::c3x
\ No newline at end of file
diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
index 4fcfcd311aa91..a5beea1a35e49 100644
--- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py
@@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum):
}
}
+VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
+ **DataTypeSize, # type: ignore
+ **{
+ VLLMDataType.u4b8: 4,
+ VLLMDataType.u8b128: 8,
+ }
+}
+
+VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
+ VLLMDataType.u4b8: "vllm::kU4B8",
+ VLLMDataType.u8b128: "vllm::kU8B128",
+ DataType.u4: "vllm::kU4",
+ DataType.u8: "vllm::kU8",
+ DataType.s4: "vllm::kS4",
+ DataType.s8: "vllm::kS8",
+ DataType.f16: "vllm::kFloat16",
+ DataType.bf16: "vllm::kBfloat16",
+}
+
+VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
+ DataType.u8: "at::ScalarType::Byte",
+ DataType.s8: "at::ScalarType::Char",
+ DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
+ DataType.s32: "at::ScalarType::Int",
+ DataType.f16: "at::ScalarType::Half",
+ DataType.bf16: "at::ScalarType::BFloat16",
+ DataType.f32: "at::ScalarType::Float",
+}
+
VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
index 2ad914f8e9868..90f226cf64c0a 100644
--- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
+++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
@@ -3,6 +3,7 @@
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
+#include "cutlass_extensions/vllm_type_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
- CUTE_INVALID_CONTROL_PATH(
- "InterleavedNumericArrayConverter not implemented\n");
+ if (cute::elect_one_sync()) {
+ if constexpr (std::is_same_v) {
+ printf(
+ "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
+ nameof_v, nameof_v, N);
+ } else {
+ printf(
+ "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
+ "implemented\n",
+ nameof_v, nameof_v, N, size(IlvBlkLayout{}));
+ }
+ __brkpt();
+ }
return {};
}
@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
result_type operator()(source_type const& s) const { return convert(s); }
};
-// TODO (LucasWilkinson): Implement
-// for Array <= Array
-
-// ....
-
template
struct ArrayConverterPacked32Bit {
using result_type = Array;
@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
using ScalarConverter = NumericConverter;
template
- CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
+ CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) {
- return static_cast(reinterpret_cast(source));
+ return Array{reinterpret_cast(src)};
} else if constexpr (sizeof(PackedSrc) == 2) {
- return static_cast(reinterpret_cast(source));
+ return Array{reinterpret_cast(src)};
+ } else if constexpr (sizeof(PackedSrc) == 4) {
+ return Array{reinterpret_cast(src)};
} else {
- static_assert(sizeof(PackedSrc) == 4);
- return reinterpret_cast(source);
+ static_assert(sizeof(PackedSrc) == 8);
+ return reinterpret_cast const&>(src);
}
}
@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
static_assert(std::is_same_v);
static_assert(std::is_same_v);
- return RegConvert32bit::template convert(to_reg(source));
+ return RegConvert32bit::template convert(to_regs(source));
}
friend class detail::VectorizedConverter;
@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
}
};
+// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
+// into 2 32bit register.
+template
+CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert(
+ uint32_t src) {
+ cutlass::AlignedArray r;
+ // Determines if the value is in the top half of the LUT if set or
+ // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
+ // into bit position 0x4 of each nibble so when or'd with final_prmt_base it
+ // selects the correct candidate. When elements in final_prmt_base
+ // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
+ // are < 0x4, the low candidate is selected (i.e. LUT[0:7])
+ uint32_t high_bit = (src & 0x88888888) >> 1;
+
+ // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
+ // (selects correct high or low candidate)
+ const uint32_t final_prmt_base = 0x32103210;
+
+ // Ignore the high bit when indexing into LUT, for each 4bit value
+ // we index into both the high and low candidates then use
+ // high_bit | final_prmt_base to select the correct candidate
+ uint32_t lut_idx = (src & 0x77777777);
+
+ auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
+ return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
+ (uint32_t(d) << 24);
+ };
+
+ static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
+ static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
+ static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
+ static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
+ uint32_t final_prmt_idx = final_prmt_base | high_bit;
+
+ // This uses a look up table to convert packed int4s to packed int8s,
+ // using the int4 value as the index to prmt. It first select both the
+ // high and low candidates, then uses the high bit (i.e. `high_bit`) to
+ // select the correct candidate.
+ asm volatile(
+ "{\n"
+ " .reg .b32 low, high;\n"
+ " prmt.b32 low, %1, %2, %5;\n"
+ " prmt.b32 high, %3, %4, %5;\n"
+ " prmt.b32 %0, low, high, %6;\n"
+ "}\n"
+ : "=r"(r[ii])
+ : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
+ "r"(final_prmt_idx));
+ }
+
+ return r;
+};
+
+// for Array <= Array
+template
+struct NumericArrayConverter {
+ using result_type = Array;
+ using source_type = Array;
+
+ static FloatRoundStyle const round_style = Round;
+
+ private:
+ struct RegConvert {
+ template
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
+ auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
+ 0xFC, 0xFD, 0xFE, 0xFF, //
+ 0x00, 0x01, 0x02, 0x03, //
+ 0x04, 0x05, 0x06, 0x07>(src_[0]);
+ return reinterpret_cast(r);
+ };
+ };
+
+ public:
+ CUTLASS_DEVICE
+ static result_type convert(source_type const& source) {
+ return ArrayConverterPacked32Bit::convert(source);
+ }
+
+ CUTLASS_DEVICE
+ result_type operator()(source_type const& s) const { return convert(s); }
+};
+
+// for Array <= Array
+template
+struct NumericArrayConverter {
+ using result_type = Array;
+ using source_type = Array;
+
+ static FloatRoundStyle const round_style = Round;
+
+ private:
+ struct RegConvert {
+ template
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
+ auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
+ 0xC8, 0xC4, 0xC0, 0xB8, //
+ 0x00, 0x38, 0x40, 0x44, //
+ 0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
+ return reinterpret_cast(r);
+ };
+ };
+
+ public:
+ CUTLASS_DEVICE
+ static result_type convert(source_type const& source) {
+ return ArrayConverterPacked32Bit::convert(source);
+ }
+
+ CUTLASS_DEVICE
+ result_type operator()(source_type const& s) const { return convert(s); }
+};
+
// for Array <= Array
template
struct NumericArrayConverter {
@@ -148,7 +282,8 @@ struct NumericArrayConverter {
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray;
@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>,
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray;
@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>,
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray;
@@ -417,7 +554,8 @@ struct NumericArrayConverter {
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray {
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
@@ -513,7 +652,8 @@ struct NumericArrayConverter {
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray, Stride<_4, _1>>,
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray;
@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>,
private:
struct RegConvert {
template
- CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
+ CUTLASS_DEVICE static PackedResultType convert(Array src_) {
+ uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray;
@@ -788,6 +930,61 @@ struct NumericArrayConverter {
#endif
+// for Array <= Array
+// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
+template
+struct NumericArrayConverter {
+ using result_type = Array;
+ using source_type = Array;
+
+ struct RegConvert {
+ // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
+ template
+ CUTLASS_DEVICE static PackedResultType convert(
+ Array src) {
+ // Hold output int8s in reg. We need 1 reg for every 4 elements
+ using RegArray = cutlass::AlignedArray<
+ uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
+ RegArray r;
+
+ static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
+ auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_);
+
+ *reinterpret_cast(&src[0]) =
+ __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS);
+
+ if constexpr (src_regs > 1) {
+ *reinterpret_cast(&src[1]) =
+ __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS);
+ }
+
+ static_assert(PackedResultType::kElements <= 4);
+ uint32_t uint8s;
+ static constexpr uint32_t MASK_0246 = 0x6420;
+ static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
+ asm volatile("prmt.b32 %0,%1,%2,%3;\n"
+ : "=r"(uint8s)
+ : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
+ "n"(MASK_0246));
+
+ uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
+
+ return reinterpret_cast(int8s);
+ };
+ };
+
+ public:
+ CUTLASS_DEVICE
+ static result_type convert(source_type const& source) {
+ return ArrayConverterPacked32Bit::convert(source);
+ }
+
+ CUTLASS_DEVICE
+ result_type operator()(source_type const& s) const { return convert(s); }
+};
+
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh
new file mode 100644
index 0000000000000..500ed508c8303
--- /dev/null
+++ b/csrc/cutlass_extensions/vllm_type_utils.cuh
@@ -0,0 +1,42 @@
+#include "cutlass/bfloat16.h"
+#include "cutlass/half.h"
+#include "cuda_bf16.h"
+
+#include "cutlass_extensions/vllm_custom_types.cuh"
+
+namespace cutlass {
+
+template
+struct nameof {
+ static constexpr char const* value = "unknown";
+};
+
+template
+inline constexpr auto nameof_v = nameof::value;
+
+#define NAMEOF_TYPE(T) \
+ template <> \
+ struct nameof { \
+ static constexpr char const* value = #T; \
+ };
+
+NAMEOF_TYPE(float_e4m3_t)
+NAMEOF_TYPE(float_e5m2_t)
+NAMEOF_TYPE(half_t)
+NAMEOF_TYPE(nv_bfloat16)
+NAMEOF_TYPE(bfloat16_t)
+NAMEOF_TYPE(float)
+
+NAMEOF_TYPE(int4b_t)
+NAMEOF_TYPE(int8_t)
+NAMEOF_TYPE(int32_t)
+NAMEOF_TYPE(int64_t)
+
+NAMEOF_TYPE(vllm_uint4b8_t)
+NAMEOF_TYPE(uint4b_t)
+NAMEOF_TYPE(uint8_t)
+NAMEOF_TYPE(vllm_uint8b128_t)
+NAMEOF_TYPE(uint32_t)
+NAMEOF_TYPE(uint64_t)
+
+}; // namespace cutlass
\ No newline at end of file
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
index ee801e16573d4..dbb72e8bbd3f5 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
+#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
+
+using namespace vllm;
+
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
- return vllm::cutlass_gemm_sm75_dispatch