-
Notifications
You must be signed in to change notification settings - Fork 0
/
softmax.cu
112 lines (83 loc) · 2.88 KB
/
softmax.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <stdio.h>
// error checking macro
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
// const size_t DSIZE = 16384; // matrix side dimension
const size_t DSIZE = 6; // matrix side dimension
const int block_size = 32; // CUDA maximum is 1024
const float element_val = 100;
__global__ void softmax_max(float *A, size_t ds) {
int idx = threadIdx.x;
__shared__ float sdata[block_size];
sdata[idx] = 0.0f;
float val = 0.0f;
// Total elements this block is supposed to handle
int total_elements = ds * ds;
int start_index = blockIdx.x * ds; // Start index for this block
int end_index = start_index + ds; // End index for this block
__shared__ float max_val;
max_val = 0.0f;
// Find the maximum value in the block
for (int index = start_index + idx; index < end_index; index += blockDim.x) {
if (index < ds*ds) sdata[idx] = max(A[index], sdata[idx]);
}
for(int s = blockDim.x/2; s > 0; s/=2){
__syncthreads();
if (idx < s) sdata[idx] = max(sdata[idx], sdata[idx + s]);
}
__syncthreads();
if (idx == 0) max_val = sdata[0];
__syncthreads();
sdata[idx] = 0.0f;
// Process elements
for (int index = start_index + idx; index < end_index; index += blockDim.x) {
if (index < total_elements) {
val = expf(A[index] - max_val);
A[index] = val;
atomicAdd(&sdata[idx], val);
}
}
__syncthreads();
// Sum reduction in shared memory
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (idx < s) {
sdata[idx] += sdata[idx + s];
}
__syncthreads();
}
// Normalize the values
for (int index = start_index + idx; index < end_index; index += blockDim.x) {
if (index < total_elements) {
A[index] /= sdata[0];
}
}
}
int main(){
float *h_A, *d_A;
h_A = new float[DSIZE*DSIZE];
for (int i = 0; i < DSIZE*DSIZE; i++) h_A[i] = element_val;
cudaMalloc(&d_A, DSIZE*DSIZE*sizeof(float));
cudaCheckErrors("cudaMalloc failure");
cudaMemcpy(d_A, h_A, DSIZE*DSIZE*sizeof(float), cudaMemcpyHostToDevice);
cudaCheckErrors("cudaMemcpy H2D failure");
softmax_max<<<DSIZE, block_size>>>(d_A, DSIZE);
cudaCheckErrors("kernel launch failure");
cudaMemcpy(h_A, d_A, DSIZE*DSIZE*sizeof(float), cudaMemcpyDeviceToHost);
cudaCheckErrors("cudaMemcpy D2H failure");
for(int i = 0; i < DSIZE*DSIZE; i++){
printf("h_A[%d]: %.8f\n", i, h_A[i]);
if(abs(h_A[i] - 1/(float)DSIZE) > 0.00001
) {printf("results mismatch at %d, was: %.10f, should be: %.10f\n", i, h_A[i], 1/float(DSIZE)); return -1;}
}
printf("softmax correct!\n");
return 0;
}