Skip to content

Commit 7712471

Browse files
committed
Add NVIDIA GPU implementation for add_rms_norm and make residual_out required.
1 parent 2a432b3 commit 7712471

File tree

8 files changed

+355
-134
lines changed

8 files changed

+355
-134
lines changed

python/infinicore/ops/add_rms_norm.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,43 @@
55
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
66
"""
77
Fused Add and RMS Normalization.
8-
8+
99
Args:
1010
a: First input tensor
1111
b: Second input tensor
1212
weight: Scale weights
1313
epsilon: Small constant for numerical stability, default is 1e-5
1414
out: Optional output tuple (y, residual_out) for in-place operation
15-
15+
1616
Returns:
1717
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
1818
The add_result can be used as residual for subsequent layers.
1919
"""
2020
if out is None:
21-
result = _infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon)
21+
result = _infinicore.add_rms_norm(
22+
a._underlying, b._underlying, weight._underlying, epsilon
23+
)
2224
return (Tensor(result[0]), Tensor(result[1]))
23-
25+
2426
y, residual_out = out
25-
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
27+
_infinicore.add_rms_norm_(
28+
y._underlying,
29+
residual_out._underlying,
30+
a._underlying,
31+
b._underlying,
32+
weight._underlying,
33+
epsilon,
34+
)
2635
return (y, residual_out)
2736

2837

2938
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
3039
"""In-place Fused Add and RMS Normalization."""
31-
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
40+
_infinicore.add_rms_norm_(
41+
y._underlying,
42+
residual_out._underlying,
43+
a._underlying,
44+
b._underlying,
45+
weight._underlying,
46+
epsilon,
47+
)

