Skip to content

Commit

Permalink
Implemented the missing AVX512BF16 intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
sayantn committed Jul 1, 2024
1 parent 5a4089c commit bcf81e8
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 16 deletions.
15 changes: 0 additions & 15 deletions crates/core_arch/missing-x86.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,12 @@
</p></details>


<details><summary>["AVX512_BF16", "AVX512F"]</summary><p>

* [ ] [`_mm512_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_cvtpbh_ps)
* [ ] [`_mm512_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_cvtpbh_ps)
* [ ] [`_mm512_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_maskz_cvtpbh_ps)
* [ ] [`_mm_cvtsbh_ss`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
</p></details>


<details><summary>["AVX512_BF16", "AVX512VL"]</summary><p>

* [ ] [`_mm256_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtpbh_ps)
* [ ] [`_mm256_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mask_cvtpbh_ps)
* [ ] [`_mm256_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_maskz_cvtpbh_ps)
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
* [ ] [`_mm_cvtness_sbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
* [ ] [`_mm_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtpbh_ps)
* [ ] [`_mm_mask_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
* [ ] [`_mm_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtpbh_ps)
* [ ] [`_mm_maskz_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
* [ ] [`_mm_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtpbh_ps)
</p></details>


Expand Down
185 changes: 185 additions & 0 deletions crates/core_arch/src/x86/avx512bf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,85 @@ pub unsafe fn _mm512_maskz_dpbf16_ps(
transmute(simd_select_bitmask(k, rst, zero))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_cvtpbh_ps(a: __m256bh) -> __m512 {
_mm512_castsi512_ps(_mm512_slli_epi32::<16>(_mm512_cvtepi16_epi32(transmute(a))))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_mask_cvtpbh_ps(src: __m512, k: __mmask16, a: __m256bh) -> __m512 {
let cvt = _mm512_cvtpbh_ps(a);
transmute(simd_select_bitmask(k, cvt.as_f32x16(), src.as_f32x16()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_maskz_cvtpbh_ps(k: __mmask16, a: __m256bh) -> __m512 {
let cvt = _mm512_cvtpbh_ps(a);
let zero = _mm512_setzero_ps();
transmute(simd_select_bitmask(k, cvt.as_f32x16(), zero.as_f32x16()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_cvtpbh_ps(a: __m128bh) -> __m256 {
_mm256_castsi256_ps(_mm256_slli_epi32::<16>(_mm256_cvtepi16_epi32(transmute(a))))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_mask_cvtpbh_ps(src: __m256, k: __mmask8, a: __m128bh) -> __m256 {
let cvt = _mm256_cvtpbh_ps(a);
transmute(simd_select_bitmask(k, cvt.as_f32x8(), src.as_f32x8()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m256 {
let cvt = _mm256_cvtpbh_ps(a);
let zero = _mm256_setzero_ps();
transmute(simd_select_bitmask(k, cvt.as_f32x8(), zero.as_f32x8()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtpbh_ps(a: __m128bh) -> __m128 {
_mm_castsi128_ps(_mm_slli_epi32::<16>(_mm_cvtepi16_epi32(transmute(a))))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_mask_cvtpbh_ps(src: __m128, k: __mmask8, a: __m128bh) -> __m128 {
let cvt = _mm_cvtpbh_ps(a);
transmute(simd_select_bitmask(k, cvt.as_f32x4(), src.as_f32x4()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
let cvt = _mm_cvtpbh_ps(a);
let zero = _mm_setzero_ps();
transmute(simd_select_bitmask(k, cvt.as_f32x4(), zero.as_f32x4()))
}

#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
f32::from_bits((a as u32) << 16)
}

#[cfg(test)]
mod tests {
use crate::{core_arch::x86::*, mem::transmute};
Expand Down Expand Up @@ -1592,4 +1671,110 @@ mod tests {
];
assert_eq!(result, expected_result);
}

const BF16_ONE: u16 = 0b0_01111111_0000000;
const BF16_TWO: u16 = 0b0_10000000_0000000;
const BF16_THREE: u16 = 0b0_10000000_1000000;
const BF16_FOUR: u16 = 0b0_10000001_0000000;
const BF16_FIVE: u16 = 0b0_10000001_0100000;
const BF16_SIX: u16 = 0b0_10000001_1000000;
const BF16_SEVEN: u16 = 0b0_10000001_1100000;
const BF16_EIGHT: u16 = 0b0_10000010_0000000;

#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm512_cvtpbh_ps() {
let a = __m256bh(
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
);
let r = _mm512_cvtpbh_ps(a);
let e = _mm512_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
assert_eq_m512(r, e);
}

#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm512_mask_cvtpbh_ps() {
let a = __m256bh(
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
);
let src = _mm512_setr_ps(9., 10., 11., 12., 13., 14., 15., 16., 9., 10., 11., 12., 13., 14., 15., 16.);
let k = 0b1010_1010_1010_1010;
let r = _mm512_mask_cvtpbh_ps(src, k, a);
let e = _mm512_setr_ps(9., 2., 11., 4., 13., 6., 15., 8., 9., 2., 11., 4., 13., 6., 15., 8.);
assert_eq_m512(r, e);
}

#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm512_maskz_cvtpbh_ps() {
let a = __m256bh(
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
);
let k = 0b1010_1010_1010_1010;
let r = _mm512_maskz_cvtpbh_ps(k, a);
let e = _mm512_setr_ps(0., 2., 0., 4., 0., 6., 0., 8., 0., 2., 0., 4., 0., 6., 0., 8.);
assert_eq_m512(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm256_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
let r = _mm256_cvtpbh_ps(a);
let e = _mm256_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
assert_eq_m256(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm256_mask_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
let src = _mm256_setr_ps(9., 10., 11., 12., 13., 14., 15., 16.);
let k = 0b1010_1010;
let r = _mm256_mask_cvtpbh_ps(src, k, a);
let e = _mm256_setr_ps(9., 2., 11., 4., 13., 6., 15., 8.);
assert_eq_m256(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm256_maskz_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT);
let k = 0b1010_1010;
let r = _mm256_maskz_cvtpbh_ps(k, a);
let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.);
assert_eq_m256(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
let r = _mm_cvtpbh_ps(a);
let e = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
assert_eq_m128(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_mask_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
let src = _mm_setr_ps(9., 10., 11., 12.);
let k = 0b1010;
let r = _mm_mask_cvtpbh_ps(src, k, a);
let e = _mm_setr_ps(9., 2., 11., 4.);
assert_eq_m128(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_maskz_cvtpbh_ps() {
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
let k = 0b1010;
let r = _mm_maskz_cvtpbh_ps(k, a);
let e = _mm_setr_ps(0., 2., 0., 4.);
assert_eq_m128(r, e);
}

#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm_cvtsbh_ss() {
let r = _mm_cvtsbh_ss(BF16_ONE);
assert_eq!(r, 1.);
}

}
2 changes: 1 addition & 1 deletion crates/stdarch-verify/tests/x86-intel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ fn equate(
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
(&Type::PrimSigned(64), "__int64" | "long long") => {}
(&Type::PrimUnsigned(8), "unsigned char") => {}
(&Type::PrimUnsigned(16), "unsigned short") => {}
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
(
&Type::PrimUnsigned(32),
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",
Expand Down

0 comments on commit bcf81e8

Please sign in to comment.