Skip to content

Commit

Permalink
Switch from using highest SIMD ABI to GCC function multi-versioning
Browse files Browse the repository at this point in the history
Prevents crashes on older CPUs without wider SIMD instructions.

Workaround lack of target-specific template specialisation in GCC.

Fixes: microsoft#316
  • Loading branch information
pabs3 committed Feb 6, 2023
1 parent cc01d1f commit 23263a5
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 101 deletions.
2 changes: 1 addition & 1 deletion AnnService/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_library (DistanceUtils STATIC
)

if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
target_compile_options(DistanceUtils PRIVATE -mavx2 -mavx -msse -msse2 -mavx512f -mavx512bw -mavx512dq -fPIC)
target_compile_options(DistanceUtils PRIVATE -fPIC)
endif()

add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES})
Expand Down
99 changes: 68 additions & 31 deletions AnnService/inc/Core/Common/DistanceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace SPTAG
class DistanceUtils
{
public:
DistanceUtils();

template <typename T>
static float ComputeL2Distance(const T* pX, const T* pY, DimensionType length)
{
Expand All @@ -42,21 +44,34 @@ namespace SPTAG
return diff;
}

static float ComputeL2Distance_SSE(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static float ComputeL2Distance_AVX(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static float ComputeL2Distance_AVX512(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);

static float ComputeL2Distance_SSE(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static float ComputeL2Distance_AVX(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static float ComputeL2Distance_AVX512(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);

static float ComputeL2Distance_SSE(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static float ComputeL2Distance_AVX(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static float ComputeL2Distance_AVX512(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);

static float ComputeL2Distance_SSE(const float* pX, const float* pY, DimensionType length);
static float ComputeL2Distance_AVX(const float* pX, const float* pY, DimensionType length);
static float ComputeL2Distance_AVX512(const float* pX, const float* pY, DimensionType length);
#ifndef _MSC_VER
/*
GCC cannot yet do target-specific template specialisation.
As a workaround add target-specific non-template functions
that just call the template functions and hope for inlining.
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81276
*/
f_Naive(static float ComputeL2Distance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) { return ComputeL2Distance<std::int8_t>(pX, pY, length); }
f_Naive(static float ComputeL2Distance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) { return ComputeL2Distance<std::uint8_t>(pX, pY, length); }
f_Naive(static float ComputeL2Distance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) { return ComputeL2Distance<std::int16_t>(pX, pY, length); }
f_Naive(static float ComputeL2Distance)(const float* pX, const float* pY, DimensionType length) { return ComputeL2Distance<float>(pX, pY, length); }
#endif

f_SSE(static float ComputeL2Distance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX(static float ComputeL2Distance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX512(static float ComputeL2Distance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);

f_SSE(static float ComputeL2Distance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX(static float ComputeL2Distance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX512(static float ComputeL2Distance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);

f_SSE(static float ComputeL2Distance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX(static float ComputeL2Distance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX512(static float ComputeL2Distance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);

f_SSE(static float ComputeL2Distance)(const float* pX, const float* pY, DimensionType length);
f_AVX(static float ComputeL2Distance)(const float* pX, const float* pY, DimensionType length);
f_AVX512(static float ComputeL2Distance)(const float* pX, const float* pY, DimensionType length);

template <typename T>
static float ComputeCosineDistance(const T* pX, const T* pY, DimensionType length)
Expand All @@ -78,22 +93,34 @@ namespace SPTAG
return base * base - diff;
}

static float ComputeCosineDistance_SSE(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX512(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);

static float ComputeCosineDistance_SSE(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX512(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);

static float ComputeCosineDistance_SSE(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static float ComputeCosineDistance_AVX512(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);

static float ComputeCosineDistance_SSE(const float* pX, const float* pY, DimensionType length);
static float ComputeCosineDistance_AVX(const float* pX, const float* pY, DimensionType length);
static float ComputeCosineDistance_AVX512(const float* pX, const float* pY, DimensionType length);

#ifndef _MSC_VER
/*
GCC cannot yet do target-specific template specialisation.
As a workaround add target-specific non-template functions
that just call the template functions and hope for inlining.
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81276
*/
f_Naive(static float ComputeCosineDistance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) { return ComputeCosineDistance<std::int8_t>(pX, pY, length); }
f_Naive(static float ComputeCosineDistance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) { return ComputeCosineDistance<std::uint8_t>(pX, pY, length); }
f_Naive(static float ComputeCosineDistance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) { return ComputeCosineDistance<std::int16_t>(pX, pY, length); }
f_Naive(static float ComputeCosineDistance)(const float* pX, const float* pY, DimensionType length) { return ComputeCosineDistance<float>(pX, pY, length); }
#endif

f_SSE(static float ComputeCosineDistance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX(static float ComputeCosineDistance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX512(static float ComputeCosineDistance)(const std::int8_t* pX, const std::int8_t* pY, DimensionType length);

f_SSE(static float ComputeCosineDistance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX(static float ComputeCosineDistance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX512(static float ComputeCosineDistance)(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);

f_SSE(static float ComputeCosineDistance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX(static float ComputeCosineDistance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX512(static float ComputeCosineDistance)(const std::int16_t* pX, const std::int16_t* pY, DimensionType length);

f_SSE(static float ComputeCosineDistance)(const float* pX, const float* pY, DimensionType length);
f_AVX(static float ComputeCosineDistance)(const float* pX, const float* pY, DimensionType length);
f_AVX512(static float ComputeCosineDistance)(const float* pX, const float* pY, DimensionType length);

template<typename T>
static inline float ComputeDistance(const T* p1, const T* p2, DimensionType length, SPTAG::DistCalcMethod distCalcMethod)
Expand All @@ -118,11 +145,14 @@ namespace SPTAG
template<typename T>
inline DistanceCalcReturn<T> DistanceCalcSelector(SPTAG::DistCalcMethod p_method)
{
#ifdef _MSC_VER
bool isSize4 = (sizeof(T) == 4);
#endif // _MSC_VER
switch (p_method)
{
case SPTAG::DistCalcMethod::InnerProduct:
case SPTAG::DistCalcMethod::Cosine:
#ifdef _MSC_VER
if (InstructionSet::AVX512())
{
return &(DistanceUtils::ComputeCosineDistance_AVX512);
Expand All @@ -138,8 +168,12 @@ namespace SPTAG
else {
return &(DistanceUtils::ComputeCosineDistance);
}
#else // _MSC_VER
return &(DistanceUtils::ComputeCosineDistance);
#endif // !_MSC_VER

case SPTAG::DistCalcMethod::L2:
#ifdef _MSC_VER
if (InstructionSet::AVX512())
{
return &(DistanceUtils::ComputeL2Distance_AVX512);
Expand All @@ -155,6 +189,9 @@ namespace SPTAG
else {
return &(DistanceUtils::ComputeL2Distance);
}
#else // _MSC_VER
return &(DistanceUtils::ComputeL2Distance);
#endif // !_MSC_VER

default:
break;
Expand Down
40 changes: 40 additions & 0 deletions AnnService/inc/Core/Common/InstructionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,37 @@ void cpuid(int info[4], int InfoType);
#define cpuid(info, x) __cpuidex(info, x, 0)
#endif

// MSVC has no attributes, target attribute, ifunc and function multi-versions
#ifdef _MSC_VER
// MSVC SIMD functions just have different names
#define f_Naive(func) func_Naive
#define f_SSE(func) func_SSE
#define f_SSE2(func) func_SSE2
#define f_AVX(func) func_AVX
#define f_AVX2(func) func_AVX2
#define f_AVX512(func) func_AVX512
// Inline functions
#define if_SSE(func) func
#define if_SSE2(func) func
#define if_AVX(func) func
#define if_AVX2(func) func
#define if_AVX512(func) func
#else
// GCC SIMD functions use function multi-versioning
#define f_Naive(func) __attribute__ ((target ("default"))) func
#define f_SSE(func) __attribute__ ((target ("sse"))) func
#define f_SSE2(func) __attribute__ ((target ("sse2"))) func
#define f_AVX(func) __attribute__ ((target ("avx"))) func
#define f_AVX2(func) __attribute__ ((target ("avx2"))) func
#define f_AVX512(func) __attribute__ ((target ("avx512f,avx512bw,avx512dq"))) func
// Inline functions
#define if_SSE(func) f_SSE(func)
#define if_SSE2(func) f_SSE2(func)
#define if_AVX(func) f_AVX(func)
#define if_AVX2(func) f_AVX2(func)
#define if_AVX512(func) f_AVX512(func)
#endif

namespace SPTAG {
namespace COMMON {

Expand Down Expand Up @@ -47,6 +78,15 @@ namespace SPTAG {
bool HW_AVX;
bool HW_AVX2;
bool HW_AVX512;
#ifndef _MSC_VER
private:
f_Naive(void Initialise)(void);
f_SSE(void Initialise)(void);
f_SSE2(void Initialise)(void);
f_AVX(void Initialise)(void);
f_AVX2(void Initialise)(void);
f_AVX512(void Initialise)(void);
#endif // !_MSC_VER
};
};
}
Expand Down
45 changes: 33 additions & 12 deletions AnnService/inc/Core/Common/SIMDUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ namespace SPTAG
{
namespace COMMON
{
#ifdef _MSC_VER
template <typename T>
using SumCalcReturn = void(*)(T*, const T*, DimensionType);
template<typename T>
inline SumCalcReturn<T> SumCalcSelector();
#endif // _MSC_VER

class SIMDUtils
{
public:
SIMDUtils();

template <typename T>
static void ComputeSum_Naive(T* pX, const T* pY, DimensionType length)
{
Expand All @@ -31,30 +35,46 @@ namespace SPTAG
}
}

static void ComputeSum_SSE(std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static void ComputeSum_AVX(std::int8_t* pX, const std::int8_t* pY, DimensionType length);
static void ComputeSum_AVX512(std::int8_t* pX, const std::int8_t* pY, DimensionType length);
#ifndef _MSC_VER
/*
GCC cannot yet do target-specific template specialisation.
As a workaround add target-specific non-template functions
that just call the template functions and hope for inlining.
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81276
*/
f_Naive(static void ComputeSum)(std::int8_t* pX, const std::int8_t* pY, DimensionType length) { ComputeSum_Naive<std::int8_t>(pX, pY, length); }
f_Naive(static void ComputeSum)(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) { ComputeSum_Naive<std::uint8_t>(pX, pY, length); }
f_Naive(static void ComputeSum)(std::int16_t* pX, const std::int16_t* pY, DimensionType length) { ComputeSum_Naive<std::int16_t>(pX, pY, length); }
f_Naive(static void ComputeSum)(float* pX, const float* pY, DimensionType length) { ComputeSum_Naive<float>(pX, pY, length); }
#endif // !_MSC_VER

f_SSE(static void ComputeSum)(std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX(static void ComputeSum)(std::int8_t* pX, const std::int8_t* pY, DimensionType length);
f_AVX512(static void ComputeSum)(std::int8_t* pX, const std::int8_t* pY, DimensionType length);

static void ComputeSum_SSE(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static void ComputeSum_AVX(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
static void ComputeSum_AVX512(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_SSE(static void ComputeSum)(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX(static void ComputeSum)(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);
f_AVX512(static void ComputeSum)(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length);

static void ComputeSum_SSE(std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static void ComputeSum_AVX(std::int16_t* pX, const std::int16_t* pY, DimensionType length);
static void ComputeSum_AVX512(std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_SSE(static void ComputeSum)(std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX(static void ComputeSum)(std::int16_t* pX, const std::int16_t* pY, DimensionType length);
f_AVX512(static void ComputeSum)(std::int16_t* pX, const std::int16_t* pY, DimensionType length);

static void ComputeSum_SSE(float* pX, const float* pY, DimensionType length);
static void ComputeSum_AVX(float* pX, const float* pY, DimensionType length);
static void ComputeSum_AVX512(float* pX, const float* pY, DimensionType length);
f_SSE(static void ComputeSum)(float* pX, const float* pY, DimensionType length);
f_AVX(static void ComputeSum)(float* pX, const float* pY, DimensionType length);
f_AVX512(static void ComputeSum)(float* pX, const float* pY, DimensionType length);

#ifdef _MSC_VER
template<typename T>
static inline void ComputeSum(T* p1, const T* p2, DimensionType length)
{
auto func = SumCalcSelector<T>();
return func(p1, p2, length);
}
#endif // _MSC_VER
};

#ifdef _MSC_VER
template<typename T>
inline SumCalcReturn<T> SumCalcSelector()
{
Expand All @@ -73,6 +93,7 @@ namespace SPTAG
}
return &(SIMDUtils::ComputeSum_Naive);
}
#endif // _MSC_VER
}
}

Expand Down
Loading

0 comments on commit 23263a5

Please sign in to comment.