Skip to content

Commit

Permalink
intermediate push
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Nov 18, 2024
1 parent 076998c commit 99aec95
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 26 deletions.
152 changes: 152 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,158 @@ QuantizeARow_CompInt8_avx512(
}
}

__m512
ComputeMulScal(const float* a_ptr, size_t step, float& scale)
{
const __m512 signBit = _mm512_set1_ps(-0.0f);
__m512 maxAbs = _mm512_setzero_ps();

for (size_t kk = 0; kk < step; kk += 16) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), a_ptr + kk);

// Compute max(abs(e)) for the block
maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0));
}

__m256 max8 =
_mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0));
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
const float maxScalar = _mm_cvtss_f32(max4);

// Quantize these floats
scale = maxScalar / 127.f;

const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f;
return _mm512_set1_ps(inverse_scale);
}

void
QuantizeInt8ComputeBlksum(const float* a_ptr, size_t step, __m512& mul, float scale, __m256i& i0_32_epi8, float& blksum)
{
const __m256i one_16_epi16 = _mm256_set1_epi16(1);
__m256i sum_16_epi16 = _mm256_setzero_si256();
__m128i i_16_epi8[2];
int index = 0;
for (size_t kk = 0; kk < step; kk += 16, index++) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + kk);
v0 = _mm512_mul_ps(v0, mul);

// Round to nearest integer
v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST);

// Convert floats to integers
__m512i i0 = _mm512_cvtps_epi32(v0);

// Convert int32 to int8
i_16_epi8[index] = _mm512_cvtepi32_epi8(i0);
//_mm_storeu_si128(dst++, i0_8);

// accumulate Sum(a_i)
__m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8[index]);
sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16);
}
i0_32_epi8 = _mm256_set_m128i(i_16_epi8[0], i_16_epi8[1]);
const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16);
blksum = scale * hsum_8_epi32(sum_8_epi32);
}

void
Quantize1BlkBlkLen32(const float* a_ptr, size_t step, __m256i& i_32_epi8, float& scale, float& blksum)
{
// 32 float to 32 epi8s in i0_32_epi8
__m512 mul = ComputeMulScal(a_ptr, step, scale);
QuantizeInt8ComputeBlksum(a_ptr, step, mul, scale, i_32_epi8, blksum);
}

void
stoore_4blk_blklen32_interleaved(__m256i i_32_epi8[4], int8_t* blob)
{
// 0 1 2 3 32 33 34 35 64 65 66 67 96 97 98 99
// 4 5 6 7 36 37 38 39 68 69 70 71 100 101 102 103
// 8 9 10 11 40 41 42 43 72 73 74 75 104 105 106 107
// 12 13 14 15 44 45 46 47 76 77 78 79 108 109 110 111
//
// 16 17 18 19 48 49 50 51 80 81 82 83 112 113 114 115
// 20 21 22 23 52 53 54 55 84 85 86 87 116 117 118 119
// 24 25 26 27 56 57 58 59 88 89 90 91 120 121 122 123
// 28 29 30 31 60 61 62 63 92 93 94 95 124 125 126 127

// Interleave and store i_32_epi8[4] in the specified layout
__m256i a0_lower = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x20);
__m256i a0_higher = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x31);
__m256i a1_lower = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x20);
__m256i a1_higher = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x31);

__m512i a_lower = _mm512_inserti64x4(_mm512_castsi256_si512(a0_lower), a1_lower, 1);
__m512i a_higher = _mm512_inserti64x4(_mm512_castsi256_si512(a0_higher), a1_higher, 1);

__m512i idx = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
__m512i a_lower_interleaved = _mm512_permutexvar_epi32(idx, a_lower);
__m512i a_higher_interleaved = _mm512_permutexvar_epi32(idx, a_higher);

_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 0 * 64), a_lower_interleaved);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 1 * 64), a_higher_interleaved);
}

