Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions challenges/medium/81_int4_matmul/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
<p>
Implement a weight-only INT4 quantized matrix multiplication (W4A16), a core kernel used in
modern LLM inference. Given a float16 activation matrix <code>x</code> of shape
<code>M &times; K</code> and a weight matrix stored in packed INT4 format, compute the output
matrix <code>y = x &times; W<sup>T</sup></code> of shape <code>M &times; N</code>, where
<code>W</code> is the dequantized float16 weight matrix of shape <code>N &times; K</code>.
</p>

<svg width="700" height="400" viewBox="0 0 700 400" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto; font-family:monospace;">
<rect width="700" height="400" fill="#222" rx="10"/>
<defs>
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#aaa"/>
</marker>
</defs>

<!-- ============================================================ -->
<!-- ROW 1: UNPACK — packed byte → two unsigned nibbles → signed -->
<!-- ============================================================ -->
<text x="18" y="20" fill="#666" font-size="10">STEP 1: UNPACK</text>

<!-- Packed byte -->
<text x="80" y="48" fill="#ccc" font-size="11" text-anchor="middle">w_q[n, i]</text>
<rect x="20" y="56" width="120" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<line x1="80" y1="56" x2="80" y2="88" stroke="#4a9edd" stroke-width="1" stroke-dasharray="3,2"/>
<text x="50" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">hi 7:4</text>
<text x="110" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">lo 3:0</text>

<!-- Arrow right -->
<text x="160" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>

<!-- Unsigned nibbles -->
<rect x="180" y="56" width="50" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<text x="205" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">9</text>
<rect x="236" y="56" width="50" height="32" fill="#1a4a1a" rx="4" stroke="#7ec87e" stroke-width="1.5"/>
<text x="261" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">10</text>

<!-- "- 8" arrow -->
<text x="310" y="77" fill="#ccc" font-size="11" text-anchor="middle">&#x2212; 8</text>
<text x="345" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>

<!-- Signed int4 -->
<rect x="365" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
<text x="390" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
<rect x="421" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
<text x="446" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>

<text x="540" y="77" fill="#888" font-size="10" text-anchor="middle">signed int4 [&#x2212;8, 7]</text>

<!-- ============================================================ -->
<!-- ROW 2: GROUP-WISE SCALING — show K=8, group_size=4 -->
<!-- ============================================================ -->
<text x="18" y="112" fill="#666" font-size="10">STEP 2: DEQUANTIZE (example: one row n, K=8, group_size=4)</text>

<!-- K-axis label -->
<text x="350" y="136" fill="#888" font-size="10" text-anchor="middle">k &#x2192;</text>

<!-- Group 0 bracket + cells -->
<text x="145" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 0: scale[n, 0]</text>
<rect x="58" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="81" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
<rect x="108" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="131" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>
<rect x="158" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="181" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;1</text>
<rect x="208" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="231" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+3</text>
<!-- Group 0 bracket -->
<rect x="56" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- Group 1 bracket + cells -->
<text x="385" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 1: scale[n, 1]</text>
<rect x="298" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="321" y="161" fill="#e0a040" font-size="10" text-anchor="middle">0</text>
<rect x="348" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="371" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;3</text>
<rect x="398" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="421" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+7</text>
<rect x="448" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="471" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;2</text>
<!-- Group 1 bracket -->
<rect x="296" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- "int4" label on left -->
<text x="30" y="161" fill="#e0a040" font-size="9">int4</text>

<!-- Multiply arrows down -->
<text x="156" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 0]</text>
<text x="396" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 1]</text>
<line x1="156" y1="172" x2="156" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>
<line x1="396" y1="172" x2="396" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>

<!-- Dequantized row -->
<text x="30" y="217" fill="#40c080" font-size="9">fp16</text>
<rect x="56" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="156" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 0..3] float16</text>
<rect x="296" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="396" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 4..7] float16</text>

