Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions challenges/medium/85_lora_linear/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
<p>
Implement a LoRA (Low-Rank Adaptation) linear layer forward pass. Given an input matrix
<code>x</code> of shape <code>batch &times; d_in</code>, a base weight matrix <code>W</code> of
shape <code>d_out &times; d_in</code>, a LoRA down-projection matrix <code>A</code> of shape
<code>rank &times; d_in</code>, and a LoRA up-projection matrix <code>B</code> of shape
<code>d_out &times; rank</code>, compute
<code>output = x &times; W<sup>T</sup> + lora_scale &times; (x &times; A<sup>T</sup>) &times; B<sup>T</sup></code>.
All tensors are <code>float32</code>.
</p>

<svg width="680" height="200" viewBox="0 0 680 200" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="680" height="200" fill="#222"/>

<!-- x block -->
<rect x="20" y="70" width="60" height="60" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
<text x="50" y="95" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x</text>
<text x="50" y="112" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_in</text>

<!-- Arrow to W branch -->
<line x1="80" y1="100" x2="110" y2="70" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<!-- Arrow to A branch -->
<line x1="80" y1="100" x2="110" y2="145" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- W block -->
<rect x="112" y="40" width="70" height="55" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
<text x="147" y="63" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">W</text>
<text x="147" y="80" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">D_out&times;D_in</text>

<!-- base output: x@W^T -->
<line x1="182" y1="67" x2="225" y2="90" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="227" y="70" width="80" height="55" fill="#1a4a2a" stroke="#4aff88" stroke-width="1.5"/>
<text x="267" y="92" text-anchor="middle" fill="#ccc" font-size="11" font-family="monospace">x@W&#x1D57;</text>
<text x="267" y="108" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>

<!-- A block -->
<rect x="112" y="128" width="70" height="50" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
<text x="147" y="150" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">A</text>
<text x="147" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">rank&times;D_in</text>

<!-- hidden = x@A^T -->
<line x1="182" y1="153" x2="225" y2="153" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="227" y="130" width="60" height="45" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
<text x="257" y="152" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">x@A&#x1D57;</text>
<text x="257" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;rank</text>

<!-- B block -->
<rect x="304" y="128" width="70" height="50" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
<text x="339" y="150" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">B</text>
<text x="339" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">D_out&times;rank</text>

<!-- arrow from hidden to B -->
<line x1="287" y1="153" x2="302" y2="153" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- delta = (x@A^T)@B^T -->
<line x1="374" y1="153" x2="415" y2="120" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="417" y="95" width="80" height="55" fill="#3a2a1a" stroke="#ffaa44" stroke-width="1.5"/>
<text x="457" y="117" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">&#x3B1;&times;(x@A&#x1D57;)@B&#x1D57;</text>
<text x="457" y="133" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>

<!-- plus sign -->
<line x1="307" y1="97" x2="415" y2="97" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<text x="385" y="88" text-anchor="middle" fill="#ffaa44" font-size="20" font-family="monospace">+</text>

<!-- output -->
<line x1="497" y1="122" x2="535" y2="122" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
<rect x="537" y="95" width="80" height="55" fill="#1a4a2a" stroke="#4aff88" stroke-width="1.5"/>
<text x="577" y="117" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">output</text>
<text x="577" y="135" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>

<defs>
<marker id="arr" markerWidth="6" markerHeight="6" refX="5" refY="3" orient="auto">
<path d="M0,0 L6,3 L0,6 Z" fill="#888"/>
</marker>
</defs>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the <code>solve</code> function; do not change its signature.</li>
<li>Do not use external libraries beyond those provided.</li>
<li>Write the result into <code>output</code>.</li>
</ul>

