Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold committed Sep 21, 2023
1 parent c055d70 commit a3ffa55
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions oneflow/core/cuda/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -914,18 +914,10 @@ __global__ void WelfordGlobalAllReduce(int64_t rows, int64_t cols, ComputeType*
static_cast<int>(row), static_cast<int>(blockIdx.y), static_cast<int>(threadIdx.x),
static_cast<float>(thread_mean), static_cast<float>(thread_m2),
static_cast<float>(thread_count));
// TODO: bug here
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
printf("blockIdx.x: %d, blockIdx.y: %d, thread: %d, row mean: %f, row m2: %f, row "
"count: %f\n",
static_cast<int>(row), static_cast<int>(blockIdx.y), static_cast<int>(threadIdx.x),
static_cast<float>(row_mean), static_cast<float>(row_m2), static_cast<float>(row_count));
if (threadIdx.x == 0) {
// printf(
// "blockIdx.x: %d, blockIdx.y: %d, thread: %d, row mean: %f, row m2: %f, row count:
// %f\n", static_cast<int>(row), static_cast<int>(blockIdx.y),
// static_cast<int>(threadIdx.x), static_cast<float>(row_mean),
// static_cast<float>(row_m2), static_cast<float>(row_count));
int idx = row * cols;
*(global_mean + idx) = row_mean;
*(global_m2 + idx) = row_m2;
Expand Down Expand Up @@ -961,7 +953,6 @@ __global__ void LayerNormBlockPartialImpl(LOAD load, STORE store, const int64_t
ComputeType row_count = 0;
__syncthreads();

// TODO: bug here
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
if (threadIdx.x == 0) {
Expand Down

0 comments on commit a3ffa55

Please sign in to comment.