Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions challenges/medium/78_2d_fft/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<p>
Compute the 2D Discrete Fourier Transform (2D DFT) of a complex-valued signal stored on the GPU.
Given a 2D complex input signal of shape <code>M &times; N</code>, compute its 2D DFT spectrum
using the row-column decomposition: apply a 1D DFT along each row, then a 1D DFT along each
column of the result. All values are 32-bit floating point.
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in <code>spectrum</code></li>
<li>
The input and output are stored as 1D arrays of interleaved real and imaginary parts in
row-major order: element <code>x[m, n]</code> has its real part at index
<code>2*(m*N + n)</code> and imaginary part at index <code>2*(m*N + n) + 1</code>
</li>
</ul>

<h2>Example</h2>
<p>
Input: <code>M</code> = 2, <code>N</code> = 2<br>
Signal \(x[m, n]\) (real part):
\[
\begin{bmatrix}
1.0 & 0.0 \\
0.0 & 0.0
\end{bmatrix}
\]
Signal \(x[m, n]\) (imaginary part):
\[
\begin{bmatrix}
0.0 & 0.0 \\
0.0 & 0.0
\end{bmatrix}
\]
Output:<br>
Spectrum \(X[k, l]\) (real part):
\[
\begin{bmatrix}
1.0 & 1.0 \\
1.0 & 1.0
\end{bmatrix}
\]
Spectrum \(X[k, l]\) (imaginary part):
\[
\begin{bmatrix}
0.0 & 0.0 \\
0.0 & 0.0
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>M</code>, <code>N</code> &le; 4096</li>
<li>Signal values are 32-bit floating point (real and imaginary parts)</li>
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 2,048</li>
</ul>
93 changes: 93 additions & 0 deletions challenges/medium/78_2d_fft/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="2D FFT",
atol=1e-02,
rtol=1e-02,
num_gpus=1,
access_tier="free",
)

def reference_impl(self, signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
assert signal.shape == (M * N * 2,)
assert spectrum.shape == (M * N * 2,)
assert signal.dtype == torch.float32
assert spectrum.dtype == torch.float32
assert signal.device == spectrum.device

sig_ri = signal.view(M, N, 2)
sig_c = torch.complex(sig_ri[..., 0].contiguous(), sig_ri[..., 1].contiguous())
spec_c = torch.fft.fft2(sig_c)
spec_ri = torch.stack((spec_c.real, spec_c.imag), dim=-1).contiguous()
spectrum.copy_(spec_ri.view(-1))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"signal": (ctypes.POINTER(ctypes.c_float), "in"),
"spectrum": (ctypes.POINTER(ctypes.c_float), "out"),
"M": (ctypes.c_int, "in"),
"N": (ctypes.c_int, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
M, N = 2, 2
signal = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=dtype)
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}

def generate_functional_test(self) -> List[Dict[str, Any]]:
dtype = torch.float32
cases = []

def make_case(M, N, low=-1.0, high=1.0):
signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).uniform_(low, high)
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}

def make_zero_case(M, N):
signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}

def make_impulse_case(M, N):
signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
signal[0] = 1.0
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}

# Edge cases: small sizes
cases.append(make_impulse_case(1, 1))
cases.append(make_zero_case(2, 2))
cases.append(make_case(1, 4))

# Power-of-2 sizes
cases.append(make_case(16, 16))
cases.append(make_case(32, 64))

# Non-power-of-2 sizes
cases.append(make_case(3, 5))
cases.append(make_case(30, 30))

# Mixed positive/negative values
cases.append(make_case(100, 200, low=-5.0, high=5.0))

# Realistic sizes
cases.append(make_case(256, 256))
cases.append(make_case(512, 512))

return cases

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
M, N = 2048, 2048
signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).normal_(0.0, 1.0)
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
4 changes: 4 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <cuda_runtime.h>

// signal, spectrum are device pointers
extern "C" void solve(const float* signal, float* spectrum, int M, int N) {}
8 changes: 8 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import cutlass
import cutlass.cute as cute


# signal, spectrum are tensors on the GPU
@cute.jit
def solve(signal: cute.Tensor, spectrum: cute.Tensor, M: cute.Int32, N: cute.Int32):
pass
9 changes: 9 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax
import jax.numpy as jnp


# signal is a tensor on GPU
@jax.jit
def solve(signal: jax.Array, M: int, N: int) -> jax.Array:
# return output tensor directly
pass
9 changes: 9 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from gpu.host import DeviceContext
from gpu.id import block_dim, block_idx, thread_idx
from memory import UnsafePointer
from math import ceildiv

# signal, spectrum are device pointers
@export
def solve(signal: UnsafePointer[Float32], spectrum: UnsafePointer[Float32], M: Int32, N: Int32):
pass
6 changes: 6 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


# signal, spectrum are tensors on the GPU
def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
pass
8 changes: 8 additions & 0 deletions challenges/medium/78_2d_fft/starter/starter.triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch
import triton
import triton.language as tl


# signal, spectrum are tensors on the GPU
def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
pass
Loading