diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index a1e01b6ed238..e8d645c8ae6e 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -3516,6 +3516,34 @@ class SYCLGen : public SYCLGenBase { endstmt(); return SYCLGenSuccess(); } + + + bool handle_movmatrix(const InlineAsmInstruction *Inst) override { + if (Inst->getNumInputOperands() != 1) + return SYCLGenError(); + + const auto *Type = dyn_cast(Inst->getType(0)); + + if (!Type || Type->getKind() != InlineAsmBuiltinType::b16) + return SYCLGenError(); + + OS() << MapNames::getDpctNamespace() << "experimental::matrix::movmatrix("; + + + if (emitStmt(Inst->getOutputOperand())) + return SYCLGenError(); + + OS() << ", "; + + if (emitStmt(Inst->getInputOperand(0))) + return SYCLGenError(); + + OS() << ")"; + + endstmt(); + + return SYCLGenSuccess(); + } }; /// Clean the special character in identifier. diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index fbb56f0f7cf5..0e26f1313a9c 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2799,6 +2799,37 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, } } +/// Transpose 1 8x8 b16 (128 bytes) matrix per sub-group. Requires the sub-group +/// size of kernel calling this function to be 32. +/// \param [output] output: The register to store the transposed matrix fragment. It refers to 2 +/// b16 type elements. +/// \param [in] input: The register to store the matrix fragment. It refers to 2 b16 +/// type elements. +void movmatrix(uint32_t &output, uint32_t &input) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int laneid = sg.get_local_linear_id(); + + int elm0_row = laneid / 4; + int elm0_col = (laneid % 4) * 2; + int elm1_row = elm0_row; + int elm1_col = elm0_col + 1; + int src0_row = elm0_col; + int src0_col = elm0_row; + int src1_row = elm1_col; + int src1_col = elm1_row; + int src0_laneid = src0_row * 4 + src0_col / 2; + int src0_pos = src0_col % 2; + int src1_laneid = src1_row * 4 + src1_col / 2; + int src1_pos = src1_col % 2; + + auto recv0 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src0_laneid); + auto recv1 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src1_laneid); + + auto ptr_out = reinterpret_cast(&output); + ptr_out[0] = reinterpret_cast(&recv0)[src0_pos]; + ptr_out[1] = reinterpret_cast(&recv1)[src1_pos]; +} + } // namespace matrix } // namespace experimental diff --git a/clang/test/dpct/asm/movmatrix.cu b/clang/test/dpct/asm/movmatrix.cu new file mode 100644 index 000000000000..d491c6344cdc --- /dev/null +++ b/clang/test/dpct/asm/movmatrix.cu @@ -0,0 +1,29 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none -out-root %T/movmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/movmatrix/movmatrix.dp.cpp +// RUN: %if build_lit %{icpx -c -fsycl %T/movmatrix/movmatrix.dp.cpp -o %T/movmatrix/movmatrix.dp.o %} + +// clang-format off +#include +#include +#include + +using bf16_2 = __nv_bfloat162; + +//Syntax: +//movmatrix.sync.aligned.shape.trans.type d, a; +//.shape = {.m8n8}; +//.type = {.b16};#include +// Only .m8n8.b16 +// + +__global__ void movmatrix(bf16_2 &dst, const bf16_2 &src) { + + // CHECK: dpct::experimental::matrix::movmatrix(*(uint32_t *)(&dst), (*(uint32_t *)(&src))); + asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "+r"(*(uint32_t *)(&dst)) + : "r"(*(uint32_t *)(&src))); +} + +// clang-format on \ No newline at end of file