diff --git a/challenges/medium/85_lora_linear/challenge.html b/challenges/medium/85_lora_linear/challenge.html new file mode 100644 index 00000000..9950c4a1 --- /dev/null +++ b/challenges/medium/85_lora_linear/challenge.html @@ -0,0 +1,111 @@ +

+ Implement a LoRA (Low-Rank Adaptation) linear layer forward pass. Given an input matrix + x of shape batch × d_in, a base weight matrix W of + shape d_out × d_in, a LoRA down-projection matrix A of shape + rank × d_in, and a LoRA up-projection matrix B of shape + d_out × rank, compute + output = x × WT + lora_scale × (x × AT) × BT. + All tensors are float32. +

+ + + + + + + x + B×D_in + + + + + + + + + W + D_out×D_in + + + + + x@Wᵗ + B×D_out + + + + A + rank×D_in + + + + + x@Aᵗ + B×rank + + + + B + D_out×rank + + + + + + + + α×(x@Aᵗ)@Bᵗ + B×D_out + + + + + + + + + + output + B×D_out + + + + + + + + +

Implementation Requirements

+ + +

Examples

+

Example 1:

+

+\[ +x = \begin{bmatrix} 1 & 0 & -1 & 2 \\ 0 & 1 & 1 & -1 \end{bmatrix},\quad +W = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix},\quad +A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix},\quad +B = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix} +\] +

+

With lora_scale = 0.5:

+

+\[ +\text{output} = x W^T + 0.5 \cdot (x A^T) B^T += \begin{bmatrix} 1 & 0 & -1 \\ 0 & 1 & 1 \end{bmatrix} ++ 0.5 \cdot \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} += \begin{bmatrix} 1.5 & 0 & -1 \\ 0 & 1.5 & 1 \end{bmatrix} +\] +

+ +

Constraints

