-
Notifications
You must be signed in to change notification settings - Fork 0
/
IntLayerNorm.cpp
115 lines (97 loc) · 2.76 KB
/
IntLayerNorm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <IntLayerNorm.h>
IntLayerNorm::IntLayerNorm(
int Output_bit,
bool Overflow_handling = true,
string Quant_mode = "none",
string Force_dequant = "none")
{
quant_mode = Quant_mode;
if (Force_dequant == "nonlinear" || Force_dequant == "layernorm")
{
quant_mode = "none";
}
overflow_handling = Overflow_handling;
shift = register_buffer("shift", torch::zeros(1));
output_bit = Output_bit;
dim_sqrt = torch::empty(0);
//activition = QuantAct(output_bit, quant_mode);
if (quant_mode != "none" && quant_mode != "symmetric")
{
printf("not implemented quant mode!");
exit(-1);
}
}
void IntLayerNorm::fix()
{
overflow_handling = false;
}
void IntLayerNorm::unfix()
{
overflow_handling = true;
}
int IntLayerNorm::set_param(torch::nn::Linear ln)
{
auto normalized_shape = ln.normalized_shape;
auto eps = ln.eps;
auto weightcp = ln->weight.data();
weight = weightcp.clone();
return 0;
}
int IntLayerNorm::set_shift(torch::Tensor y_int)
{
torch::NoGradGuard no_grad;
auto y_sq_int = y_int ^ 2;
auto var_int = torch::sum(y_sq_int, 2, true);
auto int_log = torch::log2(torch::sqrt(var_int / (2 ^ 32))).ceil();
auto shift2 = int_log.max();
auto shift_old = shift;
shift = torch::max(shift, shift2);
cout << "Dynamic shift adjustment: " << shift_old << "->" << shift << endl;
}
torch::Tensor IntLayerNorm::overflow_fallback(torch::Tensor y_int)
{
set_shift(y_int);
auto y_int_shifted = torch::floor(y_int / (2 ^ shift));
auto y_sq_int = y_int_shifted ^ 2;
auto var_int = torch::sum(y_sq_int, 2, true);
return var_int;
}
torch::Tensor IntLayerNorm::forward(torch::Tensor x, torch::Tensor scaling_factor = torch::empty(0), torch::Tensor exponents = torch::empty(0))
{
if (quant_mode == "none")
{
auto mean = x.mean(2, true);
auto y = x - mean;
auto var = torch::mean(y ^ 2, 2, true);
x = y / torch::sqrt(eps + var);
x = x * weight + bias;
return x;
}
if (quant_mode != "symmetric")
{
printf("unspported mode!");
exit(-1);
}
if (dim_sqrt.numel() == 0)
{
auto n = torch::tensor(x.sizes()[2], torch::kFloat);
dim_sqrt = torch::sqrt(n);
}
auto x_int = x / scaling_factor;
auto mean_int = torch::round(x_int.mean(2, true));
auto y_int = x_int - mean_int;
auto y_int_shifted = torch::floor(y_int / (2 ^ shift));
auto y_sq_int = y_int_shifted ^ 2;
auto var_int = torch::sum(y_sq_int, 2, true);
//overflow handling in training stage TBD
auto std_int = torch::floor(torch::sqrt(var_int)) * (2 ^ shift);
auto factor = torch::floor((2 ^ 31) / std_int);
y_int = torch::floor(y_int * factor / 2);
scaling_factor = dim_sqrt / (2 ^ 30);
auto bias2 = bias.detach() / weight.detach();
auto bias_int = torch::floor(bias / scaling_factor);
y_int = y_int + bias_int;
scaling_factor = scaling_factor * weight;
x = y_int * scaling_factor;
return x;
}