Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,34 @@ class SYCLGen : public SYCLGenBase {
endstmt();
return SYCLGenSuccess();
}


bool handle_movmatrix(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 1)
return SYCLGenError();

const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));

if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
return SYCLGenError();

OS() << MapNames::getDpctNamespace() << "experimental::matrix::movmatrix(";


if (emitStmt(Inst->getOutputOperand()))
return SYCLGenError();

OS() << ", ";

if (emitStmt(Inst->getInputOperand(0)))
return SYCLGenError();

OS() << ")";

endstmt();

return SYCLGenSuccess();
}
};

/// Clean the special character in identifier.
Expand Down
31 changes: 31 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,37 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
}
}

/// Transpose 1 8x8 b16 (128 bytes) matrix per sub-group. Requires the sub-group
/// size of kernel calling this function to be 32.
/// \param [output] output: The register to store the transposed matrix fragment. It refers to 2
/// b16 type elements.
/// \param [in] input: The register to store the matrix fragment. It refers to 2 b16
/// type elements.
void movmatrix(uint32_t &output, uint32_t &input) {
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
int laneid = sg.get_local_linear_id();

int elm0_row = laneid / 4;
int elm0_col = (laneid % 4) * 2;
int elm1_row = elm0_row;
int elm1_col = elm0_col + 1;
int src0_row = elm0_col;
int src0_col = elm0_row;
int src1_row = elm1_col;
int src1_col = elm1_row;
int src0_laneid = src0_row * 4 + src0_col / 2;
int src0_pos = src0_col % 2;
int src1_laneid = src1_row * 4 + src1_col / 2;
int src1_pos = src1_col % 2;

auto recv0 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src0_laneid);
auto recv1 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src1_laneid);

auto ptr_out = reinterpret_cast<sycl::half *>(&output);
ptr_out[0] = reinterpret_cast<sycl::half *>(&recv0)[src0_pos];
ptr_out[1] = reinterpret_cast<sycl::half *>(&recv1)[src1_pos];
}

} // namespace matrix
} // namespace experimental

Expand Down
29 changes: 29 additions & 0 deletions clang/test/dpct/asm/movmatrix.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
// RUN: dpct --format-range=none -out-root %T/movmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
// RUN: FileCheck %s --match-full-lines --input-file %T/movmatrix/movmatrix.dp.cpp
// RUN: %if build_lit %{icpx -c -fsycl %T/movmatrix/movmatrix.dp.cpp -o %T/movmatrix/movmatrix.dp.o %}

// clang-format off
#include <cuda_runtime.h>
#include <cstdint>
#include <cuda_bf16.h>

using bf16_2 = __nv_bfloat162;

//Syntax:
//movmatrix.sync.aligned.shape.trans.type d, a;
//.shape = {.m8n8};
//.type = {.b16};#include <cuda_bf16.h>
// Only .m8n8.b16
//

__global__ void movmatrix(bf16_2 &dst, const bf16_2 &src) {

// CHECK: dpct::experimental::matrix::movmatrix(*(uint32_t *)(&dst), (*(uint32_t *)(&src)));
asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n"
: "+r"(*(uint32_t *)(&dst))
: "r"(*(uint32_t *)(&src)));
}

// clang-format on
Loading