diff --git a/challenges/medium/78_2d_fft/challenge.html b/challenges/medium/78_2d_fft/challenge.html
new file mode 100644
index 00000000..f73cd49c
--- /dev/null
+++ b/challenges/medium/78_2d_fft/challenge.html
@@ -0,0 +1,59 @@
+
+ Compute the 2D Discrete Fourier Transform (2D DFT) of a complex-valued signal stored on the GPU.
+ Given a 2D complex input signal of shape M × N, compute its 2D DFT spectrum
+ using the row-column decomposition: apply a 1D DFT along each row, then a 1D DFT along each
+ column of the result. All values are 32-bit floating point.
+
+
+Implementation Requirements
+
+ - Use only native features (external libraries are not permitted)
+ - The
solve function signature must remain unchanged
+ - The final result must be stored in
spectrum
+ -
+ The input and output are stored as 1D arrays of interleaved real and imaginary parts in
+ row-major order: element
x[m, n] has its real part at index
+ 2*(m*N + n) and imaginary part at index 2*(m*N + n) + 1
+
+
+
+Example
+
+Input: M = 2, N = 2
+Signal \(x[m, n]\) (real part):
+\[
+\begin{bmatrix}
+1.0 & 0.0 \\
+0.0 & 0.0
+\end{bmatrix}
+\]
+Signal \(x[m, n]\) (imaginary part):
+\[
+\begin{bmatrix}
+0.0 & 0.0 \\
+0.0 & 0.0
+\end{bmatrix}
+\]
+Output:
+Spectrum \(X[k, l]\) (real part):
+\[
+\begin{bmatrix}
+1.0 & 1.0 \\
+1.0 & 1.0
+\end{bmatrix}
+\]
+Spectrum \(X[k, l]\) (imaginary part):
+\[
+\begin{bmatrix}
+0.0 & 0.0 \\
+0.0 & 0.0
+\end{bmatrix}
+\]
+
+
+Constraints
+
+ - 1 ≤
M, N ≤ 4096
+ - Signal values are 32-bit floating point (real and imaginary parts)
+ - Performance is measured with
M = 2,048, N = 2,048
+
diff --git a/challenges/medium/78_2d_fft/challenge.py b/challenges/medium/78_2d_fft/challenge.py
new file mode 100644
index 00000000..d17cc21a
--- /dev/null
+++ b/challenges/medium/78_2d_fft/challenge.py
@@ -0,0 +1,93 @@
+import ctypes
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="2D FFT",
+ atol=1e-02,
+ rtol=1e-02,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(self, signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
+ assert signal.shape == (M * N * 2,)
+ assert spectrum.shape == (M * N * 2,)
+ assert signal.dtype == torch.float32
+ assert spectrum.dtype == torch.float32
+ assert signal.device == spectrum.device
+
+ sig_ri = signal.view(M, N, 2)
+ sig_c = torch.complex(sig_ri[..., 0].contiguous(), sig_ri[..., 1].contiguous())
+ spec_c = torch.fft.fft2(sig_c)
+ spec_ri = torch.stack((spec_c.real, spec_c.imag), dim=-1).contiguous()
+ spectrum.copy_(spec_ri.view(-1))
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "signal": (ctypes.POINTER(ctypes.c_float), "in"),
+ "spectrum": (ctypes.POINTER(ctypes.c_float), "out"),
+ "M": (ctypes.c_int, "in"),
+ "N": (ctypes.c_int, "in"),
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ M, N = 2, 2
+ signal = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=dtype)
+ spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
+ return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ dtype = torch.float32
+ cases = []
+
+ def make_case(M, N, low=-1.0, high=1.0):
+ signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).uniform_(low, high)
+ spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
+ return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
+
+ def make_zero_case(M, N):
+ signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
+ spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
+ return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
+
+ def make_impulse_case(M, N):
+ signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
+ signal[0] = 1.0
+ spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
+ return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
+
+ # Edge cases: small sizes
+ cases.append(make_impulse_case(1, 1))
+ cases.append(make_zero_case(2, 2))
+ cases.append(make_case(1, 4))
+
+ # Power-of-2 sizes
+ cases.append(make_case(16, 16))
+ cases.append(make_case(32, 64))
+
+ # Non-power-of-2 sizes
+ cases.append(make_case(3, 5))
+ cases.append(make_case(30, 30))
+
+ # Mixed positive/negative values
+ cases.append(make_case(100, 200, low=-5.0, high=5.0))
+
+ # Realistic sizes
+ cases.append(make_case(256, 256))
+ cases.append(make_case(512, 512))
+
+ return cases
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ M, N = 2048, 2048
+ signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).normal_(0.0, 1.0)
+ spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
+ return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
diff --git a/challenges/medium/78_2d_fft/starter/starter.cu b/challenges/medium/78_2d_fft/starter/starter.cu
new file mode 100644
index 00000000..852e3301
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.cu
@@ -0,0 +1,4 @@
+#include
+
+// signal, spectrum are device pointers
+extern "C" void solve(const float* signal, float* spectrum, int M, int N) {}
diff --git a/challenges/medium/78_2d_fft/starter/starter.cute.py b/challenges/medium/78_2d_fft/starter/starter.cute.py
new file mode 100644
index 00000000..1be0c7c7
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.cute.py
@@ -0,0 +1,8 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# signal, spectrum are tensors on the GPU
+@cute.jit
+def solve(signal: cute.Tensor, spectrum: cute.Tensor, M: cute.Int32, N: cute.Int32):
+ pass
diff --git a/challenges/medium/78_2d_fft/starter/starter.jax.py b/challenges/medium/78_2d_fft/starter/starter.jax.py
new file mode 100644
index 00000000..d00f2299
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.jax.py
@@ -0,0 +1,9 @@
+import jax
+import jax.numpy as jnp
+
+
+# signal is a tensor on GPU
+@jax.jit
+def solve(signal: jax.Array, M: int, N: int) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/78_2d_fft/starter/starter.mojo b/challenges/medium/78_2d_fft/starter/starter.mojo
new file mode 100644
index 00000000..e7c944b6
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# signal, spectrum are device pointers
+@export
+def solve(signal: UnsafePointer[Float32], spectrum: UnsafePointer[Float32], M: Int32, N: Int32):
+ pass
diff --git a/challenges/medium/78_2d_fft/starter/starter.pytorch.py b/challenges/medium/78_2d_fft/starter/starter.pytorch.py
new file mode 100644
index 00000000..66070e0a
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.pytorch.py
@@ -0,0 +1,6 @@
+import torch
+
+
+# signal, spectrum are tensors on the GPU
+def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
+ pass
diff --git a/challenges/medium/78_2d_fft/starter/starter.triton.py b/challenges/medium/78_2d_fft/starter/starter.triton.py
new file mode 100644
index 00000000..d8843c46
--- /dev/null
+++ b/challenges/medium/78_2d_fft/starter/starter.triton.py
@@ -0,0 +1,8 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# signal, spectrum are tensors on the GPU
+def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
+ pass