Skip to content

Commit

Permalink
Change code style with clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
anstellaire committed Feb 14, 2025
1 parent 5026363 commit 0c5efd2
Showing 1 changed file with 63 additions and 57 deletions.
120 changes: 63 additions & 57 deletions cpp/src/neighbors/refine/refine_host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ namespace detail {
// -----------------------------------------------------------------------------

template <typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) {
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n)
{
size_t constexpr max_vreg_len = 512 / (8 * sizeof(DistanceT));

// max_vreg_len is a power of two
size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1));
size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1));
DistanceT distance[max_vreg_len] = {0};

for (size_t i = 0; i < n_rounded; i += max_vreg_len) {
Expand Down Expand Up @@ -70,42 +71,44 @@ struct distance_comp_l2;
struct distance_comp_inner;

// fallback
template<typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n) {
template <typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n)
{
return euclidean_distance_squared_generic<DC, DistanceT, DataT>(a, b, n);
}

#if defined(__arm__) || defined(__aarch64__)

template<>
inline float euclidean_distance_squared<distance_comp_l2, float, float>(
float const* a, float const* b, size_t n) {

template <>
inline float euclidean_distance_squared<distance_comp_l2, float, float>(float const* a,
float const* b,
size_t n)
{
size_t n_rounded = n - (n % 4);

float32x4_t vreg_dsum = vdupq_n_f32(0.f);
for (size_t i = 0; i < n_rounded; i += 4) {
float32x4_t vreg_a = vld1q_f32(&a[i]);
float32x4_t vreg_b = vld1q_f32(&b[i]);
float32x4_t vreg_d = vsubq_f32(vreg_a, vreg_b);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
}

float dsum = vaddvq_f32(vreg_dsum);
for (size_t i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d;
float d = a[i] - b[i];
dsum += d * d;
}

return dsum;
}

template<>
template <>
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {

::std::int8_t const* a, ::std::int8_t const* b, size_t n)
{
size_t n_rounded = n - (n % 16);
float dsum = 0.f;
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
Expand All @@ -114,11 +117,11 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (size_t i = 0; i < n_rounded; i += 16) {
int8x16_t vreg_a = vld1q_s8(&a[i]);
int8x16_t vreg_a = vld1q_s8(&a[i]);
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));

int8x16_t vreg_b = vld1q_s8(&b[i]);
int8x16_t vreg_b = vld1q_s8(&b[i]);
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));

Expand All @@ -140,23 +143,23 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (size_t i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
}

return dsum;
}

template<>
template <>
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>(
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {

::std::uint8_t const* a, ::std::uint8_t const* b, size_t n)
{
size_t n_rounded = n - (n % 16);
float dsum = 0.f;
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
Expand All @@ -165,17 +168,17 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (size_t i = 0; i < n_rounded; i += 16) {
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));

uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_0)));
float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_0)));
float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_1)));
Expand All @@ -196,45 +199,46 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (size_t i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_inner, float, float>(
float const* a, float const* b, size_t n) {

template <>
inline float euclidean_distance_squared<distance_comp_inner, float, float>(float const* a,
float const* b,
size_t n)
{
size_t n_rounded = n - (n % 4);

float32x4_t vreg_dsum = vdupq_n_f32(0.f);
for (size_t i = 0; i < n_rounded; i += 4) {
float32x4_t vreg_a = vld1q_f32(&a[i]);
float32x4_t vreg_b = vld1q_f32(&b[i]);
vreg_a = vnegq_f32(vreg_a);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
vreg_a = vnegq_f32(vreg_a);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
}

float dsum = vaddvq_f32(vreg_dsum);
for (size_t i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
dsum += -a[i] * b[i];
}

return dsum;
}

template<>
template <>
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_t>(
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {

::std::int8_t const* a, ::std::int8_t const* b, size_t n)
{
size_t n_rounded = n - (n % 16);
float dsum = 0.f;
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
Expand All @@ -243,11 +247,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (size_t i = 0; i < n_rounded; i += 16) {
int8x16_t vreg_a = vld1q_s8(&a[i]);
int8x16_t vreg_a = vld1q_s8(&a[i]);
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));

int8x16_t vreg_b = vld1q_s8(&b[i]);
int8x16_t vreg_b = vld1q_s8(&b[i]);
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));

Expand All @@ -269,20 +273,22 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (size_t i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
dsum += -a[i] * b[i];
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
template <>
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n)
{
size_t n_rounded = n - (n % 16);
float dsum = 0.f;
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
Expand All @@ -291,11 +297,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (size_t i = 0; i < n_rounded; i += 16) {
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));

uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));

Expand All @@ -317,17 +323,17 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (size_t i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
dsum += -a[i] * b[i];
}

return dsum;
}

#endif // defined(__arm__) || defined(__aarch64__)
#endif // defined(__arm__) || defined(__aarch64__)

// -----------------------------------------------------------------------------
// Refine kernel
Expand Down Expand Up @@ -421,7 +427,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
distance = std::numeric_limits<DistanceT>::max();
} else {
const DataT* row = dataset.data_handle() + dim * id;
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
}
refined_pairs[tid][j] = std::make_tuple(distance, id);
}
Expand Down

0 comments on commit 0c5efd2

Please sign in to comment.