diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index af13745d27eda8..6ea956f22d39a9 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -138,13 +138,14 @@ HOSTDEVICE inline T AdaptEndIndex(T ph, T input_size, T output_size) { /* used for fractional pool to calculate start and end index of each divided * grid */ +template HOSTDEVICE inline float FractionalRationalU( - float u, float alpha, int input, int output, int pool_size = 0) { + float u, float alpha, T input, T output, T pool_size = 0) { if (pool_size > 0) { return u; } - int base = input / output; + T base = input / output; float u_max1 = static_cast(base + 2) / alpha - 1; float u_max2 = static_cast(input + 1 - base) / alpha - @@ -154,24 +155,26 @@ HOSTDEVICE inline float FractionalRationalU( return u * max_u; } -HOSTDEVICE inline int FractionalStartIndex(int idx, - float alpha, - float u, - int pool_size = 0) { +template +HOSTDEVICE inline T FractionalStartIndex(T idx, + float alpha, + float u, + T pool_size = 0) { // paper use ceil instead: static_cast(ceil(alpha * (idx + u) - 1)); - return static_cast((idx + u) * alpha) - static_cast(u * alpha); + return static_cast((idx + u) * alpha) - static_cast(u * alpha); } -HOSTDEVICE inline int FractionalEndIndex(int idx, - float alpha, - float u, - int pool_size = 0) { +template +HOSTDEVICE inline T FractionalEndIndex(T idx, + float alpha, + float u, + T pool_size = 0) { if (pool_size > 0) { - return static_cast((idx + u) * alpha) - static_cast(u * alpha) + + return static_cast((idx + u) * alpha) - static_cast(u * alpha) + pool_size; } // paper use ceil instead: static_cast(ceil(alpha * (idx + 1 + u) - 1)); - return static_cast((idx + 1 + u) * alpha) - static_cast(u * alpha); + return static_cast((idx + 1 + u) * alpha) - static_cast(u * alpha); } /*