Skip to content

Commit eb650bb

Browse files
committed
fast weight kernel, remove extra syncthreads
1 parent daac780 commit eb650bb

File tree

3 files changed

+3
-15
lines changed

3 files changed

+3
-15
lines changed

algorithmic/fast_weight/fast_weight_cuda.cu

+1-5
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,14 @@ __global__ void fast_weight_forward_kernel(
107107
// get old value
108108
v_old = shared_kv[threadIdx.x + sub * blockDim.x] *
109109
shared_keys[t*E_block + e];
110-
__syncthreads();
111110

112111
atomicAdd(
113112
&shared_values_old[m],
114113
v_old
115114
);
116-
__syncthreads();
117115
}
118116
}
117+
__syncthreads();
119118

120119
// compute new value to be inserted
121120
if (threadIdx.x < M) {
@@ -132,7 +131,6 @@ __global__ void fast_weight_forward_kernel(
132131
if (e < E) {
133132
shared_kv[threadIdx.x + sub * blockDim.x] +=
134133
shared_keys[t*E_block + e] * shared_values_insert[m];
135-
__syncthreads();
136134

137135
res = shared_queries[t*E_block + e]
138136
* shared_kv[threadIdx.x + sub * blockDim.x];
@@ -497,15 +495,13 @@ __global__ void fast_weight_backward_value_beta_kernel(
497495
if (e < E) {
498496
shared_kv[threadIdx.x + sub * blockDim.x] +=
499497
shared_queries[t*E_block + e] * shared_gradout[t*M + m];
500-
__syncthreads();
501498

502499
float res = shared_keys[t*E_block + e]
503500
* shared_kv[threadIdx.x + sub * blockDim.x];
504501
atomicAdd(
505502
&shared_results[m],
506503
res
507504
);
508-
__syncthreads();
509505
}
510506
}
511507
__syncthreads();

language_modeling/src/utils/fast_fast_weight/fast_weight_cuda.cu

+1-5
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,14 @@ __global__ void fast_weight_forward_kernel(
112112
if (e < E) {
113113
// get old value
114114
v_old = shared_kv[kv_idx] * shared_keys[e_abs];
115-
__syncthreads();
116115

117116
atomicAdd(
118117
&shared_v_old[m],
119118
v_old
120119
);
121-
__syncthreads();
122120
}
123121
}
122+
__syncthreads();
124123

125124
// compute new value to be inserted
126125
if (threadIdx.x < M) {
@@ -138,7 +137,6 @@ __global__ void fast_weight_forward_kernel(
138137
kv_idx = threadIdx.x + sub * blockDim.x;
139138
if (e < E) {
140139
shared_kv[kv_idx] += shared_keys[e_abs] * shared_v_insert[m];
141-
__syncthreads();
142140
res = shared_queries[e_abs] * shared_kv[kv_idx];
143141
atomicAdd(
144142
&shared_results[m],
@@ -512,14 +510,12 @@ __global__ void fast_weight_backward_value_beta_kernel(
512510
if (e < E) {
513511
shared_kv[kv_idx] +=
514512
shared_queries[e_abs] * shared_gradout[m_abs];
515-
__syncthreads();
516513

517514
float res = shared_keys[e_abs] * shared_kv[kv_idx];
518515
atomicAdd(
519516
&shared_results[m],
520517
res
521518
);
522-
__syncthreads();
523519
}
524520
}
525521
__syncthreads();

reinforcement_learning/torchbeast/fast_weight/fast_weight_cuda.cu

+1-5
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,14 @@ __global__ void fast_weight_forward_kernel(
107107
// get old value
108108
v_old = shared_kv[threadIdx.x + sub * blockDim.x] *
109109
shared_keys[t*E_block + e];
110-
__syncthreads();
111110

112111
atomicAdd(
113112
&shared_values_old[m],
114113
v_old
115114
);
116-
__syncthreads();
117115
}
118116
}
117+
__syncthreads();
119118

120119
// compute new value to be inserted
121120
if (threadIdx.x < M) {
@@ -132,7 +131,6 @@ __global__ void fast_weight_forward_kernel(
132131
if (e < E) {
133132
shared_kv[threadIdx.x + sub * blockDim.x] +=
134133
shared_keys[t*E_block + e] * shared_values_insert[m];
135-
__syncthreads();
136134

137135
res = shared_queries[t*E_block + e]
138136
* shared_kv[threadIdx.x + sub * blockDim.x];
@@ -497,15 +495,13 @@ __global__ void fast_weight_backward_value_beta_kernel(
497495
if (e < E) {
498496
shared_kv[threadIdx.x + sub * blockDim.x] +=
499497
shared_queries[t*E_block + e] * shared_gradout[t*M + m];
500-
__syncthreads();
501498

502499
float res = shared_keys[t*E_block + e]
503500
* shared_kv[threadIdx.x + sub * blockDim.x];
504501
atomicAdd(
505502
&shared_results[m],
506503
res
507504
);
508-
__syncthreads();
509505
}
510506
}
511507
__syncthreads();

0 commit comments

Comments
 (0)