src/infiniop/ops/add_rms_norm/add_rms_norm.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
#define DESCRIPTOR(NAMESPACE) \
88
\
9-
namespace op::add_rms_norm::NAMESPACE { \
10-
class Descriptor final : public InfiniopDescriptor { \
9+
namespace op::add_rms_norm::NAMESPACE { \
10+
class Descriptor final : public InfiniopDescriptor { \
1111
struct Opaque; \
1212
Opaque *_opaque; \
1313
AddRMSNormInfo _info; \
@@ -19,7 +19,7 @@
1919
size_t workspace_size, \
2020
infiniDevice_t device_type, \
2121
int device_id) \
22-
: InfiniopDescriptor{device_type, device_id}, \
22+
: InfiniopDescriptor{device_type, device_id}, \
2323
_opaque(opaque), \
2424
_info(info), \
2525
_workspace_size(workspace_size) {} \
@@ -29,24 +29,24 @@
2929
\
3030
size_t workspaceSize() const { return _workspace_size; } \
3131
\
32-
static infiniStatus_t create( \
32+
static infiniStatus_t create( \
3333
infiniopHandle_t handle, \
3434
Descriptor **desc_ptr, \
35-
infiniopTensorDescriptor_t y_desc, \
35+
infiniopTensorDescriptor_t y_desc, \
3636
infiniopTensorDescriptor_t a_desc, \
3737
infiniopTensorDescriptor_t b_desc, \
38-
infiniopTensorDescriptor_t weight_desc, \
39-
float epsilon, \
40-
infiniopTensorDescriptor_t residual_out_desc); \
38+
infiniopTensorDescriptor_t weight_desc, \
39+
float epsilon, \
40+
infiniopTensorDescriptor_t residual_out_desc); \
4141
\
42-
infiniStatus_t calculate( \
42+
infiniStatus_t calculate( \
4343
void *workspace, size_t workspace_size, \
4444
void *y, \
4545
const void *a, \
4646
const void *b, \
4747
const void *weight, \
4848
void *residual_out, \
49-
void *stream) const; \
49+
void *stream) const; \
5050
}; \
5151
}
5252

src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,13 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
3636
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
3737
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
3838
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
39-
T *residual_out_ptr = info->has_residual_out ?
40-
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
39+
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
4140

4241
// Compute add(a, b) once and store it
4342
T sum_squared = (T)0;
4443
for (size_t k = 0; k < dim; k++) {
4544
T sum_val = a_ptr[k] + b_ptr[k];
46-
if (residual_out_ptr != nullptr) {
47-
residual_out_ptr[k] = sum_val; // Store add result
48-
}
45+
residual_out_ptr[k] = sum_val; // Store add result
4946
sum_squared += sum_val * sum_val;
5047
}
5148

@@ -54,18 +51,9 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
5451
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
5552

5653
// Apply normalization: y = (a + b) * w * rms
57-
// Reuse the stored sum values if residual_out was computed, otherwise recompute
58-
if (residual_out_ptr != nullptr) {
59-
// Reuse stored values
60-
for (size_t k = 0; k < dim; k++) {
61-
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
62-
}
63-
} else {
64-
// Recompute sum
65-
for (size_t k = 0; k < dim; k++) {
66-
T sum_val = a_ptr[k] + b_ptr[k];
67-
y_ptr[k] = sum_val * w[k] * rms;
68-
}
54+
// Reuse stored values from residual_out
55+
for (size_t k = 0; k < dim; k++) {
56+
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
6957
}
7058
}
7159

@@ -90,52 +78,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
9078
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
9179
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
9280
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
93-
T *residual_out_ptr = info->has_residual_out ?
94-
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
81+
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
9582

9683
// Compute sum of squares for RMS normalization and store add result
9784
float sum_squared = 0.0f;
9885
for (size_t k = 0; k < dim; k++) {
9986
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
100-
if (residual_out_ptr != nullptr) {
101-
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
102-
}
87+
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
10388
sum_squared += sum_val * sum_val;
10489
}
10590

10691
// Compute RMS: 1 / (sqrt(sum/dim + eps))
10792
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
10893

10994
// Apply normalization: y = (a + b) * w * rms
110-
// Reuse stored values if residual_out was computed, otherwise recompute
111-
if (residual_out_ptr != nullptr) {
112-
// Reuse stored values
113-
for (size_t k = 0; k < dim; k++) {
114-
float sum_val = utils::cast<float>(residual_out_ptr[k]);
115-
float val;
116-
if constexpr (std::is_same<Tw, float>::value) {
117-
val = sum_val * w[k] * rms;
118-
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
119-
val = sum_val * utils::cast<float>(w[k]) * rms;
120-
} else {
121-
std::abort();
122-
}
123-
y_ptr[k] = utils::cast<T>(val);
124-
}
125-
} else {
126-
// Recompute sum
127-
for (size_t k = 0; k < dim; k++) {
128-
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
129-
float val;
130-
if constexpr (std::is_same<Tw, float>::value) {
131-
val = sum_val * w[k] * rms;
132-
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
133-
val = sum_val * utils::cast<float>(w[k]) * rms;
134-
} else {
135-
std::abort();
136-
}
137-
y_ptr[k] = utils::cast<T>(val);
95+
// Reuse stored values from residual_out
96+
for (size_t k = 0; k < dim; k++) {
97+
float sum_val = utils::cast<float>(residual_out_ptr[k]);
98+
float val;
99+
if constexpr (std::is_same<Tw, float>::value) {
100+
val = sum_val * w[k] * rms;
101+
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
102+
val = sum_val * utils::cast<float>(w[k]) * rms;
103+
} else {
104+
std::abort();
138105
}
106+
y_ptr[k] = utils::cast<T>(val);
139107
}
140108
}
141109

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
2+
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
3+
4+
#include <cub/block/block_reduce.cuh>
5+
6+
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
7+
__device__ void add_rmsnormBlock(
8+
Tdata *__restrict__ y,
9+
Tdata *__restrict__ residual_out,
10+
ptrdiff_t stride_y_batch,
11+
ptrdiff_t stride_y_nhead,
12+
ptrdiff_t stride_residual_out_batch,
13+
ptrdiff_t stride_residual_out_nhead,
14+
const Tdata *__restrict__ a,
15+
ptrdiff_t stride_a_batch,
16+
ptrdiff_t stride_a_nhead,
17+
const Tdata *__restrict__ b,
18+
ptrdiff_t stride_b_batch,
19+
ptrdiff_t stride_b_nhead,
20+
const Tweight *__restrict__ w,
21+
size_t nhead,
22+
size_t dim,
23+
float epsilon) {
24+
// Each block takes care of one head in one batch
25+
// Each thread deals with every block_size element in the row
26+
size_t batch_idx = blockIdx.x / nhead;
27+
size_t head_idx = blockIdx.x % nhead;
28+
29+
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
30+
auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
31+
auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
32+
auto w_ptr = w;
33+
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
34+
35+
// Compute add(a, b) and sum of squares in one pass
36+
Tcompute sum_squared = 0;
37+
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
38+
Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
39+
residual_out_ptr[i] = Tdata(sum_val); // Store add result
40+
sum_squared += sum_val * sum_val;
41+
}
42+
43+
// Block-reduce sum of squares
44+
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
45+
__shared__ typename BlockReduce::TempStorage temp_storage;
46+
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
47+
48+
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
49+
__shared__ Tcompute rms;
50+
if (threadIdx.x == 0) {
51+
rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
52+
}
53+
__syncthreads();
54+
55+
// Apply normalization: y = (a + b) * w * rms
56+
// Reuse stored values from residual_out
57+
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
58+
Tcompute sum_val = Tcompute(residual_out_ptr[i]); // Reuse stored value
59+
y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
60+
}
61+
}
62+
63+
#endif

