Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 0 additions & 126 deletions challenges/medium/81_int4_matmul/challenge.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,132 +6,6 @@
<code>W</code> is the dequantized float16 weight matrix of shape <code>N &times; K</code>.
</p>

<svg width="700" height="400" viewBox="0 0 700 400" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto; font-family:monospace;">
<rect width="700" height="400" fill="#222" rx="10"/>
<defs>
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#aaa"/>
</marker>
</defs>

<!-- ============================================================ -->
<!-- ROW 1: UNPACK — packed byte → two unsigned nibbles → signed -->
<!-- ============================================================ -->
<text x="18" y="20" fill="#666" font-size="10">STEP 1: UNPACK</text>

<!-- Packed byte -->
<text x="80" y="48" fill="#ccc" font-size="11" text-anchor="middle">w_q[n, i]</text>
<rect x="20" y="56" width="120" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<line x1="80" y1="56" x2="80" y2="88" stroke="#4a9edd" stroke-width="1" stroke-dasharray="3,2"/>
<text x="50" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">hi 7:4</text>
<text x="110" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">lo 3:0</text>

<!-- Arrow right -->
<text x="160" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>

<!-- Unsigned nibbles -->
<rect x="180" y="56" width="50" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<text x="205" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">9</text>
<rect x="236" y="56" width="50" height="32" fill="#1a4a1a" rx="4" stroke="#7ec87e" stroke-width="1.5"/>
<text x="261" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">10</text>

<!-- "- 8" arrow -->
<text x="310" y="77" fill="#ccc" font-size="11" text-anchor="middle">&#x2212; 8</text>
<text x="345" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>

<!-- Signed int4 -->
<rect x="365" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
<text x="390" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
<rect x="421" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
<text x="446" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>

<text x="540" y="77" fill="#888" font-size="10" text-anchor="middle">signed int4 [&#x2212;8, 7]</text>

<!-- ============================================================ -->
<!-- ROW 2: GROUP-WISE SCALING — show K=8, group_size=4 -->
<!-- ============================================================ -->
<text x="18" y="112" fill="#666" font-size="10">STEP 2: DEQUANTIZE (example: one row n, K=8, group_size=4)</text>

<!-- K-axis label -->
<text x="350" y="136" fill="#888" font-size="10" text-anchor="middle">k &#x2192;</text>

<!-- Group 0 bracket + cells -->
<text x="145" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 0: scale[n, 0]</text>
<rect x="58" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="81" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
<rect x="108" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="131" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>
<rect x="158" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="181" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;1</text>
<rect x="208" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="231" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+3</text>
<!-- Group 0 bracket -->
<rect x="56" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- Group 1 bracket + cells -->
<text x="385" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 1: scale[n, 1]</text>
<rect x="298" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="321" y="161" fill="#e0a040" font-size="10" text-anchor="middle">0</text>
<rect x="348" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="371" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;3</text>
<rect x="398" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="421" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+7</text>
<rect x="448" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
<text x="471" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;2</text>
<!-- Group 1 bracket -->
<rect x="296" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>

<!-- "int4" label on left -->
<text x="30" y="161" fill="#e0a040" font-size="9">int4</text>

<!-- Multiply arrows down -->
<text x="156" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 0]</text>
<text x="396" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 1]</text>
<line x1="156" y1="172" x2="156" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>
<line x1="396" y1="172" x2="396" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>

<!-- Dequantized row -->
<text x="30" y="217" fill="#40c080" font-size="9">fp16</text>
<rect x="56" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="156" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 0..3] float16</text>
<rect x="296" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="396" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 4..7] float16</text>

<!-- Formula -->
<text x="275" y="252" fill="#ccc" font-size="10" text-anchor="middle">W[n, k] = (nibble &#x2212; 8) &#xd7; scales[n, k // group_size]</text>

<!-- ============================================================ -->
<!-- ROW 3: MATMUL -->
<!-- ============================================================ -->
<text x="18" y="280" fill="#666" font-size="10">STEP 3: MATMUL</text>

<!-- x box -->
<rect x="60" y="296" width="80" height="60" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
<text x="100" y="322" fill="#4a9edd" font-size="10" text-anchor="middle">x [M&#xd7;K]</text>
<text x="100" y="340" fill="#4a9edd" font-size="9" text-anchor="middle">float16</text>

<!-- multiply sign -->
<text x="162" y="330" fill="#ccc" font-size="16" text-anchor="middle">&#xd7;</text>

<!-- W^T box -->
<rect x="185" y="296" width="100" height="60" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
<text x="235" y="322" fill="#40c080" font-size="10" text-anchor="middle">W&#x1d40; [K&#xd7;N]</text>
<text x="235" y="340" fill="#40c080" font-size="9" text-anchor="middle">float16</text>

<!-- equals sign -->
<text x="310" y="330" fill="#ccc" font-size="16" text-anchor="middle">=</text>

<!-- y output box -->
<rect x="335" y="296" width="90" height="60" fill="#3a1a1a" rx="4" stroke="#e05050" stroke-width="1.5"/>
<text x="380" y="322" fill="#e05050" font-size="10" text-anchor="middle">y [M&#xd7;N]</text>
<text x="380" y="340" fill="#e05050" font-size="9" text-anchor="middle">float16</text>

<!-- Arrow from dequant to W^T -->
<line x1="235" y1="240" x2="235" y2="294" stroke="#40c080" stroke-width="1.5" stroke-dasharray="4,2" marker-end="url(#arr)"/>
<text x="260" y="270" fill="#40c080" font-size="9">dequantized</text>
</svg>

<p>
<strong>Packing format:</strong> Each byte of <code>w_q</code> stores two INT4 weights. The
high nibble (bits 7&ndash;4) holds weight <code>w[n, 2i]</code> and the low nibble (bits
Expand Down
Loading