<h2>Examples</h2>
<p>Example 1:</p>
<p>
\[
x = \begin{bmatrix} 1 & 0 & -1 & 2 \\ 0 & 1 & 1 & -1 \end{bmatrix},\quad
W = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix},\quad
A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix},\quad
B = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix}
\]
</p>
<p>With <code>lora_scale</code> = 0.5:</p>
<p>
\[
\text{output} = x W^T + 0.5 \cdot (x A^T) B^T
= \begin{bmatrix} 1 & 0 & -1 \\ 0 & 1 & 1 \end{bmatrix}
+ 0.5 \cdot \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix}
= \begin{bmatrix} 1.5 & 0 & -1 \\ 0 & 1.5 & 1 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>batch</code> &le; 1,024</li>
<li>1 &le; <code>d_in</code>, <code>d_out</code> &le; 8,192</li>
<li>1 &le; <code>rank</code> &le; 256; <code>rank</code> &lt; min(<code>d_in</code>, <code>d_out</code>)</li>
<li>All tensors are <code>float32</code> on GPU.</li>
<li>Performance is measured with <code>batch</code> = 256, <code>d_in</code> = 4,096, <code>d_out</code> = 4,096, <code>rank</code> = 64</li>
</ul>
171 changes: 171 additions & 0 deletions challenges/medium/85_lora_linear/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="LoRA Linear",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x: torch.Tensor,
W: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
output: torch.Tensor,
batch: int,
d_in: int,
d_out: int,
rank: int,
lora_scale: float,
):
assert x.shape == (batch, d_in)
assert W.shape == (d_out, d_in)
assert A.shape == (rank, d_in)
assert B.shape == (d_out, rank)
assert output.shape == (batch, d_out)
assert x.dtype == W.dtype == A.dtype == B.dtype == output.dtype == torch.float32
assert x.device.type == "cuda"
assert W.device.type == "cuda"
assert A.device.type == "cuda"
assert B.device.type == "cuda"
assert output.device.type == "cuda"

# Base linear: output = x @ W^T
base = torch.mm(x, W.t())

# LoRA path: delta = lora_scale * (x @ A^T) @ B^T
lora_hidden = torch.mm(x, A.t()) # (batch, rank)
delta = torch.mm(lora_hidden, B.t()) # (batch, d_out)

output.copy_(base + lora_scale * delta)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x": (ctypes.POINTER(ctypes.c_float), "in"),
"W": (ctypes.POINTER(ctypes.c_float), "in"),
"A": (ctypes.POINTER(ctypes.c_float), "in"),
"B": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"batch": (ctypes.c_int, "in"),
"d_in": (ctypes.c_int, "in"),
"d_out": (ctypes.c_int, "in"),
"rank": (ctypes.c_int, "in"),
"lora_scale": (ctypes.c_float, "in"),
}