+ diff --git a/challenges/medium/85_lora_linear/challenge.py b/challenges/medium/85_lora_linear/challenge.py new file mode 100644 index 00000000..5c4c8c98 --- /dev/null +++ b/challenges/medium/85_lora_linear/challenge.py @@ -0,0 +1,171 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="LoRA Linear", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + x: torch.Tensor, + W: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + output: torch.Tensor, + batch: int, + d_in: int, + d_out: int, + rank: int, + lora_scale: float, + ): + assert x.shape == (batch, d_in) + assert W.shape == (d_out, d_in) + assert A.shape == (rank, d_in) + assert B.shape == (d_out, rank) + assert output.shape == (batch, d_out) + assert x.dtype == W.dtype == A.dtype == B.dtype == output.dtype == torch.float32 + assert x.device.type == "cuda" + assert W.device.type == "cuda" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert output.device.type == "cuda" + + # Base linear: output = x @ W^T + base = torch.mm(x, W.t()) + + # LoRA path: delta = lora_scale * (x @ A^T) @ B^T + lora_hidden = torch.mm(x, A.t()) # (batch, rank) + delta = torch.mm(lora_hidden, B.t()) # (batch, d_out) + + output.copy_(base + lora_scale * delta) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "x": (ctypes.POINTER(ctypes.c_float), "in"), + "W": (ctypes.POINTER(ctypes.c_float), "in"), + "A": (ctypes.POINTER(ctypes.c_float), "in"), + "B": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "batch": (ctypes.c_int, "in"), + "d_in": (ctypes.c_int, "in"), + "d_out": (ctypes.c_int, "in"), + "rank": (ctypes.c_int, "in"), + "lora_scale": (ctypes.c_float, "in"), + } + + def _make_test_case(self, batch, d_in, d_out, rank, lora_scale=0.5, zero_x=False): + dtype = torch.float32 + device = "cuda" + if zero_x: + x = torch.zeros(batch, d_in, device=device, dtype=dtype) + else: + x = torch.randn(batch, d_in, device=device, dtype=dtype) + W = torch.randn(d_out, d_in, device=device, dtype=dtype) * 0.02 + A = torch.randn(rank, d_in, device=device, dtype=dtype) * 0.02 + B = torch.zeros(d_out, rank, device=device, dtype=dtype) + output = torch.zeros(batch, d_out, device=device, dtype=dtype) + return { + "x": x, + "W": W, + "A": A, + "B": B, + "output": output, + "batch": batch, + "d_in": d_in, + "d_out": d_out, + "rank": rank, + "lora_scale": lora_scale, + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + device = "cuda" + x = torch.tensor([[1.0, 0.0, -1.0, 2.0], [0.0, 1.0, 1.0, -1.0]], device=device, dtype=dtype) + W = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], + device=device, + dtype=dtype, + ) + A = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype) + B = torch.tensor( + [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]], + device=device, + dtype=dtype, + ) + output = torch.zeros(2, 3, device=device, dtype=dtype) + return { + "x": x, + "W": W, + "A": A, + "B": B, + "output": output, + "batch": 2, + "d_in": 4, + "d_out": 3, + "rank": 2, + "lora_scale": 0.5, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: batch=1, tiny dimensions + tests.append(self._make_test_case(1, 4, 4, 1)) + + # Edge case: zero input + tests.append(self._make_test_case(2, 8, 8, 2, zero_x=True)) + + # Edge case: rank=1 (minimum LoRA rank) + tests.append(self._make_test_case(4, 16, 16, 1)) + + # Power-of-2 dimensions + tests.append(self._make_test_case(16, 64, 64, 8)) + + # Power-of-2, non-square + tests.append(self._make_test_case(32, 128, 64, 16)) + + # Non-power-of-2 dimensions + tests.append(self._make_test_case(30, 100, 100, 4)) + + # Non-power-of-2, mixed + tests.append(self._make_test_case(7, 255, 128, 8)) + + # Realistic small: LLM feed-forward style + tests.append(self._make_test_case(64, 512, 512, 16, lora_scale=0.125)) + + # Negative inputs + tests.append( + { + "x": torch.full((4, 32), -1.0, device="cuda", dtype=torch.float32), + "W": torch.randn(32, 32, device="cuda", dtype=torch.float32) * 0.02, + "A": torch.randn(8, 32, device="cuda", dtype=torch.float32) * 0.02, + "B": torch.randn(32, 8, device="cuda", dtype=torch.float32) * 0.02, + "output": torch.zeros(4, 32, device="cuda", dtype=torch.float32), + "batch": 4, + "d_in": 32, + "d_out": 32, + "rank": 8, + "lora_scale": 1.0, + } + ) + + # Larger realistic: transformer hidden size + tests.append(self._make_test_case(128, 1024, 1024, 32, lora_scale=0.0625)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLaMA-style: d_in=d_out=4096, rank=64, batch=256 + return self._make_test_case(256, 4096, 4096, 64, lora_scale=0.015625) diff --git a/challenges/medium/85_lora_linear/starter/starter.cu b/challenges/medium/85_lora_linear/starter/starter.cu new file mode 100644 index 00000000..93ac644b --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// x, W, A, B, output are device pointers +extern "C" void solve(const float* x, const float* W, const float* A, const float* B, float* output, + int batch, int d_in, int d_out, int rank, float lora_scale) {} diff --git a/challenges/medium/85_lora_linear/starter/starter.cute.py b/challenges/medium/85_lora_linear/starter/starter.cute.py new file mode 100644 index 00000000..e79e3eb0 --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.cute.py @@ -0,0 +1,19 @@ +import cutlass +import cutlass.cute as cute + + +# x, W, A, B, output are tensors on the GPU +@cute.jit +def solve( + x: cute.Tensor, + W: cute.Tensor, + A: cute.Tensor, + B: cute.Tensor, + output: cute.Tensor, + batch: cute.Int32, + d_in: cute.Int32, + d_out: cute.Int32, + rank: cute.Int32, + lora_scale: cute.Float32, +): + pass diff --git a/challenges/medium/85_lora_linear/starter/starter.jax.py b/challenges/medium/85_lora_linear/starter/starter.jax.py new file mode 100644 index 00000000..14bbd946 --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.jax.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp + + +# x, W, A, B are tensors on GPU +@jax.jit +def solve( + x: jax.Array, + W: jax.Array, + A: jax.Array, + B: jax.Array, + batch: int, + d_in: int, + d_out: int, + rank: int, + lora_scale: float, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/85_lora_linear/starter/starter.mojo b/challenges/medium/85_lora_linear/starter/starter.mojo new file mode 100644 index 00000000..31594f76 --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.mojo @@ -0,0 +1,18 @@ +from gpu.host import DeviceContext +from memory import UnsafePointer + +# x, W, A, B, output are device pointers +@export +def solve( + x: UnsafePointer[Float32], + W: UnsafePointer[Float32], + A: UnsafePointer[Float32], + B: UnsafePointer[Float32], + output: UnsafePointer[Float32], + batch: Int32, + d_in: Int32, + d_out: Int32, + rank: Int32, + lora_scale: Float32, +): + pass diff --git a/challenges/medium/85_lora_linear/starter/starter.pytorch.py b/challenges/medium/85_lora_linear/starter/starter.pytorch.py new file mode 100644 index 00000000..01208973 --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.pytorch.py @@ -0,0 +1,17 @@ +import torch + + +# x, W, A, B, output are tensors on the GPU +def solve( + x: torch.Tensor, + W: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + output: torch.Tensor, + batch: int, + d_in: int, + d_out: int, + rank: int, + lora_scale: float, +): + pass diff --git a/challenges/medium/85_lora_linear/starter/starter.triton.py b/challenges/medium/85_lora_linear/starter/starter.triton.py new file mode 100644 index 00000000..1e689ea6 --- /dev/null +++ b/challenges/medium/85_lora_linear/starter/starter.triton.py @@ -0,0 +1,19 @@ +import torch +import triton +import triton.language as tl + + +# x, W, A, B, output are tensors on the GPU +def solve( + x: torch.Tensor, + W: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + output: torch.Tensor, + batch: int, + d_in: int, + d_out: int, + rank: int, + lora_scale: float, +): + pass