Skip to content

Commit 5e2eff8

Browse files
abaybektursunclaude
andcommitted
Triple Loop + Fused Kernels + Parallel Residuals + N-gram Tilt — val_bpb 1.08014
5-seed mean 1.08014 BPB (std=0.0004), best seed 1.07971. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9d070df commit 5e2eff8

File tree

16 files changed

+4005
-0
lines changed

16 files changed

+4005
-0
lines changed

records/track_10min_16mb/2026-04-06_TripleLoop_FusedKernels_Ngram/README.md

Lines changed: 422 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply
2+
// Computes: dpre = (go @ down_w.T) * act_grad
3+
// Where act_grad = f'(pre) is pre-computed in the forward pass.
4+
//
5+
// Layout convention:
6+
// go: (M, K) bf16 row-major
7+
// down_w: (K, N) bf16 row-major — CUTLASS B(N,K) with RowMajor layout
8+
// act_grad: (M, N) bf16 row-major
9+
// dpre: (M, N) bf16 row-major output
10+
11+
#include "cutlass/cutlass.h"
12+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
13+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
14+
#include "cutlass/gemm/collective/collective_builder.hpp"
15+
#include "cutlass/epilogue/collective/collective_builder.hpp"
16+
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
17+
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
18+
#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"
19+
#include "cute/tensor.hpp"
20+
#include "cutlass/util/packed_stride.hpp"
21+
#include <iostream>
22+
23+
using namespace cute;
24+
25+
// --- Type aliases ---
26+
27+
using ElementAcc = float;
28+
using ElementCompute = float;
29+
using ElementOutput = cutlass::bfloat16_t;
30+
using ElementAux = cutlass::bfloat16_t;
31+
32+
using namespace cutlass::epilogue::fusion;
33+
34+
// --- Tile / schedule configuration ---
35+
36+
using TileShape = Shape<_128, _256, _64>;
37+
using ClusterShape = Shape<_1, _1, _1>;
38+
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
39+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
40+
41+
// --- Resolve AuxLoad types via EpilogueDescriptor ---
42+
43+
using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor<
44+
TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>;
45+
46+
using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<
47+
EpiDesc, cutlass::layout::RowMajor, ElementAux>;
48+
49+
// --- EVT tree: acc * aux_load (builtin multiply) ---
50+
51+
using AuxLoad = Sm90AuxLoad<
52+
AuxDesc::Stages,
53+
typename EpiDesc::EpilogueTile,
54+
typename AuxDesc::Element,
55+
typename AuxDesc::Stride,
56+
typename AuxDesc::SmemLayoutAtom,
57+
typename AuxDesc::CopyOpS2R>;
58+
59+
// Compute node: builtin multiply(acc, act_grad)
60+
using Compute = Sm90Compute<
61+
cutlass::multiplies,
62+
ElementOutput,
63+
ElementCompute,
64+
cutlass::FloatRoundStyle::round_to_nearest>;
65+
66+
// Tree: root = Multiply(child0 = AccFetch, child1 = AuxLoad)
67+
using EVT = Sm90EVT<Compute, Sm90AccFetch, AuxLoad>;
68+
69+
// --- CollectiveBuilder + Kernel type ---
70+
71+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
72+
cutlass::arch::Sm90,
73+
cutlass::arch::OpClassTensorOp,
74+
TileShape,
75+
ClusterShape,
76+
EpilogueTile,
77+
ElementAcc, ElementCompute,
78+
ElementOutput, cutlass::layout::RowMajor, /* AlignC */ 8,
79+
ElementOutput, cutlass::layout::RowMajor, /* AlignD */ 8,
80+
EpilogueSchedule,
81+
EVT
82+
>::CollectiveOp;
83+
84+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
85+
cutlass::arch::Sm90,
86+
cutlass::arch::OpClassTensorOp,
87+
ElementOutput, cutlass::layout::RowMajor, /* AlignA */ 8,
88+
ElementOutput, cutlass::layout::RowMajor, /* AlignB */ 8,
89+
ElementAcc,
90+
TileShape, ClusterShape,
91+
cutlass::gemm::collective::StageCountAutoCarveout<
92+
sizeof(typename CollectiveEpilogue::SharedStorage)>,
93+
cutlass::gemm::KernelTmaWarpSpecializedCooperative
94+
>::CollectiveOp;
95+
96+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
97+
Shape<int, int, int, int>,
98+
CollectiveMainloop,
99+
CollectiveEpilogue>;
100+
101+
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
102+
103+
// --- Host launcher ---
104+
105+
void launch_gemm_mul(
106+
void const* ptr_go, // (M, K) bf16 row-major
107+
void const* ptr_down_w, // (K, N) bf16 row-major = RowMajor B(N,K) for CUTLASS
108+
void const* ptr_act_grad, // (M, N) bf16 row-major
109+
void* ptr_dpre, // (M, N) bf16 row-major output
110+
int M, int N, int K,
111+
cudaStream_t stream)
112+
{
113+
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
114+
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
115+
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
116+
117+
int L = 1;
118+
auto prob_shape = make_shape(M, N, K, L);
119+
120+
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
121+
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
122+
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
123+
auto stride_Aux = cutlass::make_cute_packed_stride(
124+
typename AuxDesc::Stride{}, cute::make_shape(M, N, L));
125+
126+
typename EVT::Arguments evt_args {
127+
{}, // Sm90AccFetch: no args
128+
{ // Sm90AuxLoad: pointer + null_default + stride
129+
static_cast<ElementAux const*>(ptr_act_grad),
130+
ElementAux(0),
131+
stride_Aux
132+
},
133+
{} // Sm90Compute (multiplies): no args
134+
};
135+
136+
typename GemmOp::Arguments args {
137+
cutlass::gemm::GemmUniversalMode::kGemm,
138+
prob_shape,
139+
{ // Mainloop
140+
static_cast<ElementOutput const*>(ptr_go),
141+
stride_A,
142+
static_cast<ElementOutput const*>(ptr_down_w),
143+
stride_B,
144+
},
145+
{ // Epilogue: {thread_args, ptr_C, stride_C, ptr_D, stride_D}
146+
evt_args,
147+
static_cast<ElementOutput const*>(ptr_dpre), // ptr_C (unused but TMA needs valid ptr)
148+
stride_C,
149+
static_cast<ElementOutput*>(ptr_dpre), // ptr_D (output)
150+
stride_C,
151+
}
152+
};
153+
154+
GemmOp gemm_op;
155+
size_t workspace_size = GemmOp::get_workspace_size(args);
156+
void* workspace = nullptr;
157+
if (workspace_size > 0) {
158+
cudaMalloc(&workspace, workspace_size);
159+
}
160+
161+
auto status = gemm_op.initialize(args, workspace, stream);
162+
if (status != cutlass::Status::kSuccess) {
163+
std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl;
164+
if (workspace) cudaFree(workspace);
165+
exit(EXIT_FAILURE);
166+
}
167+
168+
status = gemm_op.run(stream);
169+
if (status != cutlass::Status::kSuccess) {
170+
cudaError_t cuda_err = cudaStreamSynchronize(stream);
171+
std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status)
172+
<< " CUDA: " << cudaGetErrorString(cuda_err) << std::endl;
173+
if (workspace) cudaFree(workspace);
174+
exit(EXIT_FAILURE);
175+
}
176+
177+
if (workspace) cudaFree(workspace);
178+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// PyTorch C++ extension: CUTLASS EVT fused GEMM * elementwise multiply
2+
// dpre = (go @ down_w.T) * act_grad
3+
// Pass down_w directly (K, N) — NOT down_w.T.contiguous()
4+
5+
#include <torch/extension.h>
6+
#include <c10/cuda/CUDAStream.h>
7+
8+
void launch_gemm_mul(
9+
void const*, void const*, void const*, void*, int, int, int, cudaStream_t);
10+
11+
at::Tensor gemm_mul(at::Tensor go, at::Tensor down_w, at::Tensor act_grad) {
12+
TORCH_CHECK(go.is_cuda() && go.is_contiguous());
13+
TORCH_CHECK(down_w.is_cuda() && down_w.is_contiguous());
14+
TORCH_CHECK(act_grad.is_cuda() && act_grad.is_contiguous());
15+
TORCH_CHECK(go.scalar_type() == at::kBFloat16);
16+
TORCH_CHECK(down_w.scalar_type() == at::kBFloat16);
17+
TORCH_CHECK(act_grad.scalar_type() == at::kBFloat16);
18+
19+
int M = go.size(0);
20+
int K = go.size(1);
21+
int N = down_w.size(1); // down_w is (K, N) row-major
22+
23+
TORCH_CHECK(down_w.size(0) == K,
24+
"K mismatch: go has K=", K, " but down_w has size(0)=", down_w.size(0));
25+
TORCH_CHECK(act_grad.size(0) == M && act_grad.size(1) == N,
26+
"act_grad shape must be (M, N), got (", act_grad.size(0), ", ", act_grad.size(1), ")");
27+
28+
at::Tensor dpre = at::empty({M, N}, go.options());
29+
30+
launch_gemm_mul(
31+
go.data_ptr(), down_w.data_ptr(), act_grad.data_ptr(), dpre.data_ptr(),
32+
M, N, K,
33+
at::cuda::getCurrentCUDAStream());
34+
35+
return dpre;
36+
}
37+
38+
TORCH_LIBRARY(cutlass_evt, m) {
39+
m.def("gemm_mul(Tensor go, Tensor down_w, Tensor act_grad) -> Tensor");
40+
}
41+
42+
TORCH_LIBRARY_IMPL(cutlass_evt, CUDA, m) {
43+
m.impl("gemm_mul", &gemm_mul);
44+
}
45+
46+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from setuptools import setup
2+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3+
import os
4+
5+
CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/opt/cutlass")
6+
7+
setup(
8+
name="cutlass_evt_fusion",
9+
ext_modules=[
10+
CUDAExtension(
11+
name="cutlass_evt_fusion",
12+
sources=[
13+
"csrc/gemm_act_grad.cu",
14+
"csrc/torch_binding.cpp",
15+
],
16+
include_dirs=[
17+
f"{CUTLASS_PATH}/include",
18+
f"{CUTLASS_PATH}/tools/util/include",
19+
],
20+
extra_compile_args={
21+
"nvcc": [
22+
"-std=c++17",
23+
"-arch=sm_90a",
24+
"-O3",
25+
"--use_fast_math",
26+
"--expt-relaxed-constexpr",
27+
"-DNDEBUG",
28+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
29+
],
30+
},
31+
),
32+
],
33+
cmdclass={"build_ext": BuildExtension},
34+
)

0 commit comments

Comments
 (0)