def _make_test_case(self, batch, d_in, d_out, rank, lora_scale=0.5, zero_x=False):
dtype = torch.float32
device = "cuda"
if zero_x:
x = torch.zeros(batch, d_in, device=device, dtype=dtype)
else:
x = torch.randn(batch, d_in, device=device, dtype=dtype)
W = torch.randn(d_out, d_in, device=device, dtype=dtype) * 0.02
A = torch.randn(rank, d_in, device=device, dtype=dtype) * 0.02
B = torch.zeros(d_out, rank, device=device, dtype=dtype)
output = torch.zeros(batch, d_out, device=device, dtype=dtype)
return {
"x": x,
"W": W,
"A": A,
"B": B,
"output": output,
"batch": batch,
"d_in": d_in,
"d_out": d_out,
"rank": rank,
"lora_scale": lora_scale,
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
device = "cuda"
x = torch.tensor([[1.0, 0.0, -1.0, 2.0], [0.0, 1.0, 1.0, -1.0]], device=device, dtype=dtype)
W = torch.tensor(
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
device=device,
dtype=dtype,
)
A = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
B = torch.tensor(
[[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
device=device,
dtype=dtype,
)
output = torch.zeros(2, 3, device=device, dtype=dtype)
return {
"x": x,
"W": W,
"A": A,
"B": B,
"output": output,
"batch": 2,
"d_in": 4,
"d_out": 3,
"rank": 2,
"lora_scale": 0.5,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge case: batch=1, tiny dimensions
tests.append(self._make_test_case(1, 4, 4, 1))

# Edge case: zero input
tests.append(self._make_test_case(2, 8, 8, 2, zero_x=True))

# Edge case: rank=1 (minimum LoRA rank)
tests.append(self._make_test_case(4, 16, 16, 1))

# Power-of-2 dimensions
tests.append(self._make_test_case(16, 64, 64, 8))

# Power-of-2, non-square
tests.append(self._make_test_case(32, 128, 64, 16))

# Non-power-of-2 dimensions
tests.append(self._make_test_case(30, 100, 100, 4))

# Non-power-of-2, mixed
tests.append(self._make_test_case(7, 255, 128, 8))

# Realistic small: LLM feed-forward style
tests.append(self._make_test_case(64, 512, 512, 16, lora_scale=0.125))

# Negative inputs
tests.append(
{
"x": torch.full((4, 32), -1.0, device="cuda", dtype=torch.float32),
"W": torch.randn(32, 32, device="cuda", dtype=torch.float32) * 0.02,
"A": torch.randn(8, 32, device="cuda", dtype=torch.float32) * 0.02,
"B": torch.randn(32, 8, device="cuda", dtype=torch.float32) * 0.02,
"output": torch.zeros(4, 32, device="cuda", dtype=torch.float32),
"batch": 4,
"d_in": 32,
"d_out": 32,
"rank": 8,
"lora_scale": 1.0,
}
)

# Larger realistic: transformer hidden size
tests.append(self._make_test_case(128, 1024, 1024, 32, lora_scale=0.0625))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLaMA-style: d_in=d_out=4096, rank=64, batch=256
return self._make_test_case(256, 4096, 4096, 64, lora_scale=0.015625)
5 changes: 5 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// x, W, A, B, output are device pointers
extern "C" void solve(const float* x, const float* W, const float* A, const float* B, float* output,
int batch, int d_in, int d_out, int rank, float lora_scale) {}
19 changes: 19 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import cutlass
import cutlass.cute as cute


# x, W, A, B, output are tensors on the GPU
@cute.jit
def solve(
x: cute.Tensor,
W: cute.Tensor,
A: cute.Tensor,
B: cute.Tensor,
output: cute.Tensor,
batch: cute.Int32,
d_in: cute.Int32,
d_out: cute.Int32,
rank: cute.Int32,
lora_scale: cute.Float32,
):
pass
19 changes: 19 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import jax
import jax.numpy as jnp


# x, W, A, B are tensors on GPU
@jax.jit
def solve(
x: jax.Array,
W: jax.Array,
A: jax.Array,
B: jax.Array,
batch: int,
d_in: int,
d_out: int,
rank: int,
lora_scale: float,
) -> jax.Array:
# return output tensor directly
pass
18 changes: 18 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from gpu.host import DeviceContext
from memory import UnsafePointer

# x, W, A, B, output are device pointers
@export
def solve(
x: UnsafePointer[Float32],
W: UnsafePointer[Float32],
A: UnsafePointer[Float32],
B: UnsafePointer[Float32],
output: UnsafePointer[Float32],
batch: Int32,
d_in: Int32,
d_out: Int32,
rank: Int32,
lora_scale: Float32,
):
pass
17 changes: 17 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch


# x, W, A, B, output are tensors on the GPU
def solve(
x: torch.Tensor,
W: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
output: torch.Tensor,
batch: int,
d_in: int,
d_out: int,
rank: int,
lora_scale: float,
):
pass
19 changes: 19 additions & 0 deletions challenges/medium/85_lora_linear/starter/starter.triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import triton
import triton.language as tl


# x, W, A, B, output are tensors on the GPU
def solve(
x: torch.Tensor,
W: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
output: torch.Tensor,
batch: int,
d_in: int,
d_out: int,
rank: int,
lora_scale: float,
):
pass
Loading