void MLASCALL
QuantizeARow_CompInt8_avx512_blklen32(
const float* A,
size_t CountK,
std::byte* QuantA,
float* QuantAScale,
float* AScaledBlkSum // scale_k * Sum_blklen(a_i)
)
{
const size_t BlkLen = 32;
const int64_t SubBlkLen = 4 * BlkLen; // process 128 weights at a time and then process the remaining weights

const float* a_ptr = A;
int8_t* quant_a_ptr = reinterpret_cast<int8_t*>(QuantA);
float* scale_ptr = QuantAScale;
float* blksum_ptr = AScaledBlkSum;

size_t k_remaining = CountK;

for (; k_remaining >= SubBlkLen; k_remaining -= SubBlkLen) {
__m256i i_32_epi8[4];
float scale[4];
float blksum[4];
for (int i = 0; i < 4; i++) {
Quantize1BlkBlkLen32(a_ptr, BlkLen, i_32_epi8[i], scale[i], blksum[i]);
}
stoore_4blk_blklen32_interleaved(i_32_epi8, quant_a_ptr);
quant_a_ptr += BlkLen * 4;
std::copy(scale, scale + 4, scale_ptr);
scale_ptr += 4;
std::copy(blksum, blksum + 4, blksum_ptr);
blksum_ptr += 4;
}

while (k_remaining > 0) {
//for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t step = std::min(BlkLen, k_remaining);
__m256i i_32_epi8;
float scale;
float blksum;
Quantize1BlkBlkLen32(a_ptr, BlkLen, i_32_epi8, scale, blksum);
_mm256_storeu_epi8(quant_a_ptr, i_32_epi8);
quant_a_ptr += BlkLen;
*scale_ptr = scale;
scale_ptr++;
*blksum_ptr = blksum;
blksum_ptr++;
k_remaining -= BlkLen;
}
}

static void
SQ4BitGemmPackQuantBDataAndBlkSum512(
size_t N,
Expand Down
20 changes: 18 additions & 2 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi
bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127
}

static MLAS_FORCEINLINE
__m512 load_4blksum_512(const float* BlksumPtr)
{
// Load 128-bit data into __m128 register
__m128 blksum4_4_ps = _mm_loadu_ps(BlksumPtr);

// Insert the __m256 register into the lower 256 bits of the __m512 register
return _mm512_insertf32x4(_mm512_setzero_ps(), blksum4_4_ps, 0);
}

static MLAS_FORCEINLINE void
accumulate_blklen32_r1c1blk4_avx512(
const __m512i& av0_64_epi8,
Expand Down Expand Up @@ -103,11 +113,13 @@ accumulate_blklen32_r1c1blk4_avx512vnni(
const std::byte* QuantBDataPtr,
const float* scale_a,
const float* scale_b,
//const float* blksum_a,
//const float* blksum_b,
__m512& acc0
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
Expand All @@ -120,6 +132,10 @@ accumulate_blklen32_r1c1blk4_avx512vnni(

const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);

//const __m512 blksum_a0_ps = load_4blksum_512(blksum_a); // 0123000000000000
//const __m512 blksum_b0_ps = load_4blksum_512(blksum_b); // 0123000000000000
//acc0 = _mm512_fmadd_ps(blksum_a0_ps, blksum_b0_ps, acc0);
}
}

Expand All @@ -138,7 +154,7 @@ accumulate_blklen32_r2c1blk4_avx512vnni(
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,34 @@ PackQuantB(
size_t PackBytePairCount = SubBlkBytePairCount;
size_t PackDataSize = SubBlkDataSize;

auto pack_4blk_blklen32_512 = [](
const std::byte* QuantBData, std::byte* PackedQuantBData,
size_t pack_byte_pair_count, size_t pack_data_size) {
for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) {
// dst: |0~3/16~19|32~35/48~51|64~67/80~83|96~99/112~115| 16 bytes
// |4~7/20~23|36~39/52~55|68~71/84~87|100~103/116~119| 16 bytes
// |8~11/24~27|40~43/56~59|72~75/88~91|104~107/120~123| 16 bytes
// |12~15/28~31|44~47/60~63|76~79/92~95|108~111/124~127| 16 bytes
// 64 bytes (512bits) total of 128 4bit weights, unpack to 2 512 registers, each as:
// 0000111122223333 0000111122223333 0000111122223333 0000111122223333 (64 unsigned int8)
//
_mm512_permutexvar_epi

Check failure on line 207 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

use of undeclared identifier '_mm512_permutexvar_epi'; did you mean '_mm512_permutexvar_epi8'?

Check failure on line 207 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

expected ';' after expression

Check warning on line 207 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

expression result unused [-Wunused-value]
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2];

std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];

dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
} };

auto pack_subblk = [](
const std::byte* QuantBData, std::byte* PackedQuantBData,
size_t pack_byte_pair_count, size_t pack_data_size) {
for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) {
// (avx256)dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 |
// (avx512)dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 |
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2];

Expand Down
50 changes: 26 additions & 24 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,30 +365,32 @@ TEST(MatMulNBits, Float32_Accuracy1) {
}

TEST(MatMulNBits, Float32_Accuracy4) {
TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>();
}

#ifdef MLAS_TARGET_AMD64_IX86
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture<MlasSQNBitGemmTest<Blk
tests_registered += RegisterSingleTest(1, 527, 2131, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(11, 527, 2131, ComputeType, WithThreadpool, Symmetric, true);
// tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(2, 4, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(2, 4, 128, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(3, 4, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(3, 4, 128, ComputeType, WithThreadpool, Symmetric, true);
}
}
}
Expand Down

0 comments on commit 99aec95

Please sign in to comment.