src/infiniop/ops/add_rms_norm/info.h

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ class AddRMSNormInfo {
3434

3535
auto atype = y_desc->dtype();
3636
auto wtype = weight_desc->dtype();
37-
37+
3838
// Check that all input tensors have the same dtype
3939
if (a_desc->dtype() != atype || b_desc->dtype() != atype) {
4040
return INFINI_STATUS_BAD_TENSOR_DTYPE;
4141
}
42-
42+
4343
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
4444
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
4545
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
@@ -71,52 +71,46 @@ class AddRMSNormInfo {
7171
batch = y_desc->dim(0);
7272
dim = y_desc->dim(1);
7373

74-
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim ||
75-
b_desc->dim(0) != batch || b_desc->dim(1) != dim ||
76-
weight_desc->dim(0) != dim) {
74+
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != dim || weight_desc->dim(0) != dim) {
7775
return INFINI_STATUS_BAD_TENSOR_SHAPE;
7876
}
7977
} else if (y_ndim == 3) {
8078
batch = y_desc->dim(0);
8179
nhead = y_desc->dim(1);
8280
dim = y_desc->dim(2);
8381

84-
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim ||
85-
b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim ||
86-
weight_desc->dim(0) != dim) {
82+
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim || weight_desc->dim(0) != dim) {
8783
return INFINI_STATUS_BAD_TENSOR_SHAPE;
8884
}
8985
} else {
9086
return INFINI_STATUS_BAD_TENSOR_SHAPE;
9187
}
9288

9389
// Check contiguity of the last dimension
94-
if (y_desc->stride(y_ndim - 1) != 1 ||
95-
a_desc->stride(a_ndim - 1) != 1 ||
96-
b_desc->stride(b_ndim - 1) != 1 ||
97-
weight_desc->stride(w_ndim - 1) != 1) {
90+
if (y_desc->stride(y_ndim - 1) != 1 || a_desc->stride(a_ndim - 1) != 1 || b_desc->stride(b_ndim - 1) != 1 || weight_desc->stride(w_ndim - 1) != 1) {
9891
return INFINI_STATUS_BAD_TENSOR_STRIDES;
9992
}
10093

101-
// Check residual_out_desc if provided
102-
bool has_residual_out = (residual_out_desc != nullptr);
103-
if (has_residual_out) {
104-
const size_t residual_out_ndim = residual_out_desc->ndim();
105-
if (residual_out_ndim != y_ndim) {
94+
// residual_out_desc is required (always needed for fused operator)
95+
if (residual_out_desc == nullptr) {
96+
return INFINI_STATUS_BAD_PARAM;
97+
}
98+
99+
const size_t residual_out_ndim = residual_out_desc->ndim();
100+
if (residual_out_ndim != y_ndim) {
101+
return INFINI_STATUS_BAD_TENSOR_SHAPE;
102+
}
103+
if (residual_out_desc->dtype() != atype) {
104+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
105+
}
106+
// Check shape matches
107+
for (size_t i = 0; i < y_ndim; i++) {
108+
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
106109
return INFINI_STATUS_BAD_TENSOR_SHAPE;
107110
}
108-
if (residual_out_desc->dtype() != atype) {
109-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
110-
}
111-
// Check shape matches
112-
for (size_t i = 0; i < y_ndim; i++) {
113-
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
114-
return INFINI_STATUS_BAD_TENSOR_SHAPE;
115-
}
116-
}
117-
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
118-
return INFINI_STATUS_BAD_TENSOR_STRIDES;
119-
}
111+
}
112+
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
113+
return INFINI_STATUS_BAD_TENSOR_STRIDES;
120114
}
121115

122116
AddRMSNormInfo info;
@@ -127,10 +121,8 @@ class AddRMSNormInfo {
127121
info.y_strides = y_desc->strides();
128122
info.a_strides = a_desc->strides();
129123
info.b_strides = b_desc->strides();
130-
info.has_residual_out = has_residual_out;
131-
if (has_residual_out) {
132-
info.residual_out_strides = residual_out_desc->strides();
133-
}
124+
info.has_residual_out = true; // Always true now
125+
info.residual_out_strides = residual_out_desc->strides();
134126
return utils::Result<AddRMSNormInfo>(info);
135127
}
136128
};

0 commit comments

Comments
 (0)