Skip to content

Commit

Permalink
Make things compile with AVX2
Browse files Browse the repository at this point in the history
  • Loading branch information
eggrobin committed Mar 31, 2024
1 parent 69dd1fb commit aa22a46
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 138 deletions.
169 changes: 42 additions & 127 deletions Principia.sln

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions base/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ char const* const Architecture = "x86-64";
#define PRINCIPIA_USE_FMA_IF_AVAILABLE() !_DEBUG
#endif

#define PRINCIPIA_USE_AVX2_INTRINSICS() (!_DEBUG && __AVX2__)

#ifndef PRINCIPIA_CONFIGURABLE_TEST_SUFFIX
#define PRINCIPIA_CONFIGURABLE_TEST_SUFFIX
#endif
Expand Down
11 changes: 10 additions & 1 deletion geometry/r3_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@ struct SphericalCoordinates;
// space over ℝ, represented by |double|. |R3Element| is the underlying data
// type for more advanced strongly typed structures suchas |Multivector|.
template<typename Scalar>
struct alignas(16) R3Element final {
struct
#if PRINCIPIA_USE_AVX2_INTRINSICS()
alignas(32)
#else
alignas(16)
#endif
R3Element final {
public:
constexpr R3Element();
constexpr explicit R3Element(uninitialized_t);
R3Element(double const (&xyz)[3]) requires std::is_same_v<Scalar, double>;
R3Element(Scalar const& x, Scalar const& y, Scalar const& z);
R3Element(__m128d xy, __m128d zt);
R3Element(__m256d xyzt);

Scalar& operator[](int index);
Scalar const& operator[](int index) const;
Expand Down Expand Up @@ -71,6 +79,7 @@ struct alignas(16) R3Element final {
__m128d xy;
__m128d zt;
};
__m256d xyzt;
};
};

Expand Down
96 changes: 89 additions & 7 deletions geometry/r3_element_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ constexpr R3Element<Scalar>::R3Element(uninitialized_t) {
"R3Element has a nonstandard layout");
}

template<typename Scalar>
R3Element<Scalar>::R3Element(double const (&xyz)[3])
requires std::is_same_v<Scalar, double>
: x(xyz[0]), y(xyz[1]), z(xyz[2]) {}

template<typename Scalar>
R3Element<Scalar>::R3Element(Scalar const& x,
Scalar const& y,
Expand All @@ -50,6 +55,13 @@ R3Element<Scalar>::R3Element(__m128d const xy, __m128d const zt)
"R3Element has a nonstandard layout");
}

template<typename Scalar>
R3Element<Scalar>::R3Element(__m256d const xyzt)
: xyzt(xyzt) {
static_assert(std::is_standard_layout<R3Element>::value,
"R3Element has a nonstandard layout");
}