<!-- Formula -->
<text x="275" y="252" fill="#ccc" font-size="10" text-anchor="middle">W[n, k] = (nibble &#x2212; 8) &#xd7; scales[n, k // group_size]</text>

<!-- ============================================================ -->
<!-- ROW 3: MATMUL -->
<!-- ============================================================ -->
<text x="18" y="280" fill="#666" font-size="10">STEP 3: MATMUL</text>

<!-- x box -->
<rect x="60" y="296" width="80" height="60" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<text x="100" y="322" fill="#4a9edd" font-size="10" text-anchor="middle">x [M&#xd7;K]</text>
<text x="100" y="340" fill="#4a9edd" font-size="9" text-anchor="middle">float16</text>

<!-- multiply sign -->
<text x="162" y="330" fill="#ccc" font-size="16" text-anchor="middle">&#xd7;</text>

<!-- W^T box -->
<rect x="185" y="296" width="100" height="60" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="235" y="322" fill="#40c080" font-size="10" text-anchor="middle">W&#x1d40; [K&#xd7;N]</text>
<text x="235" y="340" fill="#40c080" font-size="9" text-anchor="middle">float16</text>

<!-- equals sign -->
<text x="310" y="330" fill="#ccc" font-size="16" text-anchor="middle">=</text>

<!-- y output box -->
<rect x="335" y="296" width="90" height="60" fill="#3a1a1a" rx="4" stroke="#e05050" stroke-width="1.5"/>
<text x="380" y="322" fill="#e05050" font-size="10" text-anchor="middle">y [M&#xd7;N]</text>
<text x="380" y="340" fill="#e05050" font-size="9" text-anchor="middle">float16</text>

<!-- Arrow from dequant to W^T -->
<line x1="235" y1="240" x2="235" y2="294" stroke="#40c080" stroke-width="1.5" stroke-dasharray="4,2" marker-end="url(#arr)"/>
<text x="260" y="270" fill="#40c080" font-size="9">dequantized</text>
</svg>

<p>
<strong>Packing format:</strong> Each byte of <code>w_q</code> stores two INT4 weights. The
high nibble (bits 7&ndash;4) holds weight <code>w[n, 2i]</code> and the low nibble (bits
3&ndash;0) holds <code>w[n, 2i+1]</code>. INT4 values are stored unsigned in the range
[0,&nbsp;15] with an offset of 8, so the signed weight is <code>nibble&nbsp;&minus;&nbsp;8</code>,
giving values in [&minus;8,&nbsp;7].
</p>

<p>
<strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of
<code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale:
</p>
<pre>
W[n, k] = (w_q_nibble[n, k] - 8) * scales[n, k // group_size]
</pre>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in <code>y</code></li>
</ul>

<h2>Example</h2>
<p>
Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2):
</p>
<p>
Activations \(x\) (float16, \(2 \times 4\)):
\[
\begin{bmatrix}
1.0 & 0.0 & 1.0 & 0.0 \\
0.0 & 1.0 & 0.0 & 1.0
\end{bmatrix}
\]
Packed weights \(w\_q\) (uint8, \(4 \times 2\)) with signed INT4 values in brackets:
\[
\begin{bmatrix}
\texttt{0x99} & \texttt{0x99} \\
\texttt{0xAA} & \texttt{0xAA} \\
\texttt{0x77} & \texttt{0x77} \\
\texttt{0x88} & \texttt{0x88}
\end{bmatrix}
\;\Rightarrow\;
W_{\text{int4}} =
\begin{bmatrix}
1 & 1 & 1 & 1 \\
2 & 2 & 2 & 2 \\
-1 & -1 & -1 & -1 \\
0 & 0 & 0 & 0
\end{bmatrix}
\]
Scales (float16, \(4 \times 2\), all entries 0.5):
\[
\begin{bmatrix}
0.5 & 0.5 \\
0.5 & 0.5 \\
0.5 & 0.5 \\
0.5 & 0.5
\end{bmatrix}
\;\Rightarrow\;
W_{\text{dequant}} =
\begin{bmatrix}
0.5 & 0.5 & 0.5 & 0.5 \\
1.0 & 1.0 & 1.0 & 1.0 \\
-0.5 & -0.5 & -0.5 & -0.5 \\
0.0 & 0.0 & 0.0 & 0.0
\end{bmatrix}
\]
Output \(y = x \times W^T\) (float16, \(2 \times 4\)):
\[
\begin{bmatrix}
1.0 & 2.0 & -1.0 & 0.0 \\
1.0 & 2.0 & -1.0 & 0.0
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>M</code>, <code>N</code> &le; 8,192</li>
<li>1 &le; <code>K</code> &le; 8,192</li>
<li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li>
<li><code>group_size</code> &isin; {2, 4, 8, 16, 32, 64, 128}</li>
<li>All tensors are stored in row-major order</li>
<li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li>
<li>Output dtype: <code>y</code> is float16</li>
<li>Performance is measured with <code>M</code> = 4,096, <code>N</code> = 4,096, <code>K</code> = 4,096, <code>group_size</code> = 128</li>
</ul>
157 changes: 157 additions & 0 deletions challenges/medium/81_int4_matmul/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="INT4 Weight-Only Quantized MatMul",
atol=1e-02,
rtol=1e-02,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x: torch.Tensor,
w_q: torch.Tensor,
scales: torch.Tensor,
y: torch.Tensor,
M: int,
N: int,
K: int,
group_size: int,
):
assert x.shape == (M, K)
assert w_q.shape == (N, K // 2)
assert scales.shape == (N, K // group_size)
assert y.shape == (M, N)
assert x.dtype == torch.float16
assert w_q.dtype == torch.uint8
assert scales.dtype == torch.float16
assert y.dtype == torch.float16
assert x.device.type == "cuda"
assert w_q.device.type == "cuda"
assert scales.device.type == "cuda"
assert y.device.type == "cuda"

# Unpack INT4 weights from packed uint8 bytes.
# w_q[n, i] stores two weights: w[n, 2*i] in the high nibble (bits 7:4)
# and w[n, 2*i+1] in the low nibble (bits 3:0).
# INT4 values are stored unsigned (0–15) with an offset of 8,
# so the signed value is nibble - 8, giving range [-8, 7].
w_high = ((w_q >> 4) & 0xF).to(torch.int32) - 8 # [N, K//2]
w_low = (w_q & 0xF).to(torch.int32) - 8 # [N, K//2]

# Interleave high and low nibbles to reconstruct [N, K]
w_int = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K]

# Apply group-wise scales: dequantize each group
n_groups = K // group_size
w_groups = w_int.reshape(N, n_groups, group_size).float() # [N, n_groups, group_size]
scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1]
w_dequant = (w_groups * scales_f).reshape(N, K) # [N, K]

# MatMul: x [M, K] @ w_dequant.T [K, N] = y [M, N]
y.copy_((x.float() @ w_dequant.T).half())

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x": (ctypes.POINTER(ctypes.c_uint16), "in"),
"w_q": (ctypes.POINTER(ctypes.c_uint8), "in"),
"scales": (ctypes.POINTER(ctypes.c_uint16), "in"),
"y": (ctypes.POINTER(ctypes.c_uint16), "out"),
"M": (ctypes.c_int, "in"),
"N": (ctypes.c_int, "in"),
"K": (ctypes.c_int, "in"),
"group_size": (ctypes.c_int, "in"),
}

def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False):
device = "cuda"
if zero_x:
x = torch.zeros(M, K, device=device, dtype=torch.float16)
else:
x = torch.randn(M, K, device=device, dtype=torch.float16)
# Random packed INT4 weights: each byte holds two nibbles in [0,15]
w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device)
# Small positive scales to keep magnitudes reasonable
scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01
y = torch.empty(M, N, device=device, dtype=torch.float16)
return {
"x": x,
"w_q": w_q,
"scales": scales,
"y": y,
"M": M,
"N": N,
"K": K,
"group_size": group_size,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
M, N, K, group_size = 2, 4, 4, 2

x = torch.tensor(
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]],
device=device,
dtype=torch.float16,
)
# Packed INT4 weights (high nibble first).
# Row 0: weights [1,1,1,1] → nibbles stored as [9,9,9,9] → bytes [0x99, 0x99] = [153, 153]
# Row 1: weights [2,2,2,2] → nibbles [10,10,10,10] → bytes [0xAA, 0xAA] = [170, 170]
# Row 2: weights [-1,-1,-1,-1] → nibbles [7,7,7,7] → bytes [0x77, 0x77] = [119, 119]
# Row 3: weights [0,0,0,0] → nibbles [8,8,8,8] → bytes [0x88, 0x88] = [136, 136]
w_q = torch.tensor(
[[153, 153], [170, 170], [119, 119], [136, 136]],
dtype=torch.uint8,
device=device,
)
# One scale per group (group_size=2 → 2 groups per row), all 0.5
scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16)
y = torch.empty(M, N, device=device, dtype=torch.float16)

