From bb1195313735b9a5a099c2cca8d33cacff0c8ef2 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 10 Apr 2024 11:46:59 -0700 Subject: [PATCH] initial int8 matrixperf sample --- samples/99_matrixexperimentsi8/CMakeLists.txt | 11 + samples/99_matrixexperimentsi8/main.cpp | 908 +++++++++++++++++ .../matrix_helpers_i8.cl | 914 ++++++++++++++++++ .../matrix_kernel_tiled_i8.cl | 752 ++++++++++++++ .../matrix_kernels_i8.cl | 613 ++++++++++++ samples/CMakeLists.txt | 1 + 6 files changed, 3199 insertions(+) create mode 100644 samples/99_matrixexperimentsi8/CMakeLists.txt create mode 100644 samples/99_matrixexperimentsi8/main.cpp create mode 100644 samples/99_matrixexperimentsi8/matrix_helpers_i8.cl create mode 100644 samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl create mode 100644 samples/99_matrixexperimentsi8/matrix_kernels_i8.cl diff --git a/samples/99_matrixexperimentsi8/CMakeLists.txt b/samples/99_matrixexperimentsi8/CMakeLists.txt new file mode 100644 index 0000000..b97f9c7 --- /dev/null +++ b/samples/99_matrixexperimentsi8/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2019-2024 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 99 + TARGET matrixexperimentsi8 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl matrix_kernel_tiled_i8.cl) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/99_matrixexperimentsi8/main.cpp new file mode 100644 index 0000000..ff99d3c --- /dev/null +++ b/samples/99_matrixexperimentsi8/main.cpp @@ -0,0 +1,908 @@ +/* +// Copyright (c) 2019-2024 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool zeroData = false; +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +bool roundRobin = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0; }); + } + else if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1; }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = static_cast(r + c); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_int_distribution dist(-64, 64); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); + } +} + +template +static void vnni_matrix( + std::vector &dst, const std::vector &src, + size_t numRows, size_t numCols, size_t factor) +{ + for (size_t r = 0; r < numRows / factor; r++) { + for (size_t c = 0; c < numCols; c++) { + for (size_t k = 0; k < factor; k++) { + dst[r * numCols * factor + c * factor + k] = + src[(r * factor + k) * numCols + c]; + } + } + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + if (C[index] != C_ref[index]) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static void i8_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "i8_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_i8.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperimentsi8 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + if (platformIndex >= platforms.size()) { + printf("Requested platform index is %d, but only %zu platforms were found.\n", + platformIndex, platforms.size()); + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_simd8 = minSubGroupSize == 8; + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + vnni_matrix(Bvnni_vec, B_vec, K, N, 4); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + i8_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x2) { + i8_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x4) { + i8_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x8) { + i8_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x10) { + i8_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x20) { + i8_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + i8_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x80) { + i8_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x100) { + i8_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x200) { + i8_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x800) { + i8_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x1000) { + i8_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl new file mode 100644 index 0000000..1d59121 --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl @@ -0,0 +1,914 @@ +__attribute__((overloadable)) +int activation(int i) +{ +#if defined(ACTIVATION_RELU) + return max(i, 0); +#else // identity + return i; +#endif +} + +__attribute__((overloadable)) +int2 activation(int2 i) +{ + int2 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + return res; +} + +__attribute__((overloadable)) +int4 activation(int4 i) +{ + int4 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + return res; +} + +int8 activation(int8 i) +{ + int8 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + res.s4 = activation(i.s4); + res.s5 = activation(i.s5); + res.s6 = activation(i.s6); + res.s7 = activation(i.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) + +typedef global char* global_aligned_char_ptr __attribute__((align_value(4))); + +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; + return m_index * tM * MM; +} + +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + +// Emulated SIMD8 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc) +{ + int res = acc; + + // TODO: this could use integer dot products instead? + + res = as_char4(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; + res = as_char4(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; + res = as_char4(sub_group_broadcast(a, 0)).z * as_char4(b.s0).z + res; + res = as_char4(sub_group_broadcast(a, 0)).w * as_char4(b.s0).w + res; + + res = as_char4(sub_group_broadcast(a, 1)).x * as_char4(b.s1).x + res; + res = as_char4(sub_group_broadcast(a, 1)).y * as_char4(b.s1).y + res; + res = as_char4(sub_group_broadcast(a, 1)).z * as_char4(b.s1).z + res; + res = as_char4(sub_group_broadcast(a, 1)).w * as_char4(b.s1).w + res; + + res = as_char4(sub_group_broadcast(a, 2)).x * as_char4(b.s2).x + res; + res = as_char4(sub_group_broadcast(a, 2)).y * as_char4(b.s2).y + res; + res = as_char4(sub_group_broadcast(a, 2)).z * as_char4(b.s2).z + res; + res = as_char4(sub_group_broadcast(a, 2)).w * as_char4(b.s2).w + res; + + res = as_char4(sub_group_broadcast(a, 3)).x * as_char4(b.s3).x + res; + res = as_char4(sub_group_broadcast(a, 3)).y * as_char4(b.s3).y + res; + res = as_char4(sub_group_broadcast(a, 3)).z * as_char4(b.s3).z + res; + res = as_char4(sub_group_broadcast(a, 3)).w * as_char4(b.s3).w + res; + + res = as_char4(sub_group_broadcast(a, 4)).x * as_char4(b.s4).x + res; + res = as_char4(sub_group_broadcast(a, 4)).y * as_char4(b.s4).y + res; + res = as_char4(sub_group_broadcast(a, 4)).z * as_char4(b.s4).z + res; + res = as_char4(sub_group_broadcast(a, 4)).w * as_char4(b.s4).w + res; + + res = as_char4(sub_group_broadcast(a, 5)).x * as_char4(b.s5).x + res; + res = as_char4(sub_group_broadcast(a, 5)).y * as_char4(b.s5).y + res; + res = as_char4(sub_group_broadcast(a, 5)).z * as_char4(b.s5).z + res; + res = as_char4(sub_group_broadcast(a, 5)).w * as_char4(b.s5).w + res; + + res = as_char4(sub_group_broadcast(a, 6)).x * as_char4(b.s6).x + res; + res = as_char4(sub_group_broadcast(a, 6)).y * as_char4(b.s6).y + res; + res = as_char4(sub_group_broadcast(a, 6)).z * as_char4(b.s6).z + res; + res = as_char4(sub_group_broadcast(a, 6)).w * as_char4(b.s6).w + res; + + res = as_char4(sub_group_broadcast(a, 7)).x * as_char4(b.s7).x + res; + res = as_char4(sub_group_broadcast(a, 7)).y * as_char4(b.s7).y + res; + res = as_char4(sub_group_broadcast(a, 7)).z * as_char4(b.s7).z + res; + res = as_char4(sub_group_broadcast(a, 7)).w * as_char4(b.s7).w + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// Emulated SIMD16 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) +{ + float res = acc; + + res = as_char2(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; + res = as_char2(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; + res = as_char2(sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res; + res = as_char2(sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res; + + res = as_char2(sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res; + res = as_char2(sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res; + res = as_char2(sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res; + res = as_char2(sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res; + + res = as_char2(sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res; + res = as_char2(sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res; + res = as_char2(sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res; + res = as_char2(sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res; + + res = as_char2(sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res; + res = as_char2(sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res; + res = as_char2(sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res; + res = as_char2(sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res; + + res = as_char2(sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res; + res = as_char2(sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res; + res = as_char2(sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res; + res = as_char2(sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res; + + res = as_char2(sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res; + res = as_char2(sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res; + res = as_char2(sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res; + res = as_char2(sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res; + + res = as_char2(sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res; + res = as_char2(sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res; + res = as_char2(sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res; + res = as_char2(sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res; + + res = as_char2(sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res; + res = as_char2(sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res; + res = as_char2(sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res; + res = as_char2(sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int load_a_rowmajor_d8_m1_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + ret = intel_sub_group_block_read(A_ui + offset_ui); + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int2 load_a_rowmajor_d8_m2_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int4 load_a_rowmajor_d8_m4_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int8 load_a_rowmajor_d8_m8_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + uint16 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return as_int16(ret); +} + +// M rows x K columns x V tiles (in the K dimension) +void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short load_a_rowmajor_d8_m1_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret = intel_sub_group_block_read_us(A_us + offset_us); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short2 load_a_rowmajor_d8_m2_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short4 load_a_rowmajor_d8_m4_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short8 load_a_rowmajor_d8_m8_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s4 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s5 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s6 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s7 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short8(ret); +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort16 ret; + + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; + + return as_short16(ret); +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// K rows x N columns: +// Each work-item loads K values and converts to VNNI. +// Stride is in units of elements. +int8 load_b_rowmajor_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uchar* B_uc = (global uchar*)B; + uint offset = rowStart * stride + colStart; + + uchar row0 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row1 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row2 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row3 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row4 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row5 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row6 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row7 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row8 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row9 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row10 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row11 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row12 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row13 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row14 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row15 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row16 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row17 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row18 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row19 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row20 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row21 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row22 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row23 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row24 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row25 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row26 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row27 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row28 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row29 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row30 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row31 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + + ret.s0 = as_int((uchar4)(row0, row1, row2, row3)); + ret.s1 = as_int((uchar4)(row4, row5, row6, row7)); + ret.s2 = as_int((uchar4)(row8, row9, row10, row11)); + ret.s3 = as_int((uchar4)(row12, row13, row14, row15)); + ret.s4 = as_int((uchar4)(row16, row17, row18, row19)); + ret.s5 = as_int((uchar4)(row20, row21, row22, row23)); + ret.s6 = as_int((uchar4)(row24, row25, row26, row27)); + ret.s7 = as_int((uchar4)(row28, row29, row30, row31)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +int8 load_b_vnni_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + uint offset_ui = rowStart / 4 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +#if 0 + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the K dimension) +void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +void store_c_rowmajor_int32_m1_nx(global int* C, int v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_int32_m2_nx(global int* C, int2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_int32_m4_nx(global int* C, int4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) + +#if 0 +#ifdef cl_intel_subgroup_extended_block_read + +// Note for 2D block reads: +// - the tile width and height is encoded into the function name. +// - base_address is the byte address. Must be 64B aligned. +// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. +// - height is the height of the entire matrix, or equivalently the number of rows. +// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. +// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. + +// Built-in functions are: + +// #ifdef cl_intel_subgroup_extended_block_read +// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); +// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); +// #endif //defined(cl_intel_subgroup_extended_block_read) + + +// For intrinsics, the pattern is: +// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat +// - operation (optional): _transpose or _transform +// - for no transpose or transform: +// - type / elements size: _u8 or _u16 or _u32 or _u64 +// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 +// - tile width: _k64 or _k32 or _k16 or _k8 +// - number of tiles: _v2 or _v1 +// - for transpose: +// - type / element size: _u64 or _u32 +// - number of tile rows: subgroup size (16) +// - tile width: _k4 (for _u64) or _k8 (for _u32) +// - number of tiles: 1 +// - for transform: +// - type / element size: _u16 or _u8 +// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) +// - tile width: subgroup size (16) +// - number of tiles: 1 + +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + +typedef ushort __attribute__((ext_vector_type(32))) ushort32; +typedef ushort __attribute__((ext_vector_type(64))) ushort64; + +typedef uint __attribute__((ext_vector_type(32))) uint32; + +// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: + +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + + +void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); +void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); + +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0] = tmp.lo.lo; + dst[1] = tmp.lo.hi; + dst[2] = tmp.hi.lo; + dst[3] = tmp.hi.hi; +} + +void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo; + dst[0][1] = tmp.lo.hi; + dst[1][0] = tmp.hi.lo; + dst[1][1] = tmp.hi.hi; +} +void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) +{ + ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo.lo; + dst[0][1] = tmp.lo.lo.hi; + dst[0][2] = tmp.lo.hi.lo; + dst[0][3] = tmp.lo.hi.hi; + dst[1][0] = tmp.hi.lo.lo; + dst[1][1] = tmp.hi.lo.hi; + dst[1][2] = tmp.hi.hi.lo; + dst[1][3] = tmp.hi.hi.hi; +} + +uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +// Each block is K rows x N columns, where the K rows have been VNNI transformed. +int8 intel_subgroup_block_read_transform_u16_k16n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: + return as_int8(intel_subgroup_block_read_transform_u16_k16(base_address, width, height, pitch, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k32n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k16n16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) +{ + uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = as_int8(tmp.lo.lo); + dst[0][1] = as_int8(tmp.lo.hi); + dst[1][0] = as_int8(tmp.hi.lo); + dst[1][1] = as_int8(tmp.hi.hi); +} + + +#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C + +void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} + + +void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m16k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} + +#endif // cl_intel_subgroup_extended_block_read +#endif diff --git a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl new file mode 100644 index 0000000..a862744 --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl @@ -0,0 +1,752 @@ +#error "Needs to be updated!" + +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#if !defined(KK) +#define KK 1 +#endif + +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#if !defined(cl_intel_split_work_group_barrier) +#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" +#endif +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 +#endif + +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +#if HAS_SIMD8 + +void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#endif // HAS_SIMD8 + +void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#ifdef cl_intel_subgroup_extended_block_read + +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } + } + } +} + +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn+=2) { + //if (get_sub_group_local_id() == 0) { + // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + int8 tmp[2][2]; + intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); + for (int tnn = 0; tnn < 2; tnn++) { + for (int tkk = 0; tkk < 2; tkk++) { + bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; + } + } + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn + 0][kk] = bTemp.lo; + bData[nn + 1][kk] = bTemp.hi; + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } + } + } +} + +void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) +{ + if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { + const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) + const int kk = 0; + const int mm = sg_index_x % 4; + //if (get_sub_group_local_id() == 0) { + // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } else if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... + const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... + //if (get_sub_group_local_id() == 0) { + // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } else if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn += 2) { + intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 + const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + // TODO: skip prefetch on the last iterations. + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + } + } +} + +#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl new file mode 100644 index 0000000..6e27d8d --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -0,0 +1,613 @@ +#include "matrix_helpers_i8.cl" + +#if EMULATE_tN8 +#define mat_mul_sg8 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg8 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg16 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +kernel void i8_naive(global int* C, global char* A, global char* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + int sum = 0; + for (int k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + + sum = activation(sum); + C[m * N + n] = sum; +} + +// For all i8 kernels tK == 32: +#define tK 32 + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) + +#if HAS_SIMD8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#endif // HAS_SIMD8 + +// rowmajor krenels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#if 0 + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +#endif // cl_intel_subgroup_extended_block_read + +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#endif // disabling these cases for now + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 30f877b..af509b2 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -90,4 +90,5 @@ if(BUILD_EXTENSION_SAMPLES) endif() add_subdirectory( 99_matrixexperiments ) +add_subdirectory( 99_matrixexperimentsi8 ) add_subdirectory( 99_matrixexperimentstf32 )