Skip to content

Commit

Permalink
layer norm support multistage
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold committed Sep 21, 2023
1 parent b93d13f commit c055d70
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 7 deletions.
205 changes: 205 additions & 0 deletions oneflow/core/cuda/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#define ONEFLOW_CORE_CUDA_LAYER_NORM_H_

#include <cub/cub.cuh>
#include "oneflow/core/device/cuda_util.h"
#include <math_constants.h>
#include <assert.h>

Expand Down Expand Up @@ -864,6 +865,202 @@ inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockPartialAffineImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon,
ComputeType* global_mean, ComputeType* global_m2,
ComputeType* global_count) {
using LoadType = typename LOAD::LoadType;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;

for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
for (int pack_id = threadIdx.x; pack_id < num_packs; pack_id += block_size) {
ComputeType row_mean = global_mean[blockIdx.x];
ComputeType row_variance =
max(Div(global_m2[blockIdx.x], global_count[blockIdx.x]), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
LoadType pack[pack_size];
ComputeType dst_pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
dst_pack[i] = (static_cast<ComputeType>(pack[i]) - row_mean) * row_inv_var;
}
store.template store<pack_size>(dst_pack, row, pack_id * pack_size);
}
}
}

template<typename ComputeType>
__global__ void WelfordGlobalAllReduce(int64_t rows, int64_t cols, ComputeType* global_mean,
ComputeType* global_m2, ComputeType* global_count) {
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
int idx = row * cols + i;
thread_mean = global_mean[idx];
thread_m2 = global_m2[idx];
thread_count = global_count[idx];
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
__syncthreads();
printf("blockIdx.x: %d, blockIdx.y: %d, thread: %d, thread mean: %f, thread m2: %f, thread "
"count: %f\n",
static_cast<int>(row), static_cast<int>(blockIdx.y), static_cast<int>(threadIdx.x),
static_cast<float>(thread_mean), static_cast<float>(thread_m2),
static_cast<float>(thread_count));
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
printf("blockIdx.x: %d, blockIdx.y: %d, thread: %d, row mean: %f, row m2: %f, row "
"count: %f\n",
static_cast<int>(row), static_cast<int>(blockIdx.y), static_cast<int>(threadIdx.x),
static_cast<float>(row_mean), static_cast<float>(row_m2), static_cast<float>(row_count));
if (threadIdx.x == 0) {
// printf(
// "blockIdx.x: %d, blockIdx.y: %d, thread: %d, row mean: %f, row m2: %f, row count:
// %f\n", static_cast<int>(row), static_cast<int>(blockIdx.y),
// static_cast<int>(threadIdx.x), static_cast<float>(row_mean),
// static_cast<float>(row_m2), static_cast<float>(row_count));
int idx = row * cols;
*(global_mean + idx) = row_mean;
*(global_m2 + idx) = row_m2;
*(global_count + idx) = row_count;
}
}
}

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
int col_width>
__global__ void LayerNormBlockPartialImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, ComputeType* global_mean,
ComputeType* global_m2, ComputeType* global_count) {
using LoadType = typename LOAD::LoadType;
assert(col_width % pack_size == 0);

const int num_packs = min(col_width, static_cast<int>(cols - blockIdx.y * col_width)) / pack_size;

for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = threadIdx.x; pack_id < num_packs; pack_id += block_size) {
LoadType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size + blockIdx.y * col_width);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(static_cast<ComputeType>(pack[i]), &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
__syncthreads();

// TODO: bug here
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
if (threadIdx.x == 0) {
int idx = row * gridDim.y + blockIdx.y;
printf("block: %d, thread: %d, idx: %d, row mean: %f, row m2: %f, row count: %f\n",
static_cast<int>(row), static_cast<int>(threadIdx.x), idx,
static_cast<float>(row_mean), static_cast<float>(row_m2),
static_cast<float>(row_count));
*(global_mean + idx) = row_mean;
*(global_m2 + idx) = row_m2;
*(global_count + idx) = row_count;
}
}
}

template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int col_width>
inline cudaError_t LaunchLayerNormBlockPartialImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance,
ComputeType* tmp_buffer) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x, grid_dim_y;
{
cudaError_t err =
GetNumBlocks(LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size, 0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
}
grid_dim_y = (cols + col_width - 1) / col_width;
dim3 grid(grid_dim_x, grid_dim_y, 1);

ComputeType* global_mean = tmp_buffer;
ComputeType* global_m2 = global_mean + grid_dim_y * grid_dim_x;
ComputeType* global_count = global_m2 + grid_dim_y * grid_dim_x;

LayerNormBlockPartialImpl<LOAD, STORE, ComputeType, pack_size, block_size, col_width>
<<<grid, block_size, 0, stream>>>(load, store, rows, cols, global_mean, global_m2,
global_count);

WelfordGlobalAllReduce<ComputeType><<<grid_dim_x, grid_dim_y, 0, stream>>>(
grid_dim_x, grid_dim_y, global_mean, global_m2, global_count);

LayerNormBlockPartialAffineImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols, epsilon, global_mean,
global_m2, global_count);

return cudaPeekAtLastError();
}