return {
"x": x,
"w_q": w_q,
"scales": scales,
"y": y,
"M": M,
"N": N,
"K": K,
"group_size": group_size,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge cases — tiny K, small group_size
tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True))
tests.append(self._make_test_case(2, 4, 4, 2))
tests.append(self._make_test_case(3, 5, 8, 4))

# Power-of-2 sizes
tests.append(self._make_test_case(16, 16, 32, 16))
tests.append(self._make_test_case(32, 64, 64, 32))
tests.append(self._make_test_case(64, 128, 128, 64))

# Non-power-of-2 sizes
tests.append(self._make_test_case(30, 50, 64, 32))
tests.append(self._make_test_case(100, 200, 128, 64))
tests.append(self._make_test_case(255, 100, 128, 64))

# Realistic LLM inference sizes
tests.append(self._make_test_case(128, 256, 512, 128))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# Typical LLM weight matrix: 4096×4096 with group_size=128
return self._make_test_case(4096, 4096, 4096, 128)
7 changes: 7 additions & 0 deletions challenges/medium/81_int4_matmul/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdint.h>

// x, w_q, scales, y are device pointers
extern "C" void solve(const __half* x, const uint8_t* w_q, const __half* scales, __half* y, int M,
int N, int K, int group_size) {}
Loading
Loading