File tree 3 files changed +3
-15
lines changed
language_modeling/src/utils/fast_fast_weight
reinforcement_learning/torchbeast/fast_weight
3 files changed +3
-15
lines changed Original file line number Diff line number Diff line change @@ -107,15 +107,14 @@ __global__ void fast_weight_forward_kernel(
107
107
// get old value
108
108
v_old = shared_kv[threadIdx .x + sub * blockDim .x ] *
109
109
shared_keys[t*E_block + e];
110
- __syncthreads ();
111
110
112
111
atomicAdd (
113
112
&shared_values_old[m],
114
113
v_old
115
114
);
116
- __syncthreads ();
117
115
}
118
116
}
117
+ __syncthreads ();
119
118
120
119
// compute new value to be inserted
121
120
if (threadIdx .x < M) {
@@ -132,7 +131,6 @@ __global__ void fast_weight_forward_kernel(
132
131
if (e < E) {
133
132
shared_kv[threadIdx .x + sub * blockDim .x ] +=
134
133
shared_keys[t*E_block + e] * shared_values_insert[m];
135
- __syncthreads ();
136
134
137
135
res = shared_queries[t*E_block + e]
138
136
* shared_kv[threadIdx .x + sub * blockDim .x ];
@@ -497,15 +495,13 @@ __global__ void fast_weight_backward_value_beta_kernel(
497
495
if (e < E) {
498
496
shared_kv[threadIdx .x + sub * blockDim .x ] +=
499
497
shared_queries[t*E_block + e] * shared_gradout[t*M + m];
500
- __syncthreads ();
501
498
502
499
float res = shared_keys[t*E_block + e]
503
500
* shared_kv[threadIdx .x + sub * blockDim .x ];
504
501
atomicAdd (
505
502
&shared_results[m],
506
503
res
507
504
);
508
- __syncthreads ();
509
505
}
510
506
}
511
507
__syncthreads ();
Original file line number Diff line number Diff line change @@ -112,15 +112,14 @@ __global__ void fast_weight_forward_kernel(
112
112
if (e < E) {
113
113
// get old value
114
114
v_old = shared_kv[kv_idx] * shared_keys[e_abs];
115
- __syncthreads ();
116
115
117
116
atomicAdd (
118
117
&shared_v_old[m],
119
118
v_old
120
119
);
121
- __syncthreads ();
122
120
}
123
121
}
122
+ __syncthreads ();
124
123
125
124
// compute new value to be inserted
126
125
if (threadIdx .x < M) {
@@ -138,7 +137,6 @@ __global__ void fast_weight_forward_kernel(
138
137
kv_idx = threadIdx .x + sub * blockDim .x ;
139
138
if (e < E) {
140
139
shared_kv[kv_idx] += shared_keys[e_abs] * shared_v_insert[m];
141
- __syncthreads ();
142
140
res = shared_queries[e_abs] * shared_kv[kv_idx];
143
141
atomicAdd (
144
142
&shared_results[m],
@@ -512,14 +510,12 @@ __global__ void fast_weight_backward_value_beta_kernel(
512
510
if (e < E) {
513
511
shared_kv[kv_idx] +=
514
512
shared_queries[e_abs] * shared_gradout[m_abs];
515
- __syncthreads ();
516
513
517
514
float res = shared_keys[e_abs] * shared_kv[kv_idx];
518
515
atomicAdd (
519
516
&shared_results[m],
520
517
res
521
518
);
522
- __syncthreads ();
523
519
}
524
520
}
525
521
__syncthreads ();
Original file line number Diff line number Diff line change @@ -107,15 +107,14 @@ __global__ void fast_weight_forward_kernel(
107
107
// get old value
108
108
v_old = shared_kv[threadIdx .x + sub * blockDim .x ] *
109
109
shared_keys[t*E_block + e];
110
- __syncthreads ();
111
110
112
111
atomicAdd (
113
112
&shared_values_old[m],
114
113
v_old
115
114
);
116
- __syncthreads ();
117
115
}
118
116
}
117
+ __syncthreads ();
119
118
120
119
// compute new value to be inserted
121
120
if (threadIdx .x < M) {
@@ -132,7 +131,6 @@ __global__ void fast_weight_forward_kernel(
132
131
if (e < E) {
133
132
shared_kv[threadIdx .x + sub * blockDim .x ] +=
134
133
shared_keys[t*E_block + e] * shared_values_insert[m];
135
- __syncthreads ();
136
134
137
135
res = shared_queries[t*E_block + e]
138
136
* shared_kv[threadIdx .x + sub * blockDim .x ];
@@ -497,15 +495,13 @@ __global__ void fast_weight_backward_value_beta_kernel(
497
495
if (e < E) {
498
496
shared_kv[threadIdx .x + sub * blockDim .x ] +=
499
497
shared_queries[t*E_block + e] * shared_gradout[t*M + m];
500
- __syncthreads ();
501
498
502
499
float res = shared_keys[t*E_block + e]
503
500
* shared_kv[threadIdx .x + sub * blockDim .x ];
504
501
atomicAdd (
505
502
&shared_results[m],
506
503
res
507
504
);
508
- __syncthreads ();
509
505
}
510
506
}
511
507
__syncthreads ();
You can’t perform that action at this time.
0 commit comments