template<typename LOAD, typename STORE, typename ComputeType, int col_width>
struct DispatchLayerNormBlockPartialImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance, ComputeType* tmp_buffer) {
if (16 / sizeof(ComputeType) == 4 && cols % 4 == 0 && CanPackAs<LOAD>(load, 4)
&& CanPackAs<STORE>(store, 4)) {
return LaunchLayerNormBlockPartialImpl<LOAD, STORE, ComputeType, 4, col_width>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
} else if (16 / sizeof(ComputeType) == 2 && cols % 2 == 0 && CanPackAs<LOAD>(load, 2)
&& CanPackAs<STORE>(store, 2)) {
return LaunchLayerNormBlockPartialImpl<LOAD, STORE, ComputeType, 2, col_width>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
} else {
return LaunchLayerNormBlockPartialImpl<LOAD, STORE, ComputeType, 1, col_width>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
}
}
};

template<typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockPartialImplColWidth(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance, ComputeType* tmp_buffer) {
// TODO: refine here
if (cols > 16384) {
return DispatchLayerNormBlockPartialImplPackSize<LOAD, STORE, ComputeType, 16384 / 4>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
} else if (cols > 8192) {
return DispatchLayerNormBlockPartialImplPackSize<LOAD, STORE, ComputeType, 8192 / 4>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
} else {
return DispatchLayerNormBlockPartialImplPackSize<LOAD, STORE, ComputeType, 512>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
}
return cudaPeekAtLastError();
}

template<typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockPartialImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance,
ComputeType* tmp_buffer) {
return DispatchLayerNormBlockPartialImplColWidth<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
}

template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
Expand Down Expand Up @@ -897,6 +1094,14 @@ DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t row
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}

template<typename LOAD, typename STORE, typename ComputeType>
cudaError_t DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance, ComputeType* tmp_buffer) {
return DispatchLayerNormBlockPartialImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, tmp_buffer);
}

/*
LayerNormGrad dx:
normalized = (x - mean) * inv_var
Expand Down
72 changes: 65 additions & 7 deletions oneflow/user/kernels/layer_norm_gpu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.

#include "oneflow/core/device/cudnn_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/user_op_tensor.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/cuda/atomic.cuh"
#include <cub/cub.cuh>
#include <ios>
#include "oneflow/core/kernel/cuda_graph_support.h"
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/include/primitive/matmul.h"
Expand Down Expand Up @@ -229,6 +231,21 @@ void LayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const
mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());
}

template<typename T, bool do_scale, bool do_center>
void LayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,
const double epsilon, const T* x_ptr, const T* gamma_ptr,
const T* beta_ptr, T* y_ptr, user_op::Tensor* mean,
user_op::Tensor* inv_variance, user_op::Tensor* tmp_buffer) {
CHECK(tmp_buffer->mut_dptr() != nullptr);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
cuda::layer_norm::DirectLoad<T, T> load(x_ptr, norm_size);
AffineStore<ComputeType, T, do_scale, do_center> store(y_ptr, norm_size, gamma_ptr, beta_ptr);
cuda::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), ComputeType>(
stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,
mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>(),
tmp_buffer->mut_dptr<ComputeType>());
}

template<typename T>
void DispatchLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances,
const int64_t norm_size, const double epsilon, const T* x_ptr,
Expand All @@ -249,6 +266,28 @@ void DispatchLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances
}
}

template<typename T>
void DispatchLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances,
const int64_t norm_size, const double epsilon, const T* x_ptr,
const T* gamma_ptr, const T* beta_ptr, T* y_ptr,
user_op::Tensor* mean, user_op::Tensor* inv_variance,
user_op::Tensor* tmp_buffer) {
if (gamma_ptr != nullptr && beta_ptr != nullptr) {
LayerNormForwardGpu<T, true, true>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,
beta_ptr, y_ptr, mean, inv_variance, tmp_buffer);
} else if (gamma_ptr != nullptr && beta_ptr == nullptr) {
LayerNormForwardGpu<T, true, false>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,
beta_ptr, y_ptr, mean, inv_variance, tmp_buffer);
} else if (gamma_ptr == nullptr && beta_ptr != nullptr) {
LayerNormForwardGpu<T, false, true>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,
beta_ptr, y_ptr, mean, inv_variance, tmp_buffer);
} else {
LayerNormForwardGpu<T, false, false>(stream, num_instances, norm_size, epsilon, x_ptr,
gamma_ptr, beta_ptr, y_ptr, mean, inv_variance,
tmp_buffer);
}
}

template<typename T, bool do_scale, bool do_add>
void LayerNormBackwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,
const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean,
Expand Down Expand Up @@ -321,16 +360,35 @@ class LayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaG
CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size);
}
if (ctx->has_input("beta", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr<T>(); }
DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);

const bool enable_partial = ParseBooleanFromEnv("ONEFLOW_ENABLE_PARTIAL_LAYERNORM_IMPL", false);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
if (tmp_buffer != nullptr && enable_partial) {
DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance,
tmp_buffer);
} else {
DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);
}
};
};

#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \
const int64_t num_instances = ctx->InputShape("mean", 0).elem_cnt(); \
const int64_t norm_size = ctx->InputShape("x", 0).elem_cnt() / num_instances; \
using ComputeType = typename cuda::layer_norm::DefaultComputeType<dtype>::type; \
if (norm_size <= 4) { /* TODO: refine here */ \
return 0; \
} else { \
return num_instances * 3 * 1024 * sizeof(ComputeType); \
} \
});

REGISTER_LAYER_NORM_CUDA_KERNEL(float)
REGISTER_LAYER_NORM_CUDA_KERNEL(double)
Expand Down

0 comments on commit c055d70

Please sign in to comment.