template<typename Scalar>
Scalar& R3Element<Scalar>::operator[](int const index) {
switch (index) {
Expand Down Expand Up @@ -111,7 +123,10 @@ R3Element<Scalar>& R3Element<Scalar>::operator-=(

template<typename Scalar>
R3Element<Scalar>& R3Element<Scalar>::operator*=(double const right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const right_256d = ToM256D(right);
xyzt = _mm256_mul_pd(xyzt, right_256d);
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
__m128d const right_128d = ToM128D(right);
xy = _mm_mul_pd(xy, right_128d);
zt = _mm_mul_sd(zt, right_128d);
Expand All @@ -125,7 +140,10 @@ R3Element<Scalar>& R3Element<Scalar>::operator*=(double const right) {

template<typename Scalar>
R3Element<Scalar>& R3Element<Scalar>::operator/=(double const right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const right_256d = ToM256D(right);
xyzt = _mm256_div_pd(xyzt, right_256d);
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
__m128d const right_128d = ToM128D(right);
xy = _mm_div_pd(xy, right_128d);
zt = _mm_div_sd(zt, right_128d);
Expand Down Expand Up @@ -229,7 +247,9 @@ R3Element<Scalar> operator-(R3Element<Scalar> const& right) {
template<typename Scalar>
R3Element<Scalar> operator+(R3Element<Scalar> const& left,
R3Element<Scalar> const& right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
return R3Element<Scalar>(_mm256_add_pd(left.xyzt, right.xyzt));
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
return R3Element<Scalar>(_mm_add_pd(left.xy, right.xy),
_mm_add_sd(left.zt, right.zt));
#else
Expand All @@ -242,7 +262,9 @@ R3Element<Scalar> operator+(R3Element<Scalar> const& left,
template<typename Scalar>
R3Element<Scalar> operator-(R3Element<Scalar> const& left,
R3Element<Scalar> const& right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
return R3Element<Scalar>(_mm256_sub_pd(left.xyzt, right.xyzt));
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
return R3Element<Scalar>(_mm_sub_pd(left.xy, right.xy),
_mm_sub_sd(left.zt, right.zt));
#else
Expand All @@ -257,7 +279,11 @@ template<typename LScalar, typename RScalar>
R3Element<Product<LScalar, RScalar>> operator*(
LScalar const& left,
R3Element<RScalar> const& right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const left_256d = ToM256D(left);
return R3Element<Product<LScalar, RScalar>>(
_mm256_mul_pd(left_256d, right.xyzt));
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
__m128d const left_128d = ToM128D(left);
return R3Element<Product<LScalar, RScalar>>(_mm_mul_pd(right.xy, left_128d),
_mm_mul_sd(right.zt, left_128d));
Expand All @@ -272,7 +298,11 @@ template<typename LScalar, typename RScalar>
requires convertible_to_quantity<RScalar>
R3Element<Product<LScalar, RScalar>> operator*(R3Element<LScalar> const& left,
RScalar const& right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const right_256d = ToM256D(right);
return R3Element<Product<LScalar, RScalar>>(
_mm256_mul_pd(left.xyzt, right_256d));
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
__m128d const right_128d = ToM128D(right);
return R3Element<Product<LScalar, RScalar>>(_mm_mul_pd(left.xy, right_128d),
_mm_mul_sd(left.zt, right_128d));
Expand All @@ -287,7 +317,11 @@ template<typename LScalar, typename RScalar>
requires convertible_to_quantity<RScalar>
R3Element<Quotient<LScalar, RScalar>> operator/(R3Element<LScalar> const& left,
RScalar const& right) {
#if PRINCIPIA_USE_SSE3_INTRINSICS()
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const right_256d = ToM256D(right);
return R3Element<Quotient<LScalar, RScalar>>(
_mm256_div_pd(left.xyzt, right_256d));
#elif PRINCIPIA_USE_SSE3_INTRINSICS()
__m128d const right_128d = ToM128D(right);
return R3Element<Quotient<LScalar, RScalar>>(_mm_div_pd(left.xy, right_128d),
_mm_div_sd(left.zt, right_128d));
Expand All @@ -305,9 +339,15 @@ R3Element<Product<LScalar, RScalar>> FusedMultiplyAdd(
RScalar const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const b_256d = ToM256D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fmadd_pd(a.xyzt, b_256d, c.xyzt));
#else
__m128d const b_128d = ToM128D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm_fmadd_pd(a.xy, b_128d, c.xy), _mm_fmadd_sd(a.zt, b_128d, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -320,9 +360,15 @@ R3Element<Product<LScalar, RScalar>> FusedMultiplySubtract(
RScalar const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const b_256d = ToM256D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fmsub_pd(a.xyzt, b_256d, c.xyzt));
#else
__m128d const b_128d = ToM128D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm_fmsub_pd(a.xy, b_128d, c.xy), _mm_fmsub_sd(a.zt, b_128d, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -335,9 +381,15 @@ R3Element<Product<LScalar, RScalar>> FusedNegatedMultiplyAdd(
RScalar const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const b_256d = ToM256D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fnmadd_pd(a.xyzt, b_256d, c.xyzt));
#else
__m128d const b_128d = ToM128D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm_fnmadd_pd(a.xy, b_128d, c.xy), _mm_fnmadd_sd(a.zt, b_128d, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -350,9 +402,15 @@ R3Element<Product<LScalar, RScalar>> FusedNegatedMultiplySubtract(
RScalar const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const b_256d = ToM256D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fnmsub_pd(a.xyzt, b_256d, c.xyzt));
#else
__m128d const b_128d = ToM128D(b);
return R3Element<Product<LScalar, RScalar>>(
_mm_fnmsub_pd(a.xy, b_128d, c.xy), _mm_fnmsub_sd(a.zt, b_128d, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -365,9 +423,15 @@ R3Element<Product<LScalar, RScalar>> FusedMultiplyAdd(
R3Element<RScalar> const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const a_256d = ToM256D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fmadd_pd(a_256d, b.xyzt, c.xyzt));
#else
__m128d const a_128d = ToM128D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm_fmadd_pd(a_128d, b.xy, c.xy), _mm_fmadd_sd(a_128d, b.zt, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -380,9 +444,15 @@ R3Element<Product<LScalar, RScalar>> FusedMultiplySubtract(
R3Element<RScalar> const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const a_256d = ToM256D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fmsub_pd(a_256d, b.xyzt, c.xyzt));
#else
__m128d const a_128d = ToM128D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm_fmsub_pd(a_128d, b.xy, c.xy), _mm_fmsub_sd(a_128d, b.zt, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -395,9 +465,15 @@ R3Element<Product<LScalar, RScalar>> FusedNegatedMultiplyAdd(
R3Element<RScalar> const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const a_256d = ToM256D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm256_fnmadd_pd(a_256d, b.xyzt, c.xyzt));
#else
__m128d const a_128d = ToM128D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm_fnmadd_pd(a_128d, b.xy, c.xy), _mm_fnmadd_sd(a_128d, b.zt, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand All @@ -410,9 +486,15 @@ R3Element<Product<LScalar, RScalar>> FusedNegatedMultiplySubtract(
R3Element<RScalar> const& b,
R3Element<Product<LScalar, RScalar>> const& c) {
if constexpr (CanEmitFMAInstructions) {
#if PRINCIPIA_USE_AVX2_INTRINSICS()
__m256d const a_256d = ToM256D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm_fnmsub_pd(a_256d, b.xyzt, c.xyzt));
#else
__m128d const a_128d = ToM128D(a);
return R3Element<Product<LScalar, RScalar>>(
_mm_fnmsub_pd(a_128d, b.xy, c.xy), _mm_fnmsub_sd(a_128d, b.zt, c.zt));
#endif
} else {
LOG(FATAL) << "Clang cannot use FMA without VEX-encoding everything";
}
Expand Down
6 changes: 3 additions & 3 deletions numerics/numerics.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@
<ClInclude Include="apodization.hpp">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="apodization_body.hpp">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="fast_fourier_transform.hpp">
<Filter>Header Files</Filter>
</ClInclude>
Expand Down Expand Up @@ -388,6 +385,9 @@
<ClCompile Include="matrix_views_test.cpp">
<Filter>Test Files</Filter>
</ClCompile>
<ClCompile Include="apodization.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<Text Include="xgscd.proto.txt">
Expand Down
10 changes: 10 additions & 0 deletions principia.props
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release AVX2|x64">
<Configuration>Release AVX2</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>

<ItemDefinitionGroup Condition="'$(Configuration)'=='Release AVX2'">
<ClCompile>
<EnableEnhancedInstructionSet>AdvancedVectorExtensions2</EnableEnhancedInstructionSet>
</ClCompile>
</ItemDefinitionGroup>

<!--Microsoft C++ stuff.-->
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />

Expand Down
7 changes: 7 additions & 0 deletions quantities/quantities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ class Quantity final {

template<typename Dimensions>
friend __m128d ToM128D(Quantity<Dimensions> x);

template<typename Dimensions>
friend __m256d ToM256D(Quantity<Dimensions> x);
};

template<typename LDimensions, typename RDimensions>
Expand All @@ -122,8 +125,11 @@ operator/(double, Quantity<RDimensions> const&);
template<typename Q>
constexpr Q SIUnit() { return Q(1); };

inline __m256d ToM256D(double x);
inline __m128d ToM128D(double x);

template<typename Dimensions>
__m256d ToM256D(Quantity<Dimensions> x);
template<typename Dimensions>
__m128d ToM128D(Quantity<Dimensions> x);

Expand Down Expand Up @@ -169,6 +175,7 @@ using internal::Quantity;
using internal::Temperature;
using internal::Time;
using internal::ToM128D;
using internal::ToM256D;

} // namespace _quantities
} // namespace quantities
Expand Down
9 changes: 9 additions & 0 deletions quantities/quantities_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ constexpr Quotient<double, Quantity<RDimensions>> operator/(
return Quotient<double, Quantity<RDimensions>>(left / right.magnitude_);
}

inline __m256d ToM256D(double const x) {
return _mm256_set1_pd(x);
}

inline __m128d ToM128D(double const x) {
return _mm_set1_pd(x);
}
Expand All @@ -136,6 +140,11 @@ __m128d ToM128D(Quantity<Dimensions> const x) {
return _mm_set1_pd(x.magnitude_);
}

template<typename Dimensions>
__m256d ToM256D(Quantity<Dimensions> const x) {
return _mm256_set1_pd(x.magnitude_);
}

template<typename Q>
constexpr bool IsFinite(Q const& x) {
return std::isfinite(x / SIUnit<Q>());
Expand Down

0 comments on commit aa22a46

Please sign in to comment.