From 7c64b4701454369996f50e3bd554316cfe23b754 Mon Sep 17 00:00:00 2001 From: Fraser Cormack Date: Thu, 3 Aug 2023 15:18:34 +0100 Subject: [PATCH] [compiler] Introduce mux builtins for 'group operations' This introduces a set of builtins to represent operations like OpenCL's `sub_group_xxx` and `work_group_xxx` and SPIR-V `OpGroupXXX` operations in a language-agnostic way. It also introduces 'vector group' operations which represent sub-groups as the vectorizer currently thinks of them: as groups of work-items being 'simulated' together via vectorized code, despite being executed on only one invocation. These have no direct OpenCL equivalent, but can be used by the compiler to simulate larger sub-groups than the hardware actually provides. Until now, the construction kit's handling of these operations has been centered around analysis of the OpenCL builtins, which is inflexible and brittle. We may one day want to support IR in a format not compatible with OpenCL-like IR, for instance. It also means that the compiler is either constrained by the OpenCL semantics of these operations. These builtins are not yet generated in the default pipeline, but there's a new pass which introduces them to the module by asking the language-level `BuiltinInfo` to replace calls to its concept of group builtins with mux's concept of these builtins. The builtins should also be correctly wired up through the identification and get-or-declare APIs, but are not yet able to be *defined* by mux. This will come in a later change. --- CHANGELOG.md | 2 + doc/modules/mux/changes.rst | 5 + doc/specifications/mux-compiler-spec.rst | 114 +++- doc/specifications/mux-runtime-spec.rst | 2 +- .../custom-lowering-work-item-builtins.rst | 14 +- .../refsi_g1_wi/refsi_mux_builtin_info.h | 10 +- .../source/refsi_mux_builtin_info.cpp | 12 +- .../include/refsi_m1/refsi_mux_builtin_info.h | 5 +- .../source/refsi_mux_builtin_info.cpp | 6 +- .../source/base_module_pass_machinery.cpp | 1 + .../source/base/source/base_pass_registry.def | 1 + .../host/include/host/host_mux_builtin_info.h | 5 +- .../host/source/HostMuxBuiltinInfo.cpp | 6 +- modules/compiler/test/CMakeLists.txt | 3 +- modules/compiler/test/common.h | 25 + modules/compiler/test/group_ops.cpp | 442 +++++++++++++ .../degenerate-sub-group-broadcast-32bit.ll | 4 + .../passes/degenerate-sub-groups-cloning.ll | 4 + .../passes/degenerate-sub-groups-cloning2.ll | 4 + .../test/lit/passes/degenerate-sub-groups.ll | 4 + .../test/lit/passes/replace-sub-group-ops.ll | 284 ++++++++ .../test/lit/passes/replace-work-group-ops.ll | 300 +++++++++ modules/compiler/test/mangling.cpp | 23 +- modules/compiler/test/utils.cpp | 23 +- modules/compiler/utils/CMakeLists.txt | 2 + .../include/compiler/utils/builtin_info.h | 160 ++++- .../include/compiler/utils/cl_builtin_info.h | 7 +- .../compiler/utils/group_collective_helpers.h | 21 +- .../compiler/utils/replace_group_funcs_pass.h | 40 ++ .../compiler/utils/source/builtin_info.cpp | 609 +++++++++++++++++- .../compiler/utils/source/cl_builtin_info.cpp | 476 +++++++++++++- .../utils/source/define_mux_builtins_pass.cpp | 3 +- .../source/degenerate_sub_group_pass.cpp | 2 +- .../utils/source/group_collective_helpers.cpp | 43 +- .../utils/source/mux_builtin_info.cpp | 60 +- .../utils/source/replace_group_funcs_pass.cpp | 50 ++ .../utils/source/replace_wgc_pass.cpp | 52 +- modules/mux/include/mux/mux.h | 2 +- modules/mux/targets/host/include/host/host.h | 2 +- .../mux/targets/riscv/include/riscv/riscv.h | 2 +- modules/mux/tools/api/mux.xml | 2 +- 41 files changed, 2660 insertions(+), 172 deletions(-) create mode 100644 modules/compiler/test/group_ops.cpp create mode 100644 modules/compiler/test/lit/passes/replace-sub-group-ops.ll create mode 100644 modules/compiler/test/lit/passes/replace-work-group-ops.ll create mode 100644 modules/compiler/utils/include/compiler/utils/replace_group_funcs_pass.h create mode 100644 modules/compiler/utils/source/replace_group_funcs_pass.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index da9a99900..75d916281 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ Upgrade guidance: * The mux spec has been bumped: * 0.77.0: to loosen the requirements on the mux `event` type used by DMA builtins. + * 0.78.0: to introduce mux builtins for sub-group, work-group, and + vector-group operations. * The `compiler::ImageArgumentSubstitutionPass` now replaces sampler typed parameters in kernel functions with i32 parameters via a wrapper function. The `host` target as a consequence now passes samplers to kernels as 32-bit diff --git a/doc/modules/mux/changes.rst b/doc/modules/mux/changes.rst index 007c03ba4..4314823ba 100644 --- a/doc/modules/mux/changes.rst +++ b/doc/modules/mux/changes.rst @@ -11,6 +11,11 @@ version increases mean backward compatible bug fixes have been applied. Versions prior to 1.0.0 may contain breaking changes in minor versions as the API is still under development. +0.78.0 +------ + +* Added sub-group, work-group, and vector-group operation builtins. + 0.77.0 ------ diff --git a/doc/specifications/mux-compiler-spec.rst b/doc/specifications/mux-compiler-spec.rst index 08a776c43..2b96eb02f 100644 --- a/doc/specifications/mux-compiler-spec.rst +++ b/doc/specifications/mux-compiler-spec.rst @@ -1,7 +1,7 @@ ComputeMux Compiler Specification ================================= - This is version 0.77.0 of the specification. + This is version 0.78.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, @@ -922,6 +922,10 @@ A Mux implementation **shall** provide definitions for these builtin functions. the current work-group. * ``i32 __mux_get_max_sub_group_size()`` - Returns the maximum subgroup size in the current kernel. +* ``i32 __mux_get_sub_group_size()`` - Returns the number of invocations in the + subgroup. +* ``i32 __mux_get_sub_group_local_id()`` - Returns the unique invocation ID + within the current sub-group. * ``size_t __mux_get_group_id(i32 %i)`` - Returns the unique work-group identifier for the ``%i``'th dimension. * ``i32 __mux_get_work_dim()`` - Returns the number of dimensions in @@ -997,6 +1001,114 @@ A Mux implementation **shall** provide definitions for these builtin functions. as ``__mux_mem_barrier(i32 %scope, i32 %semantics)``. See `below <#memory-and-control-barriers>`__ for more information. +Group operation builtins +~~~~~~~~~~~~~~~~~~~~~~~~ + +ComputeMux defines a variety of builtins to handle operations across a +sub-group, work-group, or *vector group*. + +The builtin functions are overloadable and are mangled according to the type of +operand they operate on. + +Each *work-group* operation takes as its first parameter a 32-bit integer +barrier identifier (``i32 %id``). Note that if barriers are used to implement +these operations, implementations **must** ensure uniqueness of these IDs +themselves, e.g., by running the ``compiler::utils::PrepareBarriersPass``. The +barrier identifier parameter is not mangled. + +.. note:: + + The sub-group and work-group builtins are all **uniform**, that is, the + behaviour is undefined unless all invocations in the group reach this point + of execution. + + Future versions of ComputeMux **may** add **non-uniform** versions of these + builtins. + +The groups are defined as: + +* ``work-group`` - a group of invocations running together as part of an ND + range. These builtins **must** only take scalar values. +* ``sub-group`` - a subset of invocations in a work-group which can synchronize + and share data efficiently. ComputeMux leaves the choice of sub-group size + and implementation to the target; ComputeMux only defines these builtins with + a "trivial" sub-group size of 1. These builtins **must** only take scalar + values. +* ``vec-group`` - a software level group of invocations processing data in + parallel *on a single invocation*. This allows the compiler to simulate a + sub-group without any hardware sub-group support (e.g., through + vectorization). These builtins **may** take scalar *or vector* values. The + scalar versions of these builtins are essentially identical to the + corresponding ``sub-group`` builtins with a sub-group size of 1. + + +``any``/``all`` builtins +++++++++++++++++++++++++ + +The ``any`` and ``all`` builtins return ``true`` if any/all of their operands +are ``true`` and ``false`` otherwise. + +.. code:: llvm + + i1 @__mux_sub_group_any_i1(i1 %x) + i1 @__mux_work_group_any_i1(i32 %id, i1 %x) + i1 @__mux_vec_group_any_v4i1(<4 x i1> %x) + +``broadcast`` builtins +++++++++++++++++++++++ + +The ``broadcast`` builtins broadcast the value corresponding to the local ID to +the result of all invocations in the group. The sub-group version of this +builtin takes an ``i32`` sub-group linear ID to identify the invocation to +broadcast, and the work-group version take three ``size_t`` indices to locate +the value to broadcast. Unused indices (e.g., in lower-dimension kernels) +**must** be set to zero - this is the same value returned by +``__mux_get_global_id`` for out-of-range dimensions. + +.. code:: llvm + + i64 @__mux_sub_group_broadcast_i64(i64 %val, i32 %sg_lid) + i32 @__mux_work_group_broadcast_i32(i32 %id, i32 %val, i64 %lidx, i64 %lidy, i64 %lidz) + i64 @__mux_vec_group_broadcast_v2i64(<2 x i64> %val, i32 %vec_id) + +``reduce`` and ``scan`` builtins +++++++++++++++++++++++++++++++++ + +The ``reduce`` and ``scan`` builtins return the result of the group operation +for all values of their parameters specified by invocations in the group. + +Scans may be either ``inclusive`` or ``exclusive``. Inclusive scans perform the +operation over all invocations in the group. Exclusive scans perform the +operation over the operation's identity value and all but the final invocation +in the group. + +The group operation may be specified as one of: + +* ``add``/``fadd`` - integer/floating-point addition. +* ``mul``/``fmul`` - integer/floating-point multiplication. +* ``smin``/``umin``/``fmin`` - signed integer/unsigned integer/floating-point minimum. +* ``smax``/``umax``/``fmax`` - signed integer/unsigned integer/floating-point maximum. +* ``and``/``or``/``xor`` - bitwise ``and``/``or``/``xor``. +* ``logical_and``/``logical_or``/``logical_xor`` - logical ``and``/``or``/``xor``. + +Examples: + +.. code:: llvm + + i32 @__mux_sub_group_reduce_add_i32(i32 %val) + i32 @__mux_work_group_reduce_add_i32(i32 %id, i32 %val) + float @__mux_work_group_reduce_fadd_f32(i32 %id, float %val) + + i32 @__mux_sub_group_scan_inclusive_mul_i32(i32 %val) + i32 @__mux_work_group_scan_inclusive_mul_i32(i32 %id, i32 %val) + float @__mux_work_group_scan_inclusive_fmul_f32(i32 %id, float %val) + + i64 @__mux_sub_group_scan_exclusive_mul_i64(i64 %val) + i64 @__mux_work_group_scan_exclusive_mul_i64(i32 %id, i64 %val) + double @__mux_work_group_scan_exclusive_fmul_f64(i32 %id, double %val) + + i64 @__mux_vec_group_scan_exclusive_mul_nxv1i64( %val) + Memory and Control Barriers --------------------------- diff --git a/doc/specifications/mux-runtime-spec.rst b/doc/specifications/mux-runtime-spec.rst index 377657a7d..daca264d2 100644 --- a/doc/specifications/mux-runtime-spec.rst +++ b/doc/specifications/mux-runtime-spec.rst @@ -1,7 +1,7 @@ ComputeMux Runtime Specification ================================ - This is version 0.77.0 of the specification. + This is version 0.78.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, diff --git a/doc/tutorials/custom-lowering-work-item-builtins.rst b/doc/tutorials/custom-lowering-work-item-builtins.rst index a1735ae8c..cc8c1518b 100644 --- a/doc/tutorials/custom-lowering-work-item-builtins.rst +++ b/doc/tutorials/custom-lowering-work-item-builtins.rst @@ -123,10 +123,11 @@ The code for this example is as follows: .. code:: cpp class MyMuxImpl : public utils::BIMuxInfoConcept { - virtual llvm::Function *defineMuxBuiltin(utils::BuiltinID ID, - llvm::Module &M) override { + virtual llvm::Function *defineMuxBuiltin( + utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}) override { if (ID != utils::eMuxBuiltinGetLocalId) { - return BIMuxInfoConcept::defineMuxBuiltin(ID, M); + return BIMuxInfoConcept::defineMuxBuiltin(ID, M, OverloadInfo); } llvm::Function *F = M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID)); @@ -390,8 +391,9 @@ data beyond the view of ComputeMux, e.g., in the driver or the HAL. return List; } - virtual llvm::Function *defineMuxBuiltin(utils::BuiltinID ID, - llvm::Module &M) override { + virtual llvm::Function *defineMuxBuiltin( + utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}) override { if (ID == utils::eMuxBuiltinGetLocalId) { llvm::Function *F = M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID)); @@ -407,7 +409,7 @@ data beyond the view of ComputeMux, e.g., in the driver or the HAL. B.CreateRet(std::prev(F->arg_end())); return F; } - return BIMuxInfoConcept::defineMuxBuiltin(ID, M); + return BIMuxInfoConcept::defineMuxBuiltin(ID, M, OverloadInfo); } }; diff --git a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/include/refsi_g1_wi/refsi_mux_builtin_info.h b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/include/refsi_g1_wi/refsi_mux_builtin_info.h index 02d5ac5ec..41b304a83 100644 --- a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/include/refsi_g1_wi/refsi_mux_builtin_info.h +++ b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/include/refsi_g1_wi/refsi_mux_builtin_info.h @@ -42,11 +42,13 @@ class RefSiG1BIMuxInfo : public compiler::utils::BIMuxInfoConcept { public: static llvm::StructType *getExecStateStruct(llvm::Module &M); - llvm::Function *getOrDeclareMuxBuiltin(compiler::utils::BuiltinID ID, - llvm::Module &M) override; + llvm::Function *getOrDeclareMuxBuiltin( + compiler::utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}) override; - llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID, - llvm::Module &M) override; + llvm::Function *defineMuxBuiltin( + compiler::utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}) override; }; } // namespace refsi_g1_wi diff --git a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_mux_builtin_info.cpp b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_mux_builtin_info.cpp index 4ddf6a34b..caaa600b7 100644 --- a/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_mux_builtin_info.cpp +++ b/examples/refsi/refsi_g1_wi/compiler/refsi_g1_wi/source/refsi_mux_builtin_info.cpp @@ -57,8 +57,10 @@ StructType *RefSiG1BIMuxInfo::getExecStateStruct(Module &M) { } Function *RefSiG1BIMuxInfo::getOrDeclareMuxBuiltin( - compiler::utils::BuiltinID ID, Module &M) { - auto *F = compiler::utils::BIMuxInfoConcept::getOrDeclareMuxBuiltin(ID, M); + compiler::utils::BuiltinID ID, Module &M, + llvm::ArrayRef OverloadInfo) { + auto *F = compiler::utils::BIMuxInfoConcept::getOrDeclareMuxBuiltin( + ID, M, OverloadInfo); if (!F) { return F; } @@ -78,7 +80,8 @@ Function *RefSiG1BIMuxInfo::getOrDeclareMuxBuiltin( } Function *RefSiG1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, - Module &M) { + Module &M, + ArrayRef OverloadInfo) { assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) && "Only handling mux builtins"); Function *F = @@ -213,7 +216,8 @@ Function *RefSiG1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, } if (ID != compiler::utils::eMuxBuiltinGetLocalId) { - return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M); + return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M, + OverloadInfo); } Optional ParamIdx; diff --git a/examples/refsi/refsi_m1/compiler/refsi_m1/include/refsi_m1/refsi_mux_builtin_info.h b/examples/refsi/refsi_m1/compiler/refsi_m1/include/refsi_m1/refsi_mux_builtin_info.h index fdb6a992c..22a097a5e 100644 --- a/examples/refsi/refsi_m1/compiler/refsi_m1/include/refsi_m1/refsi_mux_builtin_info.h +++ b/examples/refsi/refsi_m1/compiler/refsi_m1/include/refsi_m1/refsi_mux_builtin_info.h @@ -23,8 +23,9 @@ namespace refsi_m1 { class RefSiM1BIMuxInfo : public compiler::utils::BIMuxInfoConcept { public: - llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID, - llvm::Module &M) override; + llvm::Function *defineMuxBuiltin( + compiler::utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}) override; }; } // namespace refsi_m1 diff --git a/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_mux_builtin_info.cpp b/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_mux_builtin_info.cpp index a1cb810f8..8b7ff9e56 100644 --- a/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_mux_builtin_info.cpp +++ b/examples/refsi/refsi_m1/compiler/refsi_m1/source/refsi_mux_builtin_info.cpp @@ -375,7 +375,8 @@ void defineRefSiDmaWait(Function &F) { } Function *RefSiM1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, - Module &M) { + Module &M, + ArrayRef OverloadInfo) { assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) && "Only handling mux builtins"); auto FnName = compiler::utils::BuiltinInfo::getMuxBuiltinName(ID); @@ -390,7 +391,8 @@ Function *RefSiM1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, switch (ID) { default: - return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M); + return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M, + OverloadInfo); case compiler::utils::eMuxBuiltinDMARead1D: case compiler::utils::eMuxBuiltinDMAWrite1D: defineRefSiDma1D(*F, *this); diff --git a/modules/compiler/source/base/source/base_module_pass_machinery.cpp b/modules/compiler/source/base/source/base_module_pass_machinery.cpp index 249f7049b..b20453db3 100644 --- a/modules/compiler/source/base/source/base_module_pass_machinery.cpp +++ b/modules/compiler/source/base/source/base_module_pass_machinery.cpp @@ -57,6 +57,7 @@ #include #include #include +#include #include #include #include diff --git a/modules/compiler/source/base/source/base_pass_registry.def b/modules/compiler/source/base/source/base_pass_registry.def index 617f3c1f0..9fec2f19f 100644 --- a/modules/compiler/source/base/source/base_pass_registry.def +++ b/modules/compiler/source/base/source/base_pass_registry.def @@ -43,6 +43,7 @@ MODULE_PASS("replace-atomic-funcs", compiler::utils::ReplaceAtomicFuncsPass()) MODULE_PASS("replace-barriers", compiler::utils::ReplaceBarriersPass()) MODULE_PASS("replace-c11-atomic-funcs", compiler::utils::ReplaceC11AtomicFuncsPass()) +MODULE_PASS("replace-group-funcs", compiler::utils::ReplaceGroupFuncsPass()) MODULE_PASS("replace-wgc", compiler::utils::ReplaceWGCPass()) MODULE_PASS("replace-module-scope-vars", diff --git a/modules/compiler/targets/host/include/host/host_mux_builtin_info.h b/modules/compiler/targets/host/include/host/host_mux_builtin_info.h index cff5c0d13..8d300eb3f 100644 --- a/modules/compiler/targets/host/include/host/host_mux_builtin_info.h +++ b/modules/compiler/targets/host/include/host/host_mux_builtin_info.h @@ -46,8 +46,9 @@ class HostBIMuxInfo : public compiler::utils::BIMuxInfoConcept { llvm::SmallVector getMuxSchedulingParameters(llvm::Module &M) override; - llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID, - llvm::Module &M) override; + llvm::Function *defineMuxBuiltin( + compiler::utils::BuiltinID ID, llvm::Module &M, + llvm::ArrayRef OverloadInfo) override; llvm::Value *initializeSchedulingParamForWrappedKernel( const compiler::utils::BuiltinInfo::SchedParamInfo &Info, diff --git a/modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp b/modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp index 5b8d70a9e..ac559d8a5 100644 --- a/modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp +++ b/modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp @@ -138,7 +138,8 @@ HostBIMuxInfo::getMuxSchedulingParameters(Module &M) { } Function *HostBIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, - Module &M) { + Module &M, + ArrayRef OverloadInfo) { assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) && "Only handling mux builtins"); Function *F = @@ -157,7 +158,8 @@ Function *HostBIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID, switch (ID) { default: - return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M); + return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M, + OverloadInfo); case compiler::utils::eMuxBuiltinGetLocalSize: ParamIdx = SchedParamIndices::SCHED; DefaultVal = 1; diff --git a/modules/compiler/test/CMakeLists.txt b/modules/compiler/test/CMakeLists.txt index 5f60a7689..b8c15e29d 100644 --- a/modules/compiler/test/CMakeLists.txt +++ b/modules/compiler/test/CMakeLists.txt @@ -16,6 +16,7 @@ add_ca_executable(UnitCompiler ${CMAKE_CURRENT_SOURCE_DIR}/common.h + ${CMAKE_CURRENT_SOURCE_DIR}/group_ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/library.cpp @@ -30,7 +31,7 @@ target_include_directories(UnitCompiler PRIVATE ${PROJECT_SOURCE_DIR}/modules/compiler/include) target_link_libraries(UnitCompiler PRIVATE cargo - compiler-static mux ca_gtest_main compiler-utils) + compiler-static mux ca_gtest_main compiler-base compiler-utils) target_resources(UnitCompiler NAMESPACES ${BUILTINS_NAMESPACES}) diff --git a/modules/compiler/test/common.h b/modules/compiler/test/common.h index 462f663b0..f1d13160e 100644 --- a/modules/compiler/test/common.h +++ b/modules/compiler/test/common.h @@ -28,6 +28,9 @@ #include #include #include +#include +#include +#include #include #include @@ -323,6 +326,28 @@ static inline std::vector deferrableCompilers() { return deferrable_compilers; } +/// @brief Fixture for testing behavior of the compiler with LLVM modules. +/// +/// Tests based on this fixture should test the behavior of +/// LLVM-based APIs and transforms. +struct CompilerLLVMModuleTest : ::testing::Test { + void SetUp() override {} + + std::unique_ptr parseModule(llvm::StringRef Assembly) { + llvm::SMDiagnostic Error; + auto M = llvm::parseAssemblyString(Assembly, Error, Context); + + std::string ErrMsg; + llvm::raw_string_ostream OS(ErrMsg); + Error.print("", OS); + EXPECT_TRUE(M) << OS.str(); + + return M; + } + + llvm::LLVMContext Context; +}; + /// @brief Macro for instantiating test fixture parameterized over all /// compiler targets available on the platform. /// diff --git a/modules/compiler/test/group_ops.cpp b/modules/compiler/test/group_ops.cpp new file mode 100644 index 000000000..640f4a1c8 --- /dev/null +++ b/modules/compiler/test/group_ops.cpp @@ -0,0 +1,442 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common.h" +#include "compiler/module.h" + +using namespace llvm; +using namespace compiler::utils; + +class GroupOpsTest : public CompilerLLVMModuleTest { + public: + std::unique_ptr PassMach; + + void SetUp() override { + CompilerLLVMModuleTest::SetUp(); + + auto Callback = [](const llvm::Module &) { + return compiler::utils::BuiltinInfo( + compiler::utils::createCLBuiltinInfo(/*builtins*/ nullptr)); + }; + + PassMach = std::make_unique( + Context, /*TM*/ nullptr, /*Info*/ multi_llvm::None, Callback, + /*verify*/ false, /*logging level*/ DebugLogging::None, + /*time passes*/ false); + PassMach->initializeStart(); + PassMach->initializeFinish(); + } + + struct GroupOp { + GroupOp(StringRef FnName, StringRef LLVMTy, GroupCollective C) + : MangledFnName(FnName), LLVMTy(LLVMTy), Collective(C) {} + + std::string getLLVMFnString(StringRef ParamName = "%x") const { + std::string FnStr = + LLVMTy + " @" + MangledFnName + "(" + LLVMTy + " " + ParamName.str(); + if (Collective.Op == GroupCollective::OpKind::Broadcast) { + if (Collective.Scope == GroupCollective::ScopeKind::SubGroup) { + FnStr += ", i32 %sg_lid"; + } else { + FnStr += ", i64 %lid_x, i64 %lid_y, i64 %lid_z"; + } + } + FnStr += ")"; + return FnStr; + } + + std::string MangledFnName; + std::string LLVMTy; + GroupCollective Collective; + }; + + static std::string getGroupBuiltinBaseName(GroupCollective::ScopeKind Scope) { + return std::string(Scope == GroupCollective::ScopeKind::SubGroup ? "sub" + : Scope == GroupCollective::ScopeKind::VectorGroup + ? "vec" + : "work") + + "_group_"; + } + + std::vector getGroupBroadcasts(GroupCollective::ScopeKind Scope) { + std::vector GroupOps; + std::string BaseName = getGroupBuiltinBaseName(Scope); + + NameMangler Mangler(&Context); + Type *const I32Ty = IntegerType::getInt32Ty(Context); + Type *const I64Ty = IntegerType::getInt64Ty(Context); + Type *const FloatTy = IntegerType::getFloatTy(Context); + + GroupCollective Collective; + Collective.IsLogical = false; + Collective.Scope = Scope; + // Broadcasts don't expect a recursion kind. + Collective.Recurrence = RecurKind::None; + Collective.Op = GroupCollective::OpKind::Broadcast; + + if (Scope == GroupCollective::ScopeKind::SubGroup || + Scope == GroupCollective::ScopeKind::VectorGroup) { + std::string BuiltinName = BaseName + "broadcast"; + SmallVector QualsVec; + QualsVec.push_back(eTypeQualNone); + // And another for the index + QualsVec.push_back(eTypeQualNone); + // float version + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, {FloatTy, I32Ty}, QualsVec), + "float", Collective)); + // unsigned version + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, {I32Ty, I32Ty}, QualsVec), + "i32", Collective)); + // signed version + QualsVec[0] = eTypeQualSignedInt; + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, {I32Ty, I32Ty}, QualsVec), + "i32", Collective)); + } else { + SmallVector Args; + SmallVector QualsVec; + std::string BuiltinName = BaseName + "broadcast"; + + // Qualifiers for the argument + Args.push_back(nullptr); + QualsVec.push_back(eTypeQualNone); + // Qualifiers for the indices + Args.push_back(I64Ty); + QualsVec.push_back(eTypeQualNone); + Args.push_back(I64Ty); + QualsVec.push_back(eTypeQualNone); + Args.push_back(I64Ty); + QualsVec.push_back(eTypeQualNone); + // float version + Args[0] = FloatTy; + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, Args, QualsVec), "float", + Collective)); + // unsigned version + Args[0] = I32Ty; + GroupOps.emplace_back(GroupOp( + Mangler.mangleName(BuiltinName, Args, QualsVec), "i32", Collective)); + + // signed version + Args[0] = I32Ty; + QualsVec[0] = eTypeQualSignedInt; + GroupOps.emplace_back(GroupOp( + Mangler.mangleName(BuiltinName, Args, QualsVec), "i32", Collective)); + } + + return GroupOps; + } + + // GroupOpKind = "" for reductions, "exclusive" for exclusive scans and + // "inclusive" for inclusive scans. + std::vector getGroupScandAndReductions( + GroupCollective::ScopeKind Scope, std::string GroupOpKind) { + const std::string BaseName = getGroupBuiltinBaseName(Scope); + + NameMangler Mangler(&Context); + Type *const I32Ty = IntegerType::getInt32Ty(Context); + Type *const FloatTy = IntegerType::getFloatTy(Context); + + std::vector GroupOps; + + // All sorts of reductions and scans + for (StringRef OpKind : {"add", "mul", "max", "min", "and", "or", "xor", + "logical_and", "logical_or", "logical_xor"}) { + GroupCollective Collective; + Collective.IsLogical = false; + Collective.Scope = Scope; + + std::string BuiltinName = BaseName; + if (GroupOpKind.empty()) { + BuiltinName += "reduce"; + Collective.Op = GroupCollective::OpKind::Reduction; + } else { + BuiltinName += "scan_"; + Collective.Op = GroupOpKind == "inclusive" + ? GroupCollective::OpKind::ScanInclusive + : GroupCollective::OpKind::ScanExclusive; + } + + if (OpKind == "add") { + Collective.Recurrence = RecurKind::Add; + } else if (OpKind == "mul") { + Collective.Recurrence = RecurKind::Mul; + } else if (OpKind == "max") { + Collective.Recurrence = RecurKind::UMax; + } else if (OpKind == "min") { + Collective.Recurrence = RecurKind::UMin; + } else if (OpKind == "and") { + Collective.Recurrence = RecurKind::And; + } else if (OpKind == "or") { + Collective.Recurrence = RecurKind::Or; + } else if (OpKind == "xor") { + Collective.Recurrence = RecurKind::Xor; + } else if (OpKind == "logical_and") { + Collective.IsLogical = true; + Collective.Recurrence = RecurKind::And; + } else if (OpKind == "logical_or") { + Collective.IsLogical = true; + Collective.Recurrence = RecurKind::Or; + } else if (OpKind == "logical_xor") { + Collective.IsLogical = true; + Collective.Recurrence = RecurKind::Xor; + } else { + llvm_unreachable("unhandled op kind"); + } + + BuiltinName += GroupOpKind + "_" + OpKind.str(); + + TypeQualifiers DefaultQuals; + DefaultQuals.push_back(eTypeQualNone); + + TypeQualifiers SignedIntQuals; + SignedIntQuals.push_back(eTypeQualSignedInt); + + if (OpKind == "add" || OpKind == "mul" || OpKind == "max" || + OpKind == "min") { + // float version + if (OpKind == "add") { + Collective.Recurrence = RecurKind::FAdd; + } else if (OpKind == "mul") { + Collective.Recurrence = RecurKind::FMul; + } else if (OpKind == "max") { + Collective.Recurrence = RecurKind::FMax; + } else if (OpKind == "min") { + Collective.Recurrence = RecurKind::FMin; + } else { + llvm_unreachable("unhandled op kind"); + } + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, FloatTy, DefaultQuals), + "float", Collective)); + } + + // unsigned version + if (OpKind == "add") { + Collective.Recurrence = RecurKind::Add; + } else if (OpKind == "mul") { + Collective.Recurrence = RecurKind::Mul; + } else if (OpKind == "max") { + Collective.Recurrence = RecurKind::UMax; + } else if (OpKind == "min") { + Collective.Recurrence = RecurKind::UMin; + } + + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, I32Ty, DefaultQuals), "i32", + Collective)); + + // signed version + if (OpKind == "max") { + Collective.Recurrence = RecurKind::SMax; + } else if (OpKind == "min") { + Collective.Recurrence = RecurKind::SMin; + } + GroupOps.emplace_back( + GroupOp(Mangler.mangleName(BuiltinName, I32Ty, SignedIntQuals), "i32", + Collective)); + } + + return GroupOps; + } + + std::vector getGroupBuiltins(GroupCollective::ScopeKind Scope, + bool IncludeAnyAll = true, + bool IncludeBroadcasts = true, + bool IncludeReductions = true, + bool IncludeScans = true) { + std::vector GroupOps; + std::string BaseName = getGroupBuiltinBaseName(Scope); + + if (IncludeAnyAll) { + GroupCollective Collective; + Collective.Op = GroupCollective::OpKind::Any; + Collective.Recurrence = RecurKind::Or; + Collective.IsLogical = false; + Collective.Scope = Scope; + + NameMangler Mangler(&Context); + Type *const I32Ty = IntegerType::getInt32Ty(Context); + + GroupOps.emplace_back(GroupOp( + Mangler.mangleName(BaseName + "any", I32Ty, {eTypeQualSignedInt}), + "i32", Collective)); + + Collective.Op = GroupCollective::OpKind::All; + Collective.Recurrence = RecurKind::And; + GroupOps.emplace_back(GroupOp( + Mangler.mangleName(BaseName + "all", I32Ty, {eTypeQualSignedInt}), + "i32", Collective)); + } + + if (IncludeBroadcasts) { + auto Broadcasts = getGroupBroadcasts(Scope); + GroupOps.insert(GroupOps.end(), Broadcasts.begin(), Broadcasts.end()); + } + + if (IncludeReductions) { + auto Reductions = getGroupScandAndReductions(Scope, ""); + GroupOps.insert(GroupOps.end(), Reductions.begin(), Reductions.end()); + } + + if (IncludeScans) { + auto InclusiveScans = getGroupScandAndReductions(Scope, "inclusive"); + GroupOps.insert(GroupOps.end(), InclusiveScans.begin(), + InclusiveScans.end()); + + auto ExclusiveScans = getGroupScandAndReductions(Scope, "exclusive"); + GroupOps.insert(GroupOps.end(), ExclusiveScans.begin(), + ExclusiveScans.end()); + } + + return GroupOps; + } + + std::string getTestModuleStr(const std::vector &BuiltinCalls, + const std::vector &BuiltinDecls) { + std::string ModuleStr = R"( +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +)"; + + ModuleStr += R"( +define void @test_wrapper(i32 %i, float %f, i32 %sg_lid, i64 %lid_x, i64 %lid_y, i64 %lid_z) { +)"; + + for (const auto &Call : BuiltinCalls) { + ModuleStr += " " + Call + "\n"; + } + + ModuleStr += " ret void\n}\n\n"; + + for (const auto &Decl : BuiltinDecls) { + ModuleStr += Decl + "\n"; + } + + ModuleStr += R"( +!opencl.ocl.version = !{!0} + +!0 = !{i32 3, i32 0} +)"; + + return ModuleStr; + } + + // This tests: + // * auto-generates all possible OpenCL group builtins and calls them in a + // single test function + // * runs the ReplaceGroupFuncsPass to replace calls to the mux builtins + // * tests a round-trip between identifying and declaring those mux builtins + template + void doTestBody() { + auto GroupOps = getGroupBuiltins(GroupScope); + + std::vector BuiltinDecls; + std::vector BuiltinCalls; + unsigned Idx = 0; + for (const auto &Op : GroupOps) { + BuiltinDecls.push_back("declare " + Op.getLLVMFnString()); + + StringRef ParamName = + Op.LLVMTy == "float" ? "%f" : (Op.LLVMTy == "i32" ? "%i" : ""); + BuiltinCalls.push_back("%call" + std::to_string(Idx) + " = call " + + Op.getLLVMFnString(ParamName)); + ++Idx; + } + + std::string ModuleStr = getTestModuleStr(BuiltinCalls, BuiltinDecls); + + auto M = parseModule(ModuleStr); + + ModulePassManager PM; + PM.addPass(ReplaceGroupFuncsPass()); + + PM.run(*M, PassMach->getMAM()); + + auto &BI = PassMach->getMAM().getResult(*M); + + auto *TestFn = M->getFunction("test_wrapper"); + ASSERT_TRUE(TestFn && !TestFn->empty()); + + auto &BB = TestFn->front(); + + DenseSet MuxBuiltins; + DenseSet MuxBuiltinIDs; + // Note we expect the called functions in the basic block to be in the same + // order as the group operations we generated earlier. + unsigned GroupOpIdx = 0; + for (auto &I : BB) { + auto const *CI = dyn_cast(&I); + if (!CI) { + continue; + } + auto *const CalledFn = CI->getCalledFunction(); + EXPECT_TRUE(CalledFn); + MuxBuiltins.insert(CalledFn); + + auto Builtin = BI.analyzeBuiltin(*CalledFn); + std::string InfoStr = " for function " + CalledFn->getName().str() + + " identified as ID " + std::to_string(Builtin.ID); + EXPECT_NE(Builtin.ID, eBuiltinInvalid) << InfoStr; + EXPECT_TRUE(BI.isMuxBuiltinID(Builtin.ID)) << InfoStr; + + // Do a get-or-declare, and make sure we're getting back the exact same + // function. + auto *const BuiltinDecl = + BI.getOrDeclareMuxBuiltin(Builtin.ID, *M, Builtin.mux_overload_info); + EXPECT_TRUE(BuiltinDecl && BuiltinDecl == CalledFn) << InfoStr; + + auto Info = BI.isMuxGroupCollective(Builtin.ID); + ASSERT_TRUE(Info) << InfoStr; + + // Now check that the returned values are what we expect. We don't + // check 'type' or 'function' here as it's not set by either party. + assert(Info && "Asserting the optional to silence a compiler warning"); + EXPECT_EQ(Info->Op, GroupOps[GroupOpIdx].Collective.Op) << InfoStr; + EXPECT_EQ(Info->Scope, GroupOps[GroupOpIdx].Collective.Scope) << InfoStr; + EXPECT_EQ(Info->IsLogical, GroupOps[GroupOpIdx].Collective.IsLogical) + << InfoStr; + EXPECT_EQ(Info->Recurrence, GroupOps[GroupOpIdx].Collective.Recurrence) + << InfoStr; + + ++GroupOpIdx; + } + } +}; + +TEST_F(GroupOpsTest, OpenCLSubgroupOps) { + doTestBody(); +} + +TEST_F(GroupOpsTest, OpenCLWorkgroupOps) { + doTestBody(); +} diff --git a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll b/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll index ce57b8985..b58b35ed5 100644 --- a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll +++ b/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll @@ -44,7 +44,11 @@ entry: } attributes #0 = { "mux-kernel"="entry-point" } + +!opencl.ocl.version = !{!1} + !0 = !{i32 13, i32 64, i32 64} +!1 = !{i32 3, i32 0} ; CHECK: declare spir_func i32 @_Z20work_group_broadcastijjj(i32, i32, i32, i32) declare spir_func i32 @_Z19sub_group_broadcastij(i32, i32) diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll index 61c838575..d55f1e9b9 100644 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll +++ b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll @@ -32,6 +32,10 @@ declare spir_func i32 @_Z20sub_group_reduce_addi(i32) attributes #0 = { "mux-kernel"="entry-point" } +!opencl.ocl.version = !{!0} + +!0 = !{i32 3, i32 0} + ; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test.degenerate-subgroups ; CHECK: (i32 [[Y:%.*]]) #[[ATTR0:[0-9]+]] ; CHECK: entry: diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll index 265827621..cc0808c3f 100644 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll +++ b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll @@ -57,6 +57,10 @@ entry: declare spir_func i32 @_Z20sub_group_reduce_addi(i32) +!opencl.ocl.version = !{!0} + +!0 = !{i32 3, i32 0} + attributes #0 = { "mux-kernel"="entry-point" } ; CHECK: define spir_func i32 @clone_this.degenerate-subgroups(i32 [[X1:%.+]]) { diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups.ll index cd2b322c6..b1f61cb0d 100644 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll +++ b/modules/compiler/test/lit/passes/degenerate-sub-groups.ll @@ -273,3 +273,7 @@ declare spir_func void @__mux_sub_group_barrier(i32, i32, i32) attributes #0 = { "mux-kernel"="entry-point" } !0 = !{i32 13, i32 64, i32 64} ; CHECK: attributes #0 = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" } + +!opencl.ocl.version = !{!1} + +!1 = !{i32 3, i32 0} diff --git a/modules/compiler/test/lit/passes/replace-sub-group-ops.ll b/modules/compiler/test/lit/passes/replace-sub-group-ops.ll new file mode 100644 index 000000000..4c2114aa6 --- /dev/null +++ b/modules/compiler/test/lit/passes/replace-sub-group-ops.ll @@ -0,0 +1,284 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +; RUN: muxc --passes replace-group-funcs,verify %s | FileCheck %s + +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" + +define spir_func i32 @sub_group_size_test() { +; CHECK: %call1 = call i32 @__mux_get_sub_group_size() +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z18get_sub_group_sizev() + ret i32 %call +} + +define spir_func i32 @sub_group_local_id_test() { +; CHECK: %call1 = call i32 @__mux_get_sub_group_local_id() +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z22get_sub_group_local_idv() + ret i32 %call +} + +define spir_func i32 @sub_group_all_test(i32 %x) { +; CHECK: [[T0:%.*]] = icmp ne i32 %x, 0 +; CHECK: %call1 = call i1 @__mux_sub_group_all_i1(i1 [[T0]]) +; CHECK: [[T1:%.*]] = sext i1 %call1 to i32 +; CHECK: ret i32 [[T1]] + %call = call spir_func i32 @_Z13sub_group_alli(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_any_test(i32 %x) { +; CHECK: [[T0:%.*]] = icmp ne i32 %x, 0 +; CHECK: %call1 = call i1 @__mux_sub_group_any_i1(i1 [[T0]]) +; CHECK: [[T1:%.*]] = sext i1 %call1 to i32 +; CHECK: ret i32 [[T1]] + %call = call spir_func i32 @_Z13sub_group_anyi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_broadcasti_test(i32 %x, i32 %lid) { +; CHECK: %call1 = call i32 @__mux_sub_group_broadcast_i32(i32 %x, i32 %lid) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z19sub_group_broadcastij(i32 %x, i32 %lid) + ret i32 %call +} + +define spir_func float @sub_group_broadcastf_test(float %x, i32 %lid) { +; CHECK: %call1 = call float @__mux_sub_group_broadcast_f32(float %x, i32 %lid) +; CHECK: ret float %call1 + %call = call spir_func float @_Z19sub_group_broadcastfj(float %x, i32 %lid) + ret float %call +} + +define spir_func i32 @sub_group_reduce_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_reduce_add_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20sub_group_reduce_addi(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_reduce_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_reduce_fadd_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20sub_group_reduce_addf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_reduce_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_reduce_smin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20sub_group_reduce_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_reduce_minu_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_reduce_umin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20sub_group_reduce_minj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_reduce_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_reduce_fmin_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20sub_group_reduce_minf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_reduce_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_reduce_smax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20sub_group_reduce_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_reduce_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_reduce_umax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20sub_group_reduce_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_reduce_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_reduce_fmax_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20sub_group_reduce_maxf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_exclusive_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_add_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_addi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_exclusive_addj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_add_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_addj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_exclusive_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_exclusive_fadd_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_exclusive_addf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_exclusive_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_smin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_exclusive_minj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_umin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_minj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_exclusive_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_exclusive_fmin_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_exclusive_minf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_exclusive_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_smax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_exclusive_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_exclusive_umax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_exclusive_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_exclusive_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_exclusive_fmax_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_exclusive_maxf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_inclusive_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_add_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_addi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_inclusive_addj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_add_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_addj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_inclusive_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_inclusive_fadd_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_inclusive_addf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_inclusive_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_smin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_inclusive_minj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_umin_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_minj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_inclusive_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_inclusive_fmin_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_inclusive_minf(float %x) + ret float %call +} + +define spir_func i32 @sub_group_scan_inclusive_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_smax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @sub_group_scan_inclusive_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_sub_group_scan_inclusive_umax_i32(i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z28sub_group_scan_inclusive_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @sub_group_scan_inclusive_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_sub_group_scan_inclusive_fmax_f32(float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z28sub_group_scan_inclusive_maxf(float %x) + ret float %call +} + +declare spir_func i32 @_Z18get_sub_group_sizev() +declare spir_func i32 @_Z22get_sub_group_local_idv() +declare spir_func i32 @_Z13sub_group_alli(i32) +declare spir_func i32 @_Z13sub_group_anyi(i32) +declare spir_func i32 @_Z19sub_group_broadcastij(i32, i32) +declare spir_func float @_Z19sub_group_broadcastfj(float, i32) +declare spir_func i32 @_Z20sub_group_reduce_addi(i32) +declare spir_func float @_Z20sub_group_reduce_addf(float) +declare spir_func i32 @_Z20sub_group_reduce_mini(i32) +declare spir_func i32 @_Z20sub_group_reduce_minj(i32) +declare spir_func float @_Z20sub_group_reduce_minf(float) +declare spir_func i32 @_Z20sub_group_reduce_maxi(i32) +declare spir_func i32 @_Z20sub_group_reduce_maxj(i32) +declare spir_func float @_Z20sub_group_reduce_maxf(float) +declare spir_func i32 @_Z28sub_group_scan_exclusive_addi(i32) +declare spir_func i32 @_Z28sub_group_scan_exclusive_addj(i32) +declare spir_func float @_Z28sub_group_scan_exclusive_addf(float) +declare spir_func i32 @_Z28sub_group_scan_exclusive_mini(i32) +declare spir_func i32 @_Z28sub_group_scan_exclusive_minj(i32) +declare spir_func float @_Z28sub_group_scan_exclusive_minf(float) +declare spir_func i32 @_Z28sub_group_scan_exclusive_maxi(i32) +declare spir_func i32 @_Z28sub_group_scan_exclusive_maxj(i32) +declare spir_func float @_Z28sub_group_scan_exclusive_maxf(float) +declare spir_func i32 @_Z28sub_group_scan_inclusive_addi(i32) +declare spir_func i32 @_Z28sub_group_scan_inclusive_addj(i32) +declare spir_func float @_Z28sub_group_scan_inclusive_addf(float) +declare spir_func i32 @_Z28sub_group_scan_inclusive_mini(i32) +declare spir_func i32 @_Z28sub_group_scan_inclusive_minj(i32) +declare spir_func float @_Z28sub_group_scan_inclusive_minf(float) +declare spir_func i32 @_Z28sub_group_scan_inclusive_maxi(i32) +declare spir_func i32 @_Z28sub_group_scan_inclusive_maxj(i32) +declare spir_func float @_Z28sub_group_scan_inclusive_maxf(float) + +!opencl.ocl.version = !{!0} + +!0 = !{i32 3, i32 0} diff --git a/modules/compiler/test/lit/passes/replace-work-group-ops.ll b/modules/compiler/test/lit/passes/replace-work-group-ops.ll new file mode 100644 index 000000000..21daedb47 --- /dev/null +++ b/modules/compiler/test/lit/passes/replace-work-group-ops.ll @@ -0,0 +1,300 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +; RUN: muxc --passes replace-group-funcs,verify %s | FileCheck %s + +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" + +define spir_func i32 @work_group_all_test(i32 %x) { +; CHECK: [[T0:%.*]] = icmp ne i32 %x, 0 +; CHECK: %call1 = call i1 @__mux_work_group_all_i1(i32 0, i1 [[T0]]) +; CHECK: [[T1:%.*]] = sext i1 %call1 to i32 +; CHECK: ret i32 [[T1]] + %call = call spir_func i32 @_Z14work_group_alli(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_any_test(i32 %x) { +; CHECK: [[T0:%.*]] = icmp ne i32 %x, 0 +; CHECK: %call1 = call i1 @__mux_work_group_any_i1(i32 0, i1 [[T0]]) +; CHECK: [[T1:%.*]] = sext i1 %call1 to i32 +; CHECK: ret i32 [[T1]] + %call = call spir_func i32 @_Z14work_group_anyi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_broadcastxi_test(i32 %x, i64 %lid) { +; CHECK: %call1 = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 %x, i64 %lid, i64 0, i64 0) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20work_group_broadcastim(i32 %x, i64 %lid) + ret i32 %call +} + +define spir_func float @work_group_broadcastxf_test(float %x, i64 %lid) { +; CHECK: %call1 = call float @__mux_work_group_broadcast_f32(i32 0, float %x, i64 %lid, i64 0, i64 0) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20work_group_broadcastfm(float %x, i64 %lid) + ret float %call +} + +define spir_func i32 @work_group_broadcastxyi_test(i32 %x, i64 %lidx, i64 %lidy) { +; CHECK: %call1 = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 %x, i64 %lidx, i64 %lidy, i64 0) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20work_group_broadcastimm(i32 %x, i64 %lidx, i64 %lidy) + ret i32 %call +} + +define spir_func float @work_group_broadcastxyf_test(float %x, i64 %lidx, i64 %lidy) { +; CHECK: %call1 = call float @__mux_work_group_broadcast_f32(i32 0, float %x, i64 %lidx, i64 %lidy, i64 0) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20work_group_broadcastfmm(float %x, i64 %lidx, i64 %lidy) + ret float %call +} + +define spir_func i32 @work_group_broadcastxyzi_test(i32 %x, i64 %lidx, i64 %lidy, i64 %lidz) { +; CHECK: %call1 = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 %x, i64 %lidx, i64 %lidy, i64 %lidz) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z20work_group_broadcastimmm(i32 %x, i64 %lidx, i64 %lidy, i64 %lidz) + ret i32 %call +} + +define spir_func float @work_group_broadcastxyzf_test(float %x, i64 %lidx, i64 %lidy, i64 %lidz) { +; CHECK: %call1 = call float @__mux_work_group_broadcast_f32(i32 0, float %x, i64 %lidx, i64 %lidy, i64 %lidz) +; CHECK: ret float %call1 + %call = call spir_func float @_Z20work_group_broadcastfmmm(float %x, i64 %lidx, i64 %lidy, i64 %lidz) + ret float %call +} + +define spir_func i32 @work_group_reduce_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_reduce_add_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z21work_group_reduce_addi(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_reduce_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_reduce_fadd_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z21work_group_reduce_addf(float %x) + ret float %call +} + +define spir_func i32 @work_group_reduce_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_reduce_smin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z21work_group_reduce_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_reduce_minu_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_reduce_umin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z21work_group_reduce_minj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_reduce_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_reduce_fmin_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z21work_group_reduce_minf(float %x) + ret float %call +} + +define spir_func i32 @work_group_reduce_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_reduce_smax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z21work_group_reduce_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_reduce_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_reduce_umax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z21work_group_reduce_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_reduce_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_reduce_fmax_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z21work_group_reduce_maxf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_exclusive_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_add_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_addi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_exclusive_addj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_add_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_addj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_exclusive_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_exclusive_fadd_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_exclusive_addf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_exclusive_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_smin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_exclusive_minj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_umin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_minj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_exclusive_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_exclusive_fmin_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_exclusive_minf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_exclusive_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_smax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_exclusive_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_exclusive_umax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_exclusive_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_exclusive_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_exclusive_fmax_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_exclusive_maxf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_inclusive_addi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_add_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_addi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_inclusive_addj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_add_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_addj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_inclusive_addf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_inclusive_fadd_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_inclusive_addf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_inclusive_mini_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_smin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_mini(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_inclusive_minj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_umin_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_minj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_inclusive_minf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_inclusive_fmin_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_inclusive_minf(float %x) + ret float %call +} + +define spir_func i32 @work_group_scan_inclusive_maxi_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_smax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_maxi(i32 %x) + ret i32 %call +} + +define spir_func i32 @work_group_scan_inclusive_maxj_test(i32 %x) { +; CHECK: %call1 = call i32 @__mux_work_group_scan_inclusive_umax_i32(i32 0, i32 %x) +; CHECK: ret i32 %call1 + %call = call spir_func i32 @_Z29work_group_scan_inclusive_maxj(i32 %x) + ret i32 %call +} + +define spir_func float @work_group_scan_inclusive_maxf_test(float %x) { +; CHECK: %call1 = call float @__mux_work_group_scan_inclusive_fmax_f32(i32 0, float %x) +; CHECK: ret float %call1 + %call = call spir_func float @_Z29work_group_scan_inclusive_maxf(float %x) + ret float %call +} + +declare spir_func i32 @_Z14work_group_alli(i32) +declare spir_func i32 @_Z14work_group_anyi(i32) +declare spir_func i32 @_Z20work_group_broadcastim(i32, i64) +declare spir_func float @_Z20work_group_broadcastfm(float, i64) +declare spir_func i32 @_Z20work_group_broadcastimm(i32, i64, i64) +declare spir_func float @_Z20work_group_broadcastfmm(float, i64, i64) +declare spir_func i32 @_Z20work_group_broadcastimmm(i32, i64, i64, i64) +declare spir_func float @_Z20work_group_broadcastfmmm(float, i64, i64, i64) +declare spir_func i32 @_Z21work_group_reduce_addi(i32) +declare spir_func float @_Z21work_group_reduce_addf(float) +declare spir_func i32 @_Z21work_group_reduce_mini(i32) +declare spir_func i32 @_Z21work_group_reduce_minj(i32) +declare spir_func float @_Z21work_group_reduce_minf(float) +declare spir_func i32 @_Z21work_group_reduce_maxi(i32) +declare spir_func i32 @_Z21work_group_reduce_maxj(i32) +declare spir_func float @_Z21work_group_reduce_maxf(float) +declare spir_func i32 @_Z29work_group_scan_exclusive_addi(i32) +declare spir_func i32 @_Z29work_group_scan_exclusive_addj(i32) +declare spir_func float @_Z29work_group_scan_exclusive_addf(float) +declare spir_func i32 @_Z29work_group_scan_exclusive_mini(i32) +declare spir_func i32 @_Z29work_group_scan_exclusive_minj(i32) +declare spir_func float @_Z29work_group_scan_exclusive_minf(float) +declare spir_func i32 @_Z29work_group_scan_exclusive_maxi(i32) +declare spir_func i32 @_Z29work_group_scan_exclusive_maxj(i32) +declare spir_func float @_Z29work_group_scan_exclusive_maxf(float) +declare spir_func i32 @_Z29work_group_scan_inclusive_addi(i32) +declare spir_func i32 @_Z29work_group_scan_inclusive_addj(i32) +declare spir_func float @_Z29work_group_scan_inclusive_addf(float) +declare spir_func i32 @_Z29work_group_scan_inclusive_mini(i32) +declare spir_func i32 @_Z29work_group_scan_inclusive_minj(i32) +declare spir_func float @_Z29work_group_scan_inclusive_minf(float) +declare spir_func i32 @_Z29work_group_scan_inclusive_maxi(i32) +declare spir_func i32 @_Z29work_group_scan_inclusive_maxj(i32) +declare spir_func float @_Z29work_group_scan_inclusive_maxf(float) + +!opencl.ocl.version = !{!0} + +!0 = !{i32 3, i32 0} diff --git a/modules/compiler/test/mangling.cpp b/modules/compiler/test/mangling.cpp index eb0ef1e2d..2228c0bfd 100644 --- a/modules/compiler/test/mangling.cpp +++ b/modules/compiler/test/mangling.cpp @@ -16,11 +16,6 @@ #include #include -#include -#include -#include -#include -#include #include #include @@ -31,23 +26,7 @@ using namespace compiler::utils; -struct ManglingTest : ::testing::Test { - void SetUp() override {} - - std::unique_ptr parseModule(llvm::StringRef Assembly) { - llvm::SMDiagnostic Error; - auto M = llvm::parseAssemblyString(Assembly, Error, Context); - - std::string ErrMsg; - llvm::raw_string_ostream OS(ErrMsg); - Error.print("", OS); - EXPECT_TRUE(M) << OS.str(); - - return M; - } - - llvm::LLVMContext Context; -}; +using ManglingTest = CompilerLLVMModuleTest; TEST_F(ManglingTest, MangleBuiltinTypes) { // With opaque pointers, before LLVM 17 we can't actually mangle OpenCL diff --git a/modules/compiler/test/utils.cpp b/modules/compiler/test/utils.cpp index 8fe6d7c3c..36a2df469 100644 --- a/modules/compiler/test/utils.cpp +++ b/modules/compiler/test/utils.cpp @@ -17,11 +17,6 @@ #include #include #include -#include -#include -#include -#include -#include #include #include @@ -32,23 +27,7 @@ using namespace compiler::utils; -struct CompilerUtilsTest : ::testing::Test { - void SetUp() override {} - - std::unique_ptr parseModule(llvm::StringRef Assembly) { - llvm::SMDiagnostic Error; - auto M = llvm::parseAssemblyString(Assembly, Error, Context); - - std::string ErrMsg; - llvm::raw_string_ostream OS(ErrMsg); - Error.print("", OS); - EXPECT_TRUE(M) << OS.str(); - - return M; - } - - llvm::LLVMContext Context; -}; +using CompilerUtilsTest = CompilerLLVMModuleTest; TEST_F(CompilerUtilsTest, CreateKernelWrapper) { auto M = parseModule(R"( diff --git a/modules/compiler/utils/CMakeLists.txt b/modules/compiler/utils/CMakeLists.txt index eccf8bc3a..bbbf04b9a 100644 --- a/modules/compiler/utils/CMakeLists.txt +++ b/modules/compiler/utils/CMakeLists.txt @@ -60,6 +60,7 @@ add_ca_library(compiler-utils STATIC ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_atomic_funcs_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_barriers_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_c11_atomic_funcs_pass.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_group_funcs_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_local_module_scope_variables_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_mem_intrinsics_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_mux_math_decls_pass.h @@ -111,6 +112,7 @@ add_ca_library(compiler-utils STATIC ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_atomic_funcs_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_barriers_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_c11_atomic_funcs_pass.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_group_funcs_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_local_module_scope_variables_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_mem_intrinsics_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/replace_mux_math_decls_pass.cpp diff --git a/modules/compiler/utils/include/compiler/utils/builtin_info.h b/modules/compiler/utils/include/compiler/utils/builtin_info.h index d96062be8..5f04d76ec 100644 --- a/modules/compiler/utils/include/compiler/utils/builtin_info.h +++ b/modules/compiler/utils/include/compiler/utils/builtin_info.h @@ -21,6 +21,7 @@ #ifndef COMPILER_UTILS_BUILTIN_INFO_H_INCLUDED #define COMPILER_UTILS_BUILTIN_INFO_H_INCLUDED +#include #include #include #include @@ -70,10 +71,65 @@ enum BaseBuiltinID { eMuxBuiltinGetGlobalLinearId, eMuxBuiltinGetLocalLinearId, eMuxBuiltinGetEnqueuedLocalSize, + eMuxBuiltinGetSubGroupSize, + eMuxBuiltinGetSubGroupLocalId, // Synchronization builtins eMuxBuiltinMemBarrier, eMuxBuiltinSubGroupBarrier, eMuxBuiltinWorkGroupBarrier, +#define GROUP_BUILTINS(SCOPE) \ + eFirstMux##SCOPE##groupCollectiveBuiltin, \ + eMuxBuiltin##SCOPE##groupAll = eFirstMux##SCOPE##groupCollectiveBuiltin, \ + eMuxBuiltin##SCOPE##groupAny, eMuxBuiltin##SCOPE##groupBroadcast, \ + eMuxBuiltin##SCOPE##groupReduceAdd, eMuxBuiltin##SCOPE##groupReduceFAdd, \ + eMuxBuiltin##SCOPE##groupReduceSMin, \ + eMuxBuiltin##SCOPE##groupReduceUMin, \ + eMuxBuiltin##SCOPE##groupReduceFMin, \ + eMuxBuiltin##SCOPE##groupReduceSMax, \ + eMuxBuiltin##SCOPE##groupReduceUMax, \ + eMuxBuiltin##SCOPE##groupReduceFMax, eMuxBuiltin##SCOPE##groupReduceMul, \ + eMuxBuiltin##SCOPE##groupReduceFMul, eMuxBuiltin##SCOPE##groupReduceAnd, \ + eMuxBuiltin##SCOPE##groupReduceOr, eMuxBuiltin##SCOPE##groupReduceXor, \ + eMuxBuiltin##SCOPE##groupReduceLogicalAnd, \ + eMuxBuiltin##SCOPE##groupReduceLogicalOr, \ + eMuxBuiltin##SCOPE##groupReduceLogicalXor, \ + eMuxBuiltin##SCOPE##groupScanAddInclusive, \ + eMuxBuiltin##SCOPE##groupScanFAddInclusive, \ + eMuxBuiltin##SCOPE##groupScanAddExclusive, \ + eMuxBuiltin##SCOPE##groupScanFAddExclusive, \ + eMuxBuiltin##SCOPE##groupScanSMinInclusive, \ + eMuxBuiltin##SCOPE##groupScanUMinInclusive, \ + eMuxBuiltin##SCOPE##groupScanFMinInclusive, \ + eMuxBuiltin##SCOPE##groupScanSMinExclusive, \ + eMuxBuiltin##SCOPE##groupScanUMinExclusive, \ + eMuxBuiltin##SCOPE##groupScanFMinExclusive, \ + eMuxBuiltin##SCOPE##groupScanSMaxInclusive, \ + eMuxBuiltin##SCOPE##groupScanUMaxInclusive, \ + eMuxBuiltin##SCOPE##groupScanFMaxInclusive, \ + eMuxBuiltin##SCOPE##groupScanSMaxExclusive, \ + eMuxBuiltin##SCOPE##groupScanUMaxExclusive, \ + eMuxBuiltin##SCOPE##groupScanFMaxExclusive, \ + eMuxBuiltin##SCOPE##groupScanMulInclusive, \ + eMuxBuiltin##SCOPE##groupScanFMulInclusive, \ + eMuxBuiltin##SCOPE##groupScanMulExclusive, \ + eMuxBuiltin##SCOPE##groupScanFMulExclusive, \ + eMuxBuiltin##SCOPE##groupScanAndInclusive, \ + eMuxBuiltin##SCOPE##groupScanAndExclusive, \ + eMuxBuiltin##SCOPE##groupScanOrInclusive, \ + eMuxBuiltin##SCOPE##groupScanOrExclusive, \ + eMuxBuiltin##SCOPE##groupScanXorInclusive, \ + eMuxBuiltin##SCOPE##groupScanXorExclusive, \ + eMuxBuiltin##SCOPE##groupScanLogicalAndInclusive, \ + eMuxBuiltin##SCOPE##groupScanLogicalAndExclusive, \ + eMuxBuiltin##SCOPE##groupScanLogicalOrInclusive, \ + eMuxBuiltin##SCOPE##groupScanLogicalOrExclusive, \ + eMuxBuiltin##SCOPE##groupScanLogicalXorInclusive, \ + eLastMux##SCOPE##groupCollectiveBuiltin, \ + eMuxBuiltin##SCOPE##groupScanLogicalXorExclusive = \ + eLastMux##SCOPE##groupCollectiveBuiltin + GROUP_BUILTINS(Work), + GROUP_BUILTINS(Sub), + GROUP_BUILTINS(Vec), // Marker - target builtins should start from here. eFirstTargetBuiltin, @@ -154,11 +210,15 @@ enum BuiltinProperties : int32_t { eBuiltinPropertyRematerializable = (1 << 14), /// @brief The builtin should be mapped to a mux synchronization builtin. /// - /// This mapping takes place in BuiltiInfo::mapSyncBuiltinToMuxSyncBuiltin. + /// This mapping takes place in BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin. eBuiltinPropertyMapToMuxSyncBuiltin = (1 << 15), /// @brief The builtin is known not be be convergent, i.e., it does not /// depend on any other work-item in any way. eBuiltinPropertyKnownNonConvergent = (1 << 16), + /// @brief The builtin should be mapped to a mux group builtin. + /// + /// This mapping takes place in BuiltinInfo::mapSyncBuiltinToMuxGroupBuiltin. + eBuiltinPropertyMapToMuxGroupBuiltin = (1 << 17), }; /// @brief struct to hold information about a builtin function @@ -169,6 +229,9 @@ struct Builtin { BuiltinID const ID; /// @brief the Builtin Properties BuiltinProperties const properties; + /// @brief list of types used in overloading this builtin (only relevant for + /// overloadable mux builtins) + std::vector mux_overload_info = {}; /// @brief returns whether the builtin is valid bool isValid() const { return ID != eBuiltinInvalid; } @@ -215,6 +278,8 @@ constexpr const char get_global_linear_id[] = "__mux_get_global_linear_id"; constexpr const char get_local_linear_id[] = "__mux_get_local_linear_id"; constexpr const char get_enqueued_local_size[] = "__mux_get_enqueued_local_size"; +constexpr const char get_sub_group_size[] = "__mux_get_sub_group_size"; +constexpr const char get_sub_group_local_id[] = "__mux_get_sub_group_local_id"; // Barriers constexpr const char mem_barrier[] = "__mux_mem_barrier"; @@ -480,7 +545,15 @@ class BuiltinInfo { /// eBuiltinPropertyMapToMuxSyncBuiltin is set, the target must then remap /// the call to a new call to the correct mux builtin, remapping any /// arguments as required. - llvm::CallInst *mapSyncBuiltinToMuxSyncBuiltin(llvm::CallInst &CI); + llvm::Instruction *mapSyncBuiltinToMuxSyncBuiltin(llvm::CallInst &CI); + + /// @brief Remaps a call instruction to a call calling a mux group builtin. + /// + /// For a call to a builtin for which the property + /// eBuiltinPropertyMapToMuxGroupBuiltin is set, the target must then remap + /// the call to a new call to the correct mux builtin, remapping any + /// arguments as required. + llvm::Instruction *mapGroupBuiltinToMuxGroupBuiltin(llvm::CallInst &CI); /// @brief Get a builtin for printf. /// @return An identifier for the builtin, or the invalid builtin if there @@ -504,6 +577,15 @@ class BuiltinInfo { return ID > eBuiltinInvalid && ID < eFirstTargetBuiltin; } + /// @brief Returns true if the given ID is an overloadable ComputeMux builtin + /// ID. + /// + /// These builtins *require* extra overloading info when declaring or + /// defining. + static bool isOverloadableMuxBuiltinID(BuiltinID ID) { + return isMuxBuiltinID(ID) && isMuxGroupCollective(ID); + } + /// @brief Returns true if the given ID is a ComputeMux barrier builtin ID. static bool isMuxControlBarrierID(BuiltinID ID) { return ID == eMuxBuiltinSubGroupBarrier || @@ -518,8 +600,43 @@ class BuiltinInfo { ID == eMuxBuiltinDMAWrite3D; } + /// @brief Gets information about a mux group operation builtin + /// + /// Note: Does not set the 'function' or 'type' members of the + /// GroupCollective. + /// + /// FIXME: This matches an equivalent function in group_collective_helpers.h + /// which runs on OpenCL builtins. Remove that once the transition to mux + /// builtins is complete. + static std::optional isMuxGroupCollective(BuiltinID ID); + /// @brief Maps a ComputeMux builtin ID to its function name. - static llvm::StringRef getMuxBuiltinName(BuiltinID ID); + /// + /// @param OverloadInfo An array of types required to resolve certain + /// overloadable builtins, e.g., group builtins. + static std::string getMuxBuiltinName( + BuiltinID ID, llvm::ArrayRef OverloadInfo = {}); + + /// @brief Mangles a type using the LLVM intrinsic scheme + /// + /// This is an extremely simple mangling scheme matching LLVM's intrinsic + /// mangling system. It is only designed to be used with a specific set of + /// types and is not a general-purpose mangler. + /// + /// * iXXX -> iXXX + /// * half -> f16 + /// * float -> f32 + /// * double -> f64 + /// * -> vNTy + /// * -> nxvNTy + static std::string getMangledTypeStr(llvm::Type *Ty); + + /// @brief Demangles a type using the LLVM intrinsic scheme - returns nullptr + /// if it was unable to demangle a type. + /// + /// @see getMangledTypeStr + static std::pair getDemangledTypeFromStr( + llvm::StringRef TyStr, llvm::LLVMContext &Ctx); /// @brief Defines the body of a ComputeMux builtin declaration /// @@ -527,12 +644,20 @@ class BuiltinInfo { /// function name, it is left alone and returned. /// /// Will declare any builtins it requires as transitive dependencies. - llvm::Function *defineMuxBuiltin(BuiltinID, llvm::Module &M); + /// + /// @param OverloadInfo An array of types required to resolve certain + /// overloadable builtins, e.g., group builtins. + llvm::Function *defineMuxBuiltin( + BuiltinID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}); /// @brief Gets a ComputeMux builtin from the module, or declares it /// - /// Only work-item builtins are supported. - llvm::Function *getOrDeclareMuxBuiltin(BuiltinID, llvm::Module &M); + /// @param OverloadInfo An array of types required to resolve certain + /// overloadable builtins, e.g., group builtins. + llvm::Function *getOrDeclareMuxBuiltin( + BuiltinID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}); struct SchedParamInfo { /// @brief An identifier providing resolution for targets to identify @@ -658,8 +783,10 @@ class BuiltinInfo { private: /// @brief Try to identify a builtin function. /// @param[in] F The function to identify. - /// @return Valid builtin ID if the name was identified. - BuiltinID identifyMuxBuiltin(llvm::Function const &F) const; + /// @return Valid builtin ID if the name was identified, as well as any types + /// required to overload the builtin ID. + std::pair> identifyMuxBuiltin( + llvm::Function const &F) const; /// @brief Determine whether the given builtin function returns uniform values /// or not. An optional call instruction can be passed for more accuracy. @@ -682,10 +809,14 @@ class BIMuxInfoConcept { virtual ~BIMuxInfoConcept() = default; /// @brief See BuiltinInfo::defineMuxBuiltin. - virtual llvm::Function *defineMuxBuiltin(BuiltinID, llvm::Module &M); + virtual llvm::Function *defineMuxBuiltin( + BuiltinID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}); /// @brief See BuiltinInfo::getOrDeclareMuxBuiltin. - virtual llvm::Function *getOrDeclareMuxBuiltin(BuiltinID, llvm::Module &M); + virtual llvm::Function *getOrDeclareMuxBuiltin( + BuiltinID, llvm::Module &M, + llvm::ArrayRef OverloadInfo = {}); /// @brief See BuiltinInfo::getMuxSchedulingParameters virtual llvm::SmallVector @@ -825,8 +956,13 @@ class BILangInfoConcept { virtual bool requiresMapToMuxSyncBuiltin(BuiltinID) const { return false; } /// @see BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin - virtual llvm::CallInst *mapSyncBuiltinToMuxSyncBuiltin(llvm::CallInst &, - BIMuxInfoConcept &) { + virtual llvm::Instruction *mapSyncBuiltinToMuxSyncBuiltin( + llvm::CallInst &, BIMuxInfoConcept &) { + return nullptr; + } + /// @see BuiltinInfo::mapGroupBuiltinToMuxGroupBuiltin + virtual llvm::Instruction *mapGroupBuiltinToMuxGroupBuiltin( + llvm::CallInst &, BIMuxInfoConcept &) { return nullptr; } /// @see BuiltinInfo::getPrintfBuiltin diff --git a/modules/compiler/utils/include/compiler/utils/cl_builtin_info.h b/modules/compiler/utils/include/compiler/utils/cl_builtin_info.h index 98aed3967..e9cf8f2b8 100644 --- a/modules/compiler/utils/include/compiler/utils/cl_builtin_info.h +++ b/modules/compiler/utils/include/compiler/utils/cl_builtin_info.h @@ -116,8 +116,11 @@ class CLBuiltinInfo : public BILangInfoConcept { const override; /// @see BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin - llvm::CallInst *mapSyncBuiltinToMuxSyncBuiltin(llvm::CallInst &, - BIMuxInfoConcept &) override; + llvm::Instruction *mapSyncBuiltinToMuxSyncBuiltin( + llvm::CallInst &, BIMuxInfoConcept &) override; + /// @see BuiltinInfo::mapGroupBuiltinToMuxGroupBuiltin + llvm::Instruction *mapGroupBuiltinToMuxGroupBuiltin( + llvm::CallInst &, BIMuxInfoConcept &) override; /// @see BuiltinInfo::getPrintfBuiltin BuiltinID getPrintfBuiltin() const override; /// @see BuiltinInfo::getSubgroupLocalIdBuiltin diff --git a/modules/compiler/utils/include/compiler/utils/group_collective_helpers.h b/modules/compiler/utils/include/compiler/utils/group_collective_helpers.h index 7fa10545e..004b2f02b 100644 --- a/modules/compiler/utils/include/compiler/utils/group_collective_helpers.h +++ b/modules/compiler/utils/include/compiler/utils/group_collective_helpers.h @@ -63,8 +63,7 @@ llvm::Constant *getIdentityVal(llvm::RecurKind Kind, llvm::Type *Ty); /// @brief Represents a work-group or sub-group collective operation. struct GroupCollective { /// @brief The different operation types a group collective can represent. - enum class Op { - None, + enum class OpKind { All, Any, Reduction, @@ -74,24 +73,24 @@ struct GroupCollective { }; /// @brief The possible scopes of a group collective. - enum class Scope { None, WorkGroup, SubGroup }; + enum class ScopeKind { WorkGroup, SubGroup, VectorGroup }; /// @brief The operation type of the group collective. - Op op = Op::None; + OpKind Op = OpKind::All; /// @brief The scope of the group collective operation. - Scope scope = Scope::None; + ScopeKind Scope = ScopeKind::WorkGroup; /// @brief The llvm recurrence operation this can be mapped to. For broadcasts /// this will be None. - llvm::RecurKind recurKind = llvm::RecurKind::None; + llvm::RecurKind Recurrence = llvm::RecurKind::None; /// @brief The llvm function body for this group collective instance. - llvm::Function *func = nullptr; + llvm::Function *Func = nullptr; /// @brief The type the group operation is applied to. Will always be the - /// type of the first argument of `func`. - llvm::Type *type = nullptr; + /// type of the first argument of `Func`. + llvm::Type *Ty = nullptr; /// @brief True if the operation is logical, rather than bitwise. - bool isLogical = false; + bool IsLogical = false; /// @brief Returns true for Any/All type collective operations. - bool isAnyAll() const { return op == Op::Any || op == Op::All; } + bool isAnyAll() const { return Op == OpKind::Any || Op == OpKind::All; } }; /// @brief Helper function to parse a group collective operation. diff --git a/modules/compiler/utils/include/compiler/utils/replace_group_funcs_pass.h b/modules/compiler/utils/include/compiler/utils/replace_group_funcs_pass.h new file mode 100644 index 000000000..38dc5a5a1 --- /dev/null +++ b/modules/compiler/utils/include/compiler/utils/replace_group_funcs_pass.h @@ -0,0 +1,40 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +/// @file +/// +/// Replace group functions pass. + +#ifndef COMPILER_UTILS_REPLACE_GROUP_FUNCS_PASS_H_INCLUDED +#define COMPILER_UTILS_REPLACE_GROUP_FUNCS_PASS_H_INCLUDED + +#include + +namespace compiler { +namespace utils { + +/// @brief A pass that will replace calls to the group builtins with calls to +/// the equivalent mux functions + +class ReplaceGroupFuncsPass final + : public llvm::PassInfoMixin { + public: + llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &); +}; +} // namespace utils +} // namespace compiler + +#endif // COMPILER_UTILS_REPLACE_GROUP_FUNCS_PASS_H_INCLUDED diff --git a/modules/compiler/utils/source/builtin_info.cpp b/modules/compiler/utils/source/builtin_info.cpp index ad1121ff9..849e79816 100644 --- a/modules/compiler/utils/source/builtin_info.cpp +++ b/modules/compiler/utils/source/builtin_info.cpp @@ -42,8 +42,9 @@ Module *BuiltinInfo::getBuiltinsModule() { return nullptr; } -BuiltinID BuiltinInfo::identifyMuxBuiltin(Function const &F) const { - StringRef const Name = F.getName(); +std::pair> BuiltinInfo::identifyMuxBuiltin( + Function const &F) const { + StringRef Name = F.getName(); auto ID = StringSwitch(Name) .Case(MuxBuiltins::isftz, eMuxBuiltinIsFTZ) @@ -77,11 +78,159 @@ BuiltinID BuiltinInfo::identifyMuxBuiltin(Function const &F) const { .Case(MuxBuiltins::get_local_linear_id, eMuxBuiltinGetLocalLinearId) .Case(MuxBuiltins::get_enqueued_local_size, eMuxBuiltinGetEnqueuedLocalSize) + .Case(MuxBuiltins::get_sub_group_size, eMuxBuiltinGetSubGroupSize) + .Case(MuxBuiltins::get_sub_group_local_id, + eMuxBuiltinGetSubGroupLocalId) .Case(MuxBuiltins::work_group_barrier, eMuxBuiltinWorkGroupBarrier) .Case(MuxBuiltins::sub_group_barrier, eMuxBuiltinSubGroupBarrier) .Case(MuxBuiltins::mem_barrier, eMuxBuiltinMemBarrier) .Default(eBuiltinInvalid); - return ID; + if (ID != eBuiltinInvalid) { + return {ID, {}}; + } + + // Now check for group functions, which are a bit more involved as there's + // many of them and they're also mangled. We enforce that the mangling makes + // sense, otherwise the builtin is declared as invalid. + bool IsSubgroupOp = Name.consume_front("__mux_sub_group_"); + bool IsVecgroupOp = Name.consume_front("__mux_vec_group_"); + if (!IsSubgroupOp && !IsVecgroupOp && + !Name.consume_front("__mux_work_group_")) { + return {eBuiltinInvalid, {}}; + } + +#define SCOPED_GROUP_OP(OP) \ + (IsSubgroupOp ? eMuxBuiltinSubgroup##OP \ + : IsVecgroupOp ? eMuxBuiltinVecgroup##OP \ + : eMuxBuiltinWorkgroup##OP) + + // Most group operations have one argument, except for broadcasts. Despite + // that, we don't mangle the indices as they're fixed. + unsigned const NumExpectedMangledArgs = 1; + + if (Name.consume_front("any")) { + ID = SCOPED_GROUP_OP(Any); + } else if (Name.consume_front("all")) { + ID = SCOPED_GROUP_OP(All); + } else if (Name.consume_front("broadcast")) { + ID = SCOPED_GROUP_OP(Broadcast); + } else if (Name.consume_front("reduce_")) { + auto NextIdx = Name.find_first_of('_'); + std::string Group = Name.substr(0, NextIdx).str(); + Name = Name.drop_front(Group.size()); + + if (Group == "logical") { + Name = Name.drop_front(); // Drop the underscore + auto NextIdx = Name.find_first_of('_'); + auto RealGroup = Name.substr(0, NextIdx); + Group += "_" + RealGroup.str(); + Name = Name.drop_front(RealGroup.size()); + } + + ID = StringSwitch(Group) + .Case("add", SCOPED_GROUP_OP(ReduceAdd)) + .Case("fadd", SCOPED_GROUP_OP(ReduceFAdd)) + .Case("mul", SCOPED_GROUP_OP(ReduceMul)) + .Case("fmul", SCOPED_GROUP_OP(ReduceFMul)) + .Case("smin", SCOPED_GROUP_OP(ReduceSMin)) + .Case("umin", SCOPED_GROUP_OP(ReduceUMin)) + .Case("fmin", SCOPED_GROUP_OP(ReduceFMin)) + .Case("smax", SCOPED_GROUP_OP(ReduceSMax)) + .Case("umax", SCOPED_GROUP_OP(ReduceUMax)) + .Case("fmax", SCOPED_GROUP_OP(ReduceFMax)) + .Case("and", SCOPED_GROUP_OP(ReduceAnd)) + .Case("or", SCOPED_GROUP_OP(ReduceOr)) + .Case("xor", SCOPED_GROUP_OP(ReduceXor)) + .Case("logical_and", SCOPED_GROUP_OP(ReduceLogicalAnd)) + .Case("logical_or", SCOPED_GROUP_OP(ReduceLogicalOr)) + .Case("logical_xor", SCOPED_GROUP_OP(ReduceLogicalXor)) + .Default(eBuiltinInvalid); + } else if (Name.consume_front("scan_")) { + bool IsInclusive = Name.consume_front("inclusive_"); + if (!IsInclusive && !Name.consume_front("exclusive_")) { + return {eBuiltinInvalid, {}}; + } + + auto NextIdx = Name.find_first_of('_'); + std::string Group = Name.substr(0, NextIdx).str(); + Name = Name.drop_front(Group.size()); + + if (Group == "logical") { + auto NextIdx = Name.find_first_of('_', /*From*/ 1); + auto RealGroup = Name.substr(0, NextIdx); + Group += RealGroup.str(); + Name = Name.drop_front(RealGroup.size()); + } + + ID = StringSwitch(Group) + .Case("add", IsInclusive ? SCOPED_GROUP_OP(ScanAddInclusive) + : SCOPED_GROUP_OP(ScanAddExclusive)) + .Case("fadd", IsInclusive ? SCOPED_GROUP_OP(ScanFAddInclusive) + : SCOPED_GROUP_OP(ScanFAddExclusive)) + .Case("mul", IsInclusive ? SCOPED_GROUP_OP(ScanMulInclusive) + : SCOPED_GROUP_OP(ScanMulExclusive)) + .Case("fmul", IsInclusive ? SCOPED_GROUP_OP(ScanFMulInclusive) + : SCOPED_GROUP_OP(ScanFMulExclusive)) + .Case("smin", IsInclusive ? SCOPED_GROUP_OP(ScanSMinInclusive) + : SCOPED_GROUP_OP(ScanSMinExclusive)) + .Case("umin", IsInclusive ? SCOPED_GROUP_OP(ScanUMinInclusive) + : SCOPED_GROUP_OP(ScanUMinExclusive)) + .Case("fmin", IsInclusive ? SCOPED_GROUP_OP(ScanFMinInclusive) + : SCOPED_GROUP_OP(ScanFMinExclusive)) + .Case("smax", IsInclusive ? SCOPED_GROUP_OP(ScanSMaxInclusive) + : SCOPED_GROUP_OP(ScanSMaxExclusive)) + .Case("umax", IsInclusive ? SCOPED_GROUP_OP(ScanUMaxInclusive) + : SCOPED_GROUP_OP(ScanUMaxExclusive)) + .Case("fmax", IsInclusive ? SCOPED_GROUP_OP(ScanFMaxInclusive) + : SCOPED_GROUP_OP(ScanFMaxExclusive)) + .Case("and", IsInclusive ? SCOPED_GROUP_OP(ScanAndInclusive) + : SCOPED_GROUP_OP(ScanAndExclusive)) + .Case("or", IsInclusive ? SCOPED_GROUP_OP(ScanOrInclusive) + : SCOPED_GROUP_OP(ScanOrExclusive)) + .Case("xor", IsInclusive ? SCOPED_GROUP_OP(ScanXorInclusive) + : SCOPED_GROUP_OP(ScanXorExclusive)) + .Case("logical_and", + IsInclusive ? SCOPED_GROUP_OP(ScanLogicalAndInclusive) + : SCOPED_GROUP_OP(ScanLogicalAndExclusive)) + .Case("logical_or", IsInclusive + ? SCOPED_GROUP_OP(ScanLogicalOrInclusive) + : SCOPED_GROUP_OP(ScanLogicalOrExclusive)) + .Case("logical_xor", + IsInclusive ? SCOPED_GROUP_OP(ScanLogicalXorInclusive) + : SCOPED_GROUP_OP(ScanLogicalXorExclusive)) + .Default(eBuiltinInvalid); + } + + std::vector OverloadInfo; + if (ID != eBuiltinInvalid) { + // Consume the rest of this group Op function name. If we can't identify a + // series of mangled type names, this builtin is invalid. + unsigned NumMangledArgs = 0; + // Work-group builtins have an unmangled 'barrier ID' parameter first, which + // we want to skip. + unsigned Offset = ID >= eFirstMuxWorkgroupCollectiveBuiltin && + ID <= eLastMuxWorkgroupCollectiveBuiltin; + while (!Name.empty()) { + if (!Name.consume_front("_")) { + return {eBuiltinInvalid, {}}; + } + auto [Ty, NewName] = getDemangledTypeFromStr(Name, F.getContext()); + Name = NewName; + + auto ParamIdx = Offset + NumMangledArgs; + if (ParamIdx >= F.arg_size() || Ty != F.getArg(ParamIdx)->getType()) { + return {eBuiltinInvalid, {}}; + } + + ++NumMangledArgs; + OverloadInfo.push_back(Ty); + } + if (NumMangledArgs != NumExpectedMangledArgs) { + return {eBuiltinInvalid, {}}; + } + } + + return {ID, OverloadInfo}; } BuiltinUniformity BuiltinInfo::isBuiltinUniform(Builtin const &B, @@ -118,12 +267,34 @@ BuiltinUniformity BuiltinInfo::isBuiltinUniform(Builtin const &B, return eBuiltinUniformityAlways; } + case eMuxBuiltinGetSubGroupLocalId: + return eBuiltinUniformityInstanceID; case eMuxBuiltinGetLocalLinearId: case eMuxBuiltinGetGlobalLinearId: // TODO: This is fine for vectorizing in the x-axis, but currently we do // not support vectorizing along y or z (see CA-2843). return SimdDimIdx ? eBuiltinUniformityNever : eBuiltinUniformityInstanceID; + case eMuxBuiltinSubgroupAll: + case eMuxBuiltinSubgroupAny: + case eMuxBuiltinSubgroupBroadcast: + case eMuxBuiltinSubgroupReduceAdd: + case eMuxBuiltinSubgroupReduceFAdd: + case eMuxBuiltinSubgroupReduceMul: + case eMuxBuiltinSubgroupReduceFMul: + case eMuxBuiltinSubgroupReduceSMax: + case eMuxBuiltinSubgroupReduceUMax: + case eMuxBuiltinSubgroupReduceFMax: + case eMuxBuiltinSubgroupReduceSMin: + case eMuxBuiltinSubgroupReduceUMin: + case eMuxBuiltinSubgroupReduceFMin: + case eMuxBuiltinSubgroupReduceAnd: + case eMuxBuiltinSubgroupReduceOr: + case eMuxBuiltinSubgroupReduceXor: + case eMuxBuiltinSubgroupReduceLogicalAnd: + case eMuxBuiltinSubgroupReduceLogicalOr: + case eMuxBuiltinSubgroupReduceLogicalXor: + return eBuiltinUniformityAlways; } if (LangImpl) { return LangImpl->isBuiltinUniform(B, CI, SimdDimIdx); @@ -213,7 +384,7 @@ Builtin BuiltinInfo::analyzeBuiltin(Function const &F) const { return Builtin{F, eBuiltinUnknown, (BuiltinProperties)Properties}; } - BuiltinID ID = identifyMuxBuiltin(F); + auto [ID, OverloadInfo] = identifyMuxBuiltin(F); if (ID == eBuiltinInvalid) { // It's not a Mux builtin, so defer to the language implementation @@ -257,6 +428,7 @@ Builtin BuiltinInfo::analyzeBuiltin(Function const &F) const { case eMuxBuiltinGetGlobalLinearId: case eMuxBuiltinGetLocalLinearId: case eMuxBuiltinGetGlobalId: + case eMuxBuiltinGetSubGroupLocalId: Properties = eBuiltinPropertyWorkItem | eBuiltinPropertyRematerializable; break; case eMuxBuiltinGetLocalId: @@ -269,10 +441,17 @@ Builtin BuiltinInfo::analyzeBuiltin(Function const &F) const { Properties = eBuiltinPropertyNoSideEffects; break; } + + // Group functions are convergent. + if (isMuxGroupCollective(ID)) { + IsConvergent = true; + } + if (!IsConvergent) { Properties |= eBuiltinPropertyKnownNonConvergent; } - return Builtin{F, ID, (BuiltinProperties)Properties}; + + return Builtin{F, ID, (BuiltinProperties)Properties, OverloadInfo}; } BuiltinCall BuiltinInfo::analyzeBuiltinCall(CallInst const &CI, @@ -371,7 +550,7 @@ multi_llvm::Optional BuiltinInfo::getBuiltinRange( return multi_llvm::None; } -CallInst *BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin(CallInst &CI) { +Instruction *BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin(CallInst &CI) { if (LangImpl) { return LangImpl->mapSyncBuiltinToMuxSyncBuiltin(CI, *MuxImpl); } @@ -379,6 +558,14 @@ CallInst *BuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin(CallInst &CI) { return nullptr; } +Instruction *BuiltinInfo::mapGroupBuiltinToMuxGroupBuiltin(CallInst &CI) { + if (LangImpl) { + return LangImpl->mapGroupBuiltinToMuxGroupBuiltin(CI, *MuxImpl); + } + // We shouldn't be mapping mux builtins to mux builtins, so we can stop here. + return nullptr; +} + BuiltinID BuiltinInfo::getPrintfBuiltin() const { if (LangImpl) { return LangImpl->getPrintfBuiltin(); @@ -429,11 +616,78 @@ Value *BuiltinInfo::initializeSchedulingParamForWrappedKernel( CalleeF); } -StringRef BuiltinInfo::getMuxBuiltinName(BuiltinID ID) { +// This provides an extremely simple mangling scheme matching LLVM's intrinsic +// mangling system. It is only designed to be used with a specific set of types +// and is not a general-purpose mangler. +std::string BuiltinInfo::getMangledTypeStr(Type *Ty) { + std::string Result; + if (VectorType *VTy = dyn_cast(Ty)) { + ElementCount EC = VTy->getElementCount(); + if (EC.isScalable()) { + Result += "nx"; + } + return "v" + utostr(EC.getKnownMinValue()) + + getMangledTypeStr(VTy->getElementType()); + } + + if (Ty) { + switch (Ty->getTypeID()) { + default: + break; + case Type::HalfTyID: + return "f16"; + case Type::BFloatTyID: + return "bf16"; + case Type::FloatTyID: + return "f32"; + case Type::DoubleTyID: + return "f64"; + case Type::IntegerTyID: + return "i" + utostr(cast(Ty)->getBitWidth()); + } + } + llvm_unreachable("Unhandled type"); +} + +std::pair BuiltinInfo::getDemangledTypeFromStr( + StringRef TyStr, LLVMContext &Ctx) { + bool IsScalable = TyStr.consume_front("nx"); + if (TyStr.consume_front("v")) { + unsigned EC; + if (TyStr.consumeInteger(10, EC)) { + return {nullptr, TyStr}; + } + if (auto [EltTy, NewTyStr] = getDemangledTypeFromStr(TyStr, Ctx); EltTy) { + return {VectorType::get(EltTy, EC, IsScalable), NewTyStr}; + } + return {nullptr, TyStr}; + } + if (TyStr.consume_front("f16")) { + return {Type::getHalfTy(Ctx), TyStr}; + } + if (TyStr.consume_front("bf16")) { + return {Type::getBFloatTy(Ctx), TyStr}; + } + if (TyStr.consume_front("f32")) { + return {Type::getFloatTy(Ctx), TyStr}; + } + if (TyStr.consume_front("f64")) { + return {Type::getDoubleTy(Ctx), TyStr}; + } + unsigned IntBitWidth; + if (TyStr.consume_front("i") && !TyStr.consumeInteger(10, IntBitWidth)) { + return {IntegerType::get(Ctx, IntBitWidth), TyStr}; + } + + return {nullptr, TyStr}; +} + +std::string BuiltinInfo::getMuxBuiltinName(BuiltinID ID, + ArrayRef OverloadInfo) { assert(isMuxBuiltinID(ID)); switch (ID) { default: - llvm_unreachable("Unhandled mux builtin"); + break; case eMuxBuiltinIsFTZ: return MuxBuiltins::isftz; case eMuxBuiltinUseFast: @@ -490,6 +744,10 @@ StringRef BuiltinInfo::getMuxBuiltinName(BuiltinID ID) { return MuxBuiltins::get_local_linear_id; case eMuxBuiltinGetEnqueuedLocalSize: return MuxBuiltins::get_enqueued_local_size; + case eMuxBuiltinGetSubGroupSize: + return MuxBuiltins::get_sub_group_size; + case eMuxBuiltinGetSubGroupLocalId: + return MuxBuiltins::get_sub_group_local_id; case eMuxBuiltinMemBarrier: return MuxBuiltins::mem_barrier; case eMuxBuiltinWorkGroupBarrier: @@ -497,11 +755,147 @@ StringRef BuiltinInfo::getMuxBuiltinName(BuiltinID ID) { case eMuxBuiltinSubGroupBarrier: return MuxBuiltins::sub_group_barrier; } + + // A sneaky macro to do case statements on all scopes of a group operation. + // Note that it is missing a leading 'case' and a trailing ':' to trick + // clang-format into formatting it like a regular case statement. +#define CASE_GROUP_OP_ALL_SCOPES(OP) \ + eMuxBuiltinVecgroup##OP : case eMuxBuiltinSubgroup##OP: \ + case eMuxBuiltinWorkgroup##OP + + std::string BaseName = [](BuiltinID ID) { + // For simplicity, return all group operations as 'work_group' and replace + // the string with 'sub_group' or 'vec_group' post-hoc. + switch (ID) { + default: + return ""; + case CASE_GROUP_OP_ALL_SCOPES(All): + return "__mux_work_group_all"; + case CASE_GROUP_OP_ALL_SCOPES(Any): + return "__mux_work_group_any"; + case CASE_GROUP_OP_ALL_SCOPES(Broadcast): + return "__mux_work_group_broadcast"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceAdd): + return "__mux_work_group_reduce_add"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFAdd): + return "__mux_work_group_reduce_fadd"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMin): + return "__mux_work_group_reduce_smin"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMin): + return "__mux_work_group_reduce_umin"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMin): + return "__mux_work_group_reduce_fmin"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMax): + return "__mux_work_group_reduce_smax"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMax): + return "__mux_work_group_reduce_umax"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMax): + return "__mux_work_group_reduce_fmax"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceMul): + return "__mux_work_group_reduce_mul"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMul): + return "__mux_work_group_reduce_fmul"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceAnd): + return "__mux_work_group_reduce_and"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceOr): + return "__mux_work_group_reduce_or"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceXor): + return "__mux_work_group_reduce_xor"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalAnd): + return "__mux_work_group_reduce_logical_and"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalOr): + return "__mux_work_group_reduce_logical_or"; + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalXor): + return "__mux_work_group_reduce_logical_xor"; + case CASE_GROUP_OP_ALL_SCOPES(ScanAddInclusive): + return "__mux_work_group_scan_inclusive_add"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddInclusive): + return "__mux_work_group_scan_inclusive_fadd"; + case CASE_GROUP_OP_ALL_SCOPES(ScanAddExclusive): + return "__mux_work_group_scan_exclusive_add"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddExclusive): + return "__mux_work_group_scan_exclusive_fadd"; + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinInclusive): + return "__mux_work_group_scan_inclusive_smin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinInclusive): + return "__mux_work_group_scan_inclusive_umin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinInclusive): + return "__mux_work_group_scan_inclusive_fmin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinExclusive): + return "__mux_work_group_scan_exclusive_smin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinExclusive): + return "__mux_work_group_scan_exclusive_umin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinExclusive): + return "__mux_work_group_scan_exclusive_fmin"; + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxInclusive): + return "__mux_work_group_scan_inclusive_smax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxInclusive): + return "__mux_work_group_scan_inclusive_umax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxInclusive): + return "__mux_work_group_scan_inclusive_fmax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxExclusive): + return "__mux_work_group_scan_exclusive_smax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxExclusive): + return "__mux_work_group_scan_exclusive_umax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxExclusive): + return "__mux_work_group_scan_exclusive_fmax"; + case CASE_GROUP_OP_ALL_SCOPES(ScanMulInclusive): + return "__mux_work_group_scan_inclusive_mul"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulInclusive): + return "__mux_work_group_scan_inclusive_fmul"; + case CASE_GROUP_OP_ALL_SCOPES(ScanMulExclusive): + return "__mux_work_group_scan_exclusive_mul"; + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulExclusive): + return "__mux_work_group_scan_exclusive_fmul"; + case CASE_GROUP_OP_ALL_SCOPES(ScanAndInclusive): + return "__mux_work_group_scan_inclusive_and"; + case CASE_GROUP_OP_ALL_SCOPES(ScanAndExclusive): + return "__mux_work_group_scan_exclusive_and"; + case CASE_GROUP_OP_ALL_SCOPES(ScanOrInclusive): + return "__mux_work_group_scan_inclusive_or"; + case CASE_GROUP_OP_ALL_SCOPES(ScanOrExclusive): + return "__mux_work_group_scan_exclusive_or"; + case CASE_GROUP_OP_ALL_SCOPES(ScanXorInclusive): + return "__mux_work_group_scan_inclusive_xor"; + case CASE_GROUP_OP_ALL_SCOPES(ScanXorExclusive): + return "__mux_work_group_scan_exclusive_xor"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndInclusive): + return "__mux_work_group_scan_inclusive_logical_and"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndExclusive): + return "__mux_work_group_scan_exclusive_logical_and"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrInclusive): + return "__mux_work_group_scan_inclusive_logical_or"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrExclusive): + return "__mux_work_group_scan_exclusive_logical_or"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorInclusive): + return "__mux_work_group_scan_inclusive_logical_xor"; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorExclusive): + return "__mux_work_group_scan_exclusive_logical_xor"; + } + }(ID); + + if (!BaseName.empty()) { + assert(!OverloadInfo.empty() && + "Must know how to overload group operation"); + if (ID >= eFirstMuxSubgroupCollectiveBuiltin && + ID <= eLastMuxSubgroupCollectiveBuiltin) { + // Replace 'work' with 'sub' + BaseName = BaseName.replace(6, 4, "sub"); + } else if (ID >= eFirstMuxVecgroupCollectiveBuiltin && + ID <= eLastMuxVecgroupCollectiveBuiltin) { + // Replace 'work' with 'vec' + BaseName = BaseName.replace(6, 4, "vec"); + } + auto *const Ty = OverloadInfo.front(); + return BaseName + "_" + getMangledTypeStr(Ty); + } + llvm_unreachable("Unhandled mux builtin"); } -Function *BuiltinInfo::defineMuxBuiltin(BuiltinID ID, Module &M) { +Function *BuiltinInfo::defineMuxBuiltin(BuiltinID ID, Module &M, + ArrayRef OverloadInfo) { assert(isMuxBuiltinID(ID) && "Only handling mux builtins"); - Function *F = M.getFunction(getMuxBuiltinName(ID)); + Function *F = M.getFunction(getMuxBuiltinName(ID, OverloadInfo)); // FIXME: We'd ideally want to declare it here to reduce pass // inter-dependencies. assert(F && "Function should have been pre-declared"); @@ -509,13 +903,202 @@ Function *BuiltinInfo::defineMuxBuiltin(BuiltinID ID, Module &M) { return F; } // Defer to the mux implementation to define this builtin. - return MuxImpl->defineMuxBuiltin(ID, M); + return MuxImpl->defineMuxBuiltin(ID, M, OverloadInfo); } -Function *BuiltinInfo::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M) { +Function *BuiltinInfo::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M, + ArrayRef OverloadInfo) { assert(isMuxBuiltinID(ID) && "Only handling mux builtins"); // Defer to the mux implementation to get/declare this builtin. - return MuxImpl->getOrDeclareMuxBuiltin(ID, M); + return MuxImpl->getOrDeclareMuxBuiltin(ID, M, OverloadInfo); +} + +std::optional BuiltinInfo::isMuxGroupCollective(BuiltinID ID) { + GroupCollective Collective; + + if (ID >= eFirstMuxSubgroupCollectiveBuiltin && + ID <= eLastMuxSubgroupCollectiveBuiltin) { + Collective.Scope = GroupCollective::ScopeKind::SubGroup; + } else if (ID >= eFirstMuxWorkgroupCollectiveBuiltin && + ID <= eLastMuxWorkgroupCollectiveBuiltin) { + Collective.Scope = GroupCollective::ScopeKind::WorkGroup; + } else if (ID >= eFirstMuxVecgroupCollectiveBuiltin && + ID <= eLastMuxVecgroupCollectiveBuiltin) { + Collective.Scope = GroupCollective::ScopeKind::VectorGroup; + } else { + return std::nullopt; + } + + // A sneaky macro to do case statements on all scopes of a group operation. + // Note that it is missing a leading 'case' and a trailing ':' to trick + // clang-format into formatting it like a regular case statement. +#define CASE_GROUP_OP_ALL_SCOPES(OP) \ + eMuxBuiltinVecgroup##OP : case eMuxBuiltinSubgroup##OP: \ + case eMuxBuiltinWorkgroup##OP + + switch (ID) { + default: + llvm_unreachable("Unhandled mux group builtin"); + case CASE_GROUP_OP_ALL_SCOPES(All): + Collective.Op = GroupCollective::OpKind::All; + break; + case CASE_GROUP_OP_ALL_SCOPES(Any): + Collective.Op = GroupCollective::OpKind::Any; + break; + case CASE_GROUP_OP_ALL_SCOPES(Broadcast): + Collective.Op = GroupCollective::OpKind::Broadcast; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalAnd): + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalOr): + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalXor): + Collective.IsLogical = true; + [[fallthrough]]; + case CASE_GROUP_OP_ALL_SCOPES(ReduceAdd): + case CASE_GROUP_OP_ALL_SCOPES(ReduceFAdd): + case CASE_GROUP_OP_ALL_SCOPES(ReduceMul): + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMul): + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMin): + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMin): + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMin): + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMax): + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMax): + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMax): + case CASE_GROUP_OP_ALL_SCOPES(ReduceAnd): + case CASE_GROUP_OP_ALL_SCOPES(ReduceOr): + case CASE_GROUP_OP_ALL_SCOPES(ReduceXor): + Collective.Op = GroupCollective::OpKind::Reduction; + break; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorInclusive): + Collective.IsLogical = true; + [[fallthrough]]; + case CASE_GROUP_OP_ALL_SCOPES(ScanAddInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanMulInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanAndInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanOrInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanXorInclusive): + Collective.Op = GroupCollective::OpKind::ScanInclusive; + break; + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorExclusive): + Collective.IsLogical = true; + [[fallthrough]]; + case CASE_GROUP_OP_ALL_SCOPES(ScanAddExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanMulExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanAndExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanOrExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanXorExclusive): + Collective.Op = GroupCollective::OpKind::ScanExclusive; + break; + } + + // Then the recurrence kind. + if (Collective.Op == GroupCollective::OpKind::All) { + Collective.Recurrence = RecurKind::And; + } else if (Collective.Op == GroupCollective::OpKind::Any) { + Collective.Recurrence = RecurKind::Or; + } else if (Collective.Op == GroupCollective::OpKind::Reduction || + Collective.Op == GroupCollective::OpKind::ScanExclusive || + Collective.Op == GroupCollective::OpKind::ScanInclusive) { + switch (ID) { + case CASE_GROUP_OP_ALL_SCOPES(ReduceAdd): + case CASE_GROUP_OP_ALL_SCOPES(ScanAddInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanAddExclusive): + Collective.Recurrence = RecurKind::Add; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFAdd): + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFAddExclusive): + Collective.Recurrence = RecurKind::FAdd; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceMul): + case CASE_GROUP_OP_ALL_SCOPES(ScanMulInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanMulExclusive): + Collective.Recurrence = RecurKind::Mul; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMul): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMulExclusive): + Collective.Recurrence = RecurKind::FMul; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMin): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMinExclusive): + Collective.Recurrence = RecurKind::SMin; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMin): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMinExclusive): + Collective.Recurrence = RecurKind::UMin; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMin): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMinExclusive): + Collective.Recurrence = RecurKind::FMin; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceSMax): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanSMaxExclusive): + Collective.Recurrence = RecurKind::SMax; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceUMax): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanUMaxExclusive): + Collective.Recurrence = RecurKind::UMax; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceFMax): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanFMaxExclusive): + Collective.Recurrence = RecurKind::FMax; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceAnd): + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalAnd): + case CASE_GROUP_OP_ALL_SCOPES(ScanAndInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanAndExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalAndExclusive): + Collective.Recurrence = RecurKind::And; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceOr): + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalOr): + case CASE_GROUP_OP_ALL_SCOPES(ScanOrInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanOrExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalOrExclusive): + Collective.Recurrence = RecurKind::Or; + break; + case CASE_GROUP_OP_ALL_SCOPES(ReduceXor): + case CASE_GROUP_OP_ALL_SCOPES(ReduceLogicalXor): + case CASE_GROUP_OP_ALL_SCOPES(ScanXorInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanXorExclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorInclusive): + case CASE_GROUP_OP_ALL_SCOPES(ScanLogicalXorExclusive): + Collective.Recurrence = RecurKind::Xor; + break; + } + } else if (Collective.Op != GroupCollective::OpKind::Broadcast) { + llvm_unreachable("Unhandled mux group operation"); + } + + return Collective; } } // namespace utils diff --git a/modules/compiler/utils/source/cl_builtin_info.cpp b/modules/compiler/utils/source/cl_builtin_info.cpp index 4635439e5..62331a740 100644 --- a/modules/compiler/utils/source/cl_builtin_info.cpp +++ b/modules/compiler/utils/source/cl_builtin_info.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -116,6 +117,8 @@ enum CLBuiltinID : compiler::utils::BuiltinID { eCLBuiltinGetGlobalLinearId, /// @brief OpenCL builtin 'get_sub_group_local_id' (OpenCL >= 3.0). eCLBuiltinGetSubgroupLocalId, + /// @brief OpenCL builtin 'get_sub_group_size' (OpenCL >= 3.0). + eCLBuiltinGetSubgroupSize, // 6.12.2 Math Functions /// @brief OpenCL builtin 'fmax'. @@ -549,6 +552,7 @@ static const CLBuiltinEntry Builtins[] = { {eCLBuiltinGetLocalLinearId, "get_local_linear_id", OpenCLC20}, {eCLBuiltinGetGlobalLinearId, "get_global_linear_id", OpenCLC20}, {eCLBuiltinGetSubgroupLocalId, "get_sub_group_local_id", OpenCLC30}, + {eCLBuiltinGetSubgroupSize, "get_sub_group_size", OpenCLC30}, // 6.12.2 Math Functions {eCLBuiltinFMax, "fmax"}, @@ -1010,6 +1014,7 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const { bool IsConvergent = false; unsigned Properties = eBuiltinPropertyNone; + llvm::SmallVector OverloadInfo; switch (ID) { default: // Assume convergence on unknown builtins. @@ -1085,11 +1090,15 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const { case eCLBuiltinGetGlobalId: case eCLBuiltinGetLocalSize: case eCLBuiltinGetLocalLinearId: - case eCLBuiltinGetSubgroupLocalId: case eCLBuiltinGetGlobalLinearId: Properties |= eBuiltinPropertyWorkItem; Properties |= eBuiltinPropertyRematerializable; break; + case eCLBuiltinGetSubgroupLocalId: + Properties |= eBuiltinPropertyWorkItem; + Properties |= eBuiltinPropertyRematerializable; + Properties |= eBuiltinPropertyMapToMuxGroupBuiltin; + break; case eCLBuiltinGetLocalId: Properties |= eBuiltinPropertyWorkItem; Properties |= eBuiltinPropertyLocalID; @@ -1204,6 +1213,9 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const { case eCLBuiltinAtomicWorkItemFence: Properties |= eBuiltinPropertyMapToMuxSyncBuiltin; break; + case eCLBuiltinGetSubgroupSize: + Properties |= eBuiltinPropertyMapToMuxGroupBuiltin; + break; // Subgroup collectives case eCLBuiltinSubgroupAll: case eCLBuiltinSubgroupAny: @@ -1273,6 +1285,11 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const { case eCLBuiltinWorkgroupScanLogicalXorInclusive: case eCLBuiltinWorkgroupScanLogicalXorExclusive: IsConvergent = true; + Properties |= eBuiltinPropertyMapToMuxGroupBuiltin; + if (ID != eCLBuiltinWorkgroupAll && ID != eCLBuiltinWorkgroupAny && + ID != eCLBuiltinSubgroupAll && ID != eCLBuiltinSubgroupAny) { + OverloadInfo.push_back(Callee.getArg(0)->getType()); + } break; } @@ -2898,7 +2915,7 @@ static multi_llvm::Optional parseMemoryOrderParam(Value *const P) { return multi_llvm::None; } -CallInst *CLBuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin( +Instruction *CLBuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin( CallInst &CI, BIMuxInfoConcept &BIMuxImpl) { auto &M = *CI.getModule(); auto *const F = CI.getCalledFunction(); @@ -2982,6 +2999,461 @@ CallInst *CLBuiltinInfo::mapSyncBuiltinToMuxSyncBuiltin( } } +Instruction *CLBuiltinInfo::mapGroupBuiltinToMuxGroupBuiltin( + CallInst &CI, BIMuxInfoConcept &BIMuxImpl) { + auto &M = *CI.getModule(); + auto *const F = CI.getCalledFunction(); + assert(F && "No calling function?"); + auto const Builtin = analyzeBuiltin(*F); + + if (Builtin.ID == eCLBuiltinGetSubgroupSize || + Builtin.ID == eCLBuiltinGetSubgroupLocalId) { + BaseBuiltinID MuxBuiltinID = Builtin.ID == eCLBuiltinGetSubgroupSize + ? eMuxBuiltinGetSubGroupSize + : eMuxBuiltinGetSubGroupLocalId; + auto *const MuxBuiltinFn = + BIMuxImpl.getOrDeclareMuxBuiltin(MuxBuiltinID, M); + auto *const NewCI = + CallInst::Create(MuxBuiltinFn, /*Args*/ {}, CI.getName(), &CI); + NewCI->setAttributes(MuxBuiltinFn->getAttributes()); + return NewCI; + } + + // Some ops need extra checking to determine their mux ID: + // * add/mul operations are split into integer/float + // * min/max operations are split into signed/unsigned/float + // So we set a 'base' builtin ID for these operations to the (unsigned) + // integer variant and do a checking step afterwards where we refine the + // builtin ID. + bool RecheckOpType = false; + BaseBuiltinID MuxBuiltinID = eBuiltinInvalid; + switch (Builtin.ID) { + default: + return nullptr; + case eCLBuiltinSubgroupAll: + MuxBuiltinID = eMuxBuiltinSubgroupAll; + break; + case eCLBuiltinSubgroupAny: + MuxBuiltinID = eMuxBuiltinSubgroupAny; + break; + case eCLBuiltinSubgroupBroadcast: + MuxBuiltinID = eMuxBuiltinSubgroupBroadcast; + break; + case eCLBuiltinSubgroupReduceAdd: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupReduceAdd; + break; + case eCLBuiltinSubgroupReduceMin: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupReduceUMin; + break; + case eCLBuiltinSubgroupReduceMax: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupReduceUMax; + break; + case eCLBuiltinSubgroupReduceMul: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupReduceMul; + break; + case eCLBuiltinSubgroupReduceAnd: + MuxBuiltinID = eMuxBuiltinSubgroupReduceAnd; + break; + case eCLBuiltinSubgroupReduceOr: + MuxBuiltinID = eMuxBuiltinSubgroupReduceOr; + break; + case eCLBuiltinSubgroupReduceXor: + MuxBuiltinID = eMuxBuiltinSubgroupReduceXor; + break; + case eCLBuiltinSubgroupReduceLogicalAnd: + MuxBuiltinID = eMuxBuiltinSubgroupReduceLogicalAnd; + break; + case eCLBuiltinSubgroupReduceLogicalOr: + MuxBuiltinID = eMuxBuiltinSubgroupReduceLogicalOr; + break; + case eCLBuiltinSubgroupReduceLogicalXor: + MuxBuiltinID = eMuxBuiltinSubgroupReduceLogicalXor; + break; + case eCLBuiltinSubgroupScanAddInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanAddInclusive; + break; + case eCLBuiltinSubgroupScanAddExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanAddExclusive; + break; + case eCLBuiltinSubgroupScanMinInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanUMinInclusive; + break; + case eCLBuiltinSubgroupScanMinExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanUMinExclusive; + break; + case eCLBuiltinSubgroupScanMaxInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanUMaxInclusive; + break; + case eCLBuiltinSubgroupScanMaxExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanUMaxExclusive; + break; + case eCLBuiltinSubgroupScanMulInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanMulInclusive; + break; + case eCLBuiltinSubgroupScanMulExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinSubgroupScanMulExclusive; + break; + case eCLBuiltinSubgroupScanAndInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanAndInclusive; + break; + case eCLBuiltinSubgroupScanAndExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanAndExclusive; + break; + case eCLBuiltinSubgroupScanOrInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanOrInclusive; + break; + case eCLBuiltinSubgroupScanOrExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanOrExclusive; + break; + case eCLBuiltinSubgroupScanXorInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanXorInclusive; + break; + case eCLBuiltinSubgroupScanXorExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanXorExclusive; + break; + case eCLBuiltinSubgroupScanLogicalAndInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalAndInclusive; + break; + case eCLBuiltinSubgroupScanLogicalAndExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalAndExclusive; + break; + case eCLBuiltinSubgroupScanLogicalOrInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalOrInclusive; + break; + case eCLBuiltinSubgroupScanLogicalOrExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalOrExclusive; + break; + case eCLBuiltinSubgroupScanLogicalXorInclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalXorInclusive; + break; + case eCLBuiltinSubgroupScanLogicalXorExclusive: + MuxBuiltinID = eMuxBuiltinSubgroupScanLogicalXorExclusive; + break; + case eCLBuiltinWorkgroupAll: + MuxBuiltinID = eMuxBuiltinWorkgroupAll; + break; + case eCLBuiltinWorkgroupAny: + MuxBuiltinID = eMuxBuiltinWorkgroupAny; + break; + case eCLBuiltinWorkgroupBroadcast: + MuxBuiltinID = eMuxBuiltinWorkgroupBroadcast; + break; + case eCLBuiltinWorkgroupReduceAdd: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupReduceAdd; + break; + case eCLBuiltinWorkgroupReduceMin: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupReduceUMin; + break; + case eCLBuiltinWorkgroupReduceMax: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupReduceUMax; + break; + case eCLBuiltinWorkgroupReduceMul: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupReduceMul; + break; + case eCLBuiltinWorkgroupReduceAnd: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceAnd; + break; + case eCLBuiltinWorkgroupReduceOr: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceOr; + break; + case eCLBuiltinWorkgroupReduceXor: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceXor; + break; + case eCLBuiltinWorkgroupReduceLogicalAnd: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceLogicalAnd; + break; + case eCLBuiltinWorkgroupReduceLogicalOr: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceLogicalOr; + break; + case eCLBuiltinWorkgroupReduceLogicalXor: + MuxBuiltinID = eMuxBuiltinWorkgroupReduceLogicalXor; + break; + case eCLBuiltinWorkgroupScanAddInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanAddInclusive; + break; + case eCLBuiltinWorkgroupScanAddExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanAddExclusive; + break; + case eCLBuiltinWorkgroupScanMinInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanUMinInclusive; + break; + case eCLBuiltinWorkgroupScanMinExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanUMinExclusive; + break; + case eCLBuiltinWorkgroupScanMaxInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanUMaxInclusive; + break; + case eCLBuiltinWorkgroupScanMaxExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanUMaxExclusive; + break; + case eCLBuiltinWorkgroupScanMulInclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanMulInclusive; + break; + case eCLBuiltinWorkgroupScanMulExclusive: + RecheckOpType = true; + MuxBuiltinID = eMuxBuiltinWorkgroupScanMulExclusive; + break; + case eCLBuiltinWorkgroupScanAndInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanAndInclusive; + break; + case eCLBuiltinWorkgroupScanAndExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanAndExclusive; + break; + case eCLBuiltinWorkgroupScanOrInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanOrInclusive; + break; + case eCLBuiltinWorkgroupScanOrExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanOrExclusive; + break; + case eCLBuiltinWorkgroupScanXorInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanXorInclusive; + break; + case eCLBuiltinWorkgroupScanXorExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanXorExclusive; + break; + case eCLBuiltinWorkgroupScanLogicalAndInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalAndInclusive; + break; + case eCLBuiltinWorkgroupScanLogicalAndExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalAndExclusive; + break; + case eCLBuiltinWorkgroupScanLogicalOrInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalOrInclusive; + break; + case eCLBuiltinWorkgroupScanLogicalOrExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalOrExclusive; + break; + case eCLBuiltinWorkgroupScanLogicalXorInclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalXorInclusive; + break; + case eCLBuiltinWorkgroupScanLogicalXorExclusive: + MuxBuiltinID = eMuxBuiltinWorkgroupScanLogicalXorExclusive; + break; + } + + if (RecheckOpType) { + // We've assumed (unsigned) integer operations, but we may actually have + // signed integer, or floating point, operations. Refine the builtin ID to + // the correct 'overload' now. + compiler::utils::NameMangler Mangler(&F->getContext()); + SmallVector ArgumentTypes; + SmallVector Qualifiers; + + const auto DemangledName = std::string( + Mangler.demangleName(F->getName(), ArgumentTypes, Qualifiers)); + + assert(Qualifiers.size() == 1 && ArgumentTypes.size() == 1 && + "Unknown collective builtin"); + auto &Qual = Qualifiers[0]; + + bool IsSignedInt = false; + while (!IsSignedInt && Qual.getCount()) { + IsSignedInt |= Qual.pop_front() == compiler::utils::eTypeQualSignedInt; + } + + bool IsFP = ArgumentTypes[0]->isFloatingPointTy(); + switch (MuxBuiltinID) { + default: + llvm_unreachable("unknown group operation for which to check the type"); + case eMuxBuiltinSubgroupReduceAdd: + MuxBuiltinID = IsFP ? eMuxBuiltinSubgroupReduceFAdd : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupReduceMul: + MuxBuiltinID = IsFP ? eMuxBuiltinSubgroupReduceFMul : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupReduceUMin: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupReduceFMin + : (IsSignedInt ? eMuxBuiltinSubgroupReduceSMin : MuxBuiltinID); + break; + case eMuxBuiltinSubgroupReduceUMax: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupReduceFMax + : (IsSignedInt ? eMuxBuiltinSubgroupReduceSMax : MuxBuiltinID); + break; + case eMuxBuiltinSubgroupScanAddInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupScanFAddInclusive : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupScanAddExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupScanFAddExclusive : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupScanMulInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupScanFMulInclusive : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupScanMulExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinSubgroupScanFMulExclusive : MuxBuiltinID; + break; + case eMuxBuiltinSubgroupScanUMinInclusive: + MuxBuiltinID = IsFP + ? eMuxBuiltinSubgroupScanFMinInclusive + : (IsSignedInt ? eMuxBuiltinSubgroupScanSMinInclusive + : MuxBuiltinID); + break; + case eMuxBuiltinSubgroupScanUMinExclusive: + MuxBuiltinID = IsFP + ? eMuxBuiltinSubgroupScanFMinExclusive + : (IsSignedInt ? eMuxBuiltinSubgroupScanSMinExclusive + : MuxBuiltinID); + break; + case eMuxBuiltinSubgroupScanUMaxInclusive: + MuxBuiltinID = IsFP + ? eMuxBuiltinSubgroupScanFMaxInclusive + : (IsSignedInt ? eMuxBuiltinSubgroupScanSMaxInclusive + : MuxBuiltinID); + break; + case eMuxBuiltinSubgroupScanUMaxExclusive: + MuxBuiltinID = IsFP + ? eMuxBuiltinSubgroupScanFMaxExclusive + : (IsSignedInt ? eMuxBuiltinSubgroupScanSMaxExclusive + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupReduceAdd: + MuxBuiltinID = IsFP ? eMuxBuiltinWorkgroupReduceFAdd : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupReduceMul: + MuxBuiltinID = IsFP ? eMuxBuiltinWorkgroupReduceFMul : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupReduceUMin: + MuxBuiltinID = IsFP ? eMuxBuiltinWorkgroupReduceFMin + : (IsSignedInt ? eMuxBuiltinWorkgroupReduceSMin + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupReduceUMax: + MuxBuiltinID = IsFP ? eMuxBuiltinWorkgroupReduceFMax + : (IsSignedInt ? eMuxBuiltinWorkgroupReduceSMax + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupScanAddInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFAddInclusive : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupScanAddExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFAddExclusive : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupScanMulInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMulInclusive : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupScanMulExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMulExclusive : MuxBuiltinID; + break; + case eMuxBuiltinWorkgroupScanUMinInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMinInclusive + : (IsSignedInt ? eMuxBuiltinWorkgroupScanSMinInclusive + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupScanUMinExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMinExclusive + : (IsSignedInt ? eMuxBuiltinWorkgroupScanSMinExclusive + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupScanUMaxInclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMaxInclusive + : (IsSignedInt ? eMuxBuiltinWorkgroupScanSMaxInclusive + : MuxBuiltinID); + break; + case eMuxBuiltinWorkgroupScanUMaxExclusive: + MuxBuiltinID = + IsFP ? eMuxBuiltinWorkgroupScanFMaxExclusive + : (IsSignedInt ? eMuxBuiltinWorkgroupScanSMaxExclusive + : MuxBuiltinID); + break; + } + } + + bool const IsAnyAll = MuxBuiltinID == eMuxBuiltinSubgroupAny || + MuxBuiltinID == eMuxBuiltinSubgroupAll || + MuxBuiltinID == eMuxBuiltinWorkgroupAny || + MuxBuiltinID == eMuxBuiltinWorkgroupAll; + SmallVector OverloadInfo; + if (!IsAnyAll) { + OverloadInfo.push_back(CI.getOperand(0)->getType()); + } else { + OverloadInfo.push_back(IntegerType::getInt1Ty(M.getContext())); + } + + auto *const MuxBuiltinFn = + BIMuxImpl.getOrDeclareMuxBuiltin(MuxBuiltinID, M, OverloadInfo); + + assert(MuxBuiltinFn && "Missing mux builtin"); + auto *const I32Ty = Type::getInt32Ty(M.getContext()); + auto *const I64Ty = Type::getInt64Ty(M.getContext()); + + SmallVector Args; + if (MuxBuiltinID >= eFirstMuxWorkgroupCollectiveBuiltin && + MuxBuiltinID <= eLastMuxWorkgroupCollectiveBuiltin) { + // Work-group operations have a barrier ID first. + Args.push_back(ConstantInt::get(I32Ty, 0)); + } + // Then the arg itself + // If it's an any/all operation, we must first reduce to i1 because that's how + // the mux builtins expect their arguments. + auto *Val = CI.getOperand(0); + if (!IsAnyAll) { + Args.push_back(Val); + } else { + assert(Val->getType()->isIntegerTy()); + auto *NEZero = + ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, Val, + ConstantInt::getNullValue(Val->getType()), "", &CI); + Args.push_back(NEZero); + } + + if (MuxBuiltinID == eMuxBuiltinSubgroupBroadcast) { + // Pass on the ID parameter + Args.push_back(CI.getOperand(1)); + } + if (MuxBuiltinID == eMuxBuiltinWorkgroupBroadcast) { + // The mux version always has three indices. Any missing ones are replaced + // with zeros + for (unsigned i = 0, e = CI.arg_size(); i != 3; i++) { + Args.push_back(1 + i < e ? CI.getOperand(1 + i) + : ConstantInt::getNullValue(I64Ty)); + } + } + + auto *const NewCI = CallInst::Create(MuxBuiltinFn, Args, CI.getName(), &CI); + NewCI->setAttributes(MuxBuiltinFn->getAttributes()); + + if (!IsAnyAll) { + return NewCI; + } + // For any/all we need to recreate the original i32 return value. + return SExtInst::Create(Instruction::SExt, NewCI, CI.getType(), "sext", &CI); +} + //////////////////////////////////////////////////////////////////////////////// Function *CLBuiltinLoader::materializeBuiltin(StringRef BuiltinName, diff --git a/modules/compiler/utils/source/define_mux_builtins_pass.cpp b/modules/compiler/utils/source/define_mux_builtins_pass.cpp index 8b8281430..75cf246e6 100644 --- a/modules/compiler/utils/source/define_mux_builtins_pass.cpp +++ b/modules/compiler/utils/source/define_mux_builtins_pass.cpp @@ -41,7 +41,8 @@ PreservedAnalyses compiler::utils::DefineMuxBuiltinsPass::run( // Define the builtin. If it declares any new dependent builtins, those // will be appended to the module's function list and so will be // encountered by later iterations. - if (BI.defineMuxBuiltin(BI.analyzeBuiltin(F).ID, M)) { + auto Builtin = BI.analyzeBuiltin(F); + if (BI.defineMuxBuiltin(Builtin.ID, M, Builtin.mux_overload_info)) { Changed = true; } } diff --git a/modules/compiler/utils/source/degenerate_sub_group_pass.cpp b/modules/compiler/utils/source/degenerate_sub_group_pass.cpp index cc613b408..f8ce88136 100644 --- a/modules/compiler/utils/source/degenerate_sub_group_pass.cpp +++ b/modules/compiler/utils/source/degenerate_sub_group_pass.cpp @@ -54,7 +54,7 @@ bool isSubGroupFunction(CallInst *CI) { auto *Fcn = CI->getCalledFunction(); assert(Fcn && "virtual calls are not supported"); if (auto GC = compiler::utils::isGroupCollective(Fcn)) { - return GC->scope == compiler::utils::GroupCollective::Scope::SubGroup; + return GC->Scope == compiler::utils::GroupCollective::ScopeKind::SubGroup; } return Fcn->getName() == compiler::utils::MuxBuiltins::sub_group_barrier; diff --git a/modules/compiler/utils/source/group_collective_helpers.cpp b/modules/compiler/utils/source/group_collective_helpers.cpp index c4b112dc7..47b55f00c 100644 --- a/modules/compiler/utils/source/group_collective_helpers.cpp +++ b/modules/compiler/utils/source/group_collective_helpers.cpp @@ -90,40 +90,42 @@ compiler::utils::isGroupCollective(llvm::Function *f) { // Parse the scope. if (L.Consume("work_group_")) { - collective.scope = GroupCollective::Scope::WorkGroup; + collective.Scope = GroupCollective::ScopeKind::WorkGroup; } else if (L.Consume("sub_group_")) { - collective.scope = GroupCollective::Scope::SubGroup; + collective.Scope = GroupCollective::ScopeKind::SubGroup; + } else if (L.Consume("vec_group_")) { + collective.Scope = GroupCollective::ScopeKind::VectorGroup; } else { return multi_llvm::None; } // Then the operation type. if (L.Consume("reduce_")) { - collective.op = GroupCollective::Op::Reduction; + collective.Op = GroupCollective::OpKind::Reduction; } else if (L.Consume("all")) { - collective.op = GroupCollective::Op::All; + collective.Op = GroupCollective::OpKind::All; } else if (L.Consume("any")) { - collective.op = GroupCollective::Op::Any; + collective.Op = GroupCollective::OpKind::Any; } else if (L.Consume("scan_exclusive_")) { - collective.op = GroupCollective::Op::ScanExclusive; + collective.Op = GroupCollective::OpKind::ScanExclusive; } else if (L.Consume("scan_inclusive_")) { - collective.op = GroupCollective::Op::ScanInclusive; + collective.Op = GroupCollective::OpKind::ScanInclusive; } else if (L.Consume("broadcast")) { - collective.op = GroupCollective::Op::Broadcast; + collective.Op = GroupCollective::OpKind::Broadcast; } else { return multi_llvm::None; } // Then the recurrence kind. - if (collective.op == GroupCollective::Op::All) { - collective.recurKind = RecurKind::And; - } else if (collective.op == GroupCollective::Op::Any) { - collective.recurKind = RecurKind::Or; - } else if (collective.op == GroupCollective::Op::Reduction || - collective.op == GroupCollective::Op::ScanExclusive || - collective.op == GroupCollective::Op::ScanInclusive) { + if (collective.Op == GroupCollective::OpKind::All) { + collective.Recurrence = RecurKind::And; + } else if (collective.Op == GroupCollective::OpKind::Any) { + collective.Recurrence = RecurKind::Or; + } else if (collective.Op == GroupCollective::OpKind::Reduction || + collective.Op == GroupCollective::OpKind::ScanExclusive || + collective.Op == GroupCollective::OpKind::ScanInclusive) { if (L.Consume("logical_")) { - collective.isLogical = true; + collective.IsLogical = true; } assert(Qualifiers.size() == 1 && ArgumentTypes.size() == 1 && @@ -148,7 +150,7 @@ compiler::utils::isGroupCollective(llvm::Function *f) { return multi_llvm::None; } - collective.recurKind = + collective.Recurrence = StringSwitch(OpKind) .Case("add", isInt ? RecurKind::Add : RecurKind::FAdd) @@ -164,7 +166,7 @@ compiler::utils::isGroupCollective(llvm::Function *f) { .Case("xor", RecurKind::Xor) .Default(RecurKind::None); - if (collective.recurKind == RecurKind::None) { + if (collective.Recurrence == RecurKind::None) { return multi_llvm::None; } } @@ -175,7 +177,8 @@ compiler::utils::isGroupCollective(llvm::Function *f) { return multi_llvm::None; } - collective.func = f; - collective.type = f->getArg(0)->getType(); + collective.Func = f; + collective.Ty = f->getArg(0)->getType(); + return collective; } diff --git a/modules/compiler/utils/source/mux_builtin_info.cpp b/modules/compiler/utils/source/mux_builtin_info.cpp index 8c9416c6c..87c161cfb 100644 --- a/modules/compiler/utils/source/mux_builtin_info.cpp +++ b/modules/compiler/utils/source/mux_builtin_info.cpp @@ -684,9 +684,10 @@ Function *BIMuxInfoConcept::defineDMAWait(Function &F) { return &F; } -Function *BIMuxInfoConcept::defineMuxBuiltin(BuiltinID ID, Module &M) { +Function *BIMuxInfoConcept::defineMuxBuiltin(BuiltinID ID, Module &M, + ArrayRef OverloadInfo) { assert(BuiltinInfo::isMuxBuiltinID(ID) && "Only handling mux builtins"); - Function *F = M.getFunction(BuiltinInfo::getMuxBuiltinName(ID)); + Function *F = M.getFunction(BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo)); // FIXME: We'd ideally want to declare it here to reduce pass // inter-dependencies. assert(F && "Function should have been pre-declared"); @@ -806,24 +807,25 @@ Type *BIMuxInfoConcept::getRemappedTargetExtTy(Type *Ty) { return nullptr; } -Function *BIMuxInfoConcept::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M) { +Function *BIMuxInfoConcept::getOrDeclareMuxBuiltin( + BuiltinID ID, Module &M, ArrayRef OverloadInfo) { assert(BuiltinInfo::isMuxBuiltinID(ID) && "Only handling mux builtins"); - auto FnName = BuiltinInfo::getMuxBuiltinName(ID); + auto FnName = BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo); if (auto *const F = M.getFunction(FnName)) { return F; } - AttrBuilder AB(M.getContext()); + auto &Ctx = M.getContext(); + AttrBuilder AB(Ctx); auto *const SizeTy = getSizeType(M); - auto *const Int32Ty = Type::getInt32Ty(M.getContext()); - auto *const VoidTy = Type::getVoidTy(M.getContext()); + auto *const Int32Ty = Type::getInt32Ty(Ctx); + auto *const Int64Ty = Type::getInt64Ty(Ctx); + auto *const VoidTy = Type::getVoidTy(Ctx); Type *RetTy = nullptr; SmallVector ParamTys; SmallVector ParamNames; switch (ID) { - default: - return nullptr; // Ranked Getters case eMuxBuiltinGetLocalId: case eMuxBuiltinGetGlobalId: @@ -840,13 +842,17 @@ Function *BIMuxInfoConcept::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M) { case eMuxBuiltinGetWorkDim: case eMuxBuiltinGetSubGroupId: case eMuxBuiltinGetNumSubGroups: + case eMuxBuiltinGetSubGroupSize: case eMuxBuiltinGetMaxSubGroupSize: + case eMuxBuiltinGetSubGroupLocalId: case eMuxBuiltinGetLocalLinearId: case eMuxBuiltinGetGlobalLinearId: { // Some builtins return uint, others return size_t RetTy = (ID == eMuxBuiltinGetWorkDim || ID == eMuxBuiltinGetSubGroupId || ID == eMuxBuiltinGetNumSubGroups || - ID == eMuxBuiltinGetMaxSubGroupSize) + ID == eMuxBuiltinGetSubGroupSize || + ID == eMuxBuiltinGetMaxSubGroupSize || + ID == eMuxBuiltinGetSubGroupLocalId) ? Int32Ty : SizeTy; // All of our mux getters are readonly - they may never write data @@ -894,6 +900,38 @@ Function *BIMuxInfoConcept::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M) { AB.addAttribute(Attribute::Convergent); break; } + default: + // Group builtins are more easily found using this helper rather than + // explicitly enumerating each switch case. + if (auto Group = BuiltinInfo::isMuxGroupCollective(ID)) { + RetTy = OverloadInfo.front(); + ParamTys.push_back(RetTy); + ParamNames.push_back("val"); + AB.addAttribute(Attribute::Convergent); + // Broadcasts additionally add ID parameters + if (Group->Op == GroupCollective::OpKind::Broadcast) { + if (Group->Scope == GroupCollective::ScopeKind::SubGroup) { + ParamTys.push_back(Int32Ty); + ParamNames.push_back("lid"); + } else { + ParamTys.push_back(Int64Ty); + ParamNames.push_back("lidx"); + ParamTys.push_back(Int64Ty); + ParamNames.push_back("lidy"); + ParamTys.push_back(Int64Ty); + ParamNames.push_back("lidz"); + } + } + // All work-group operations have a 'barrier id' operand as their first + // parameter. + if (Group->Scope == GroupCollective::ScopeKind::WorkGroup) { + ParamTys.insert(ParamTys.begin(), Int32Ty); + ParamNames.insert(ParamNames.begin(), "id"); + } + } else { + // Unknown mux builtin + return nullptr; + } } assert(RetTy); @@ -922,7 +960,7 @@ Function *BIMuxInfoConcept::getOrDeclareMuxBuiltin(BuiltinID ID, Module &M) { for (unsigned i = 0, e = ParamNames.size(); i != e; i++) { F->getArg(i)->setName(ParamNames[i]); - auto AB = AttrBuilder(M.getContext(), ParamAttrs[i]); + auto AB = AttrBuilder(Ctx, ParamAttrs[i]); F->getArg(i)->addAttrs(AB); } diff --git a/modules/compiler/utils/source/replace_group_funcs_pass.cpp b/modules/compiler/utils/source/replace_group_funcs_pass.cpp new file mode 100644 index 000000000..bc436260d --- /dev/null +++ b/modules/compiler/utils/source/replace_group_funcs_pass.cpp @@ -0,0 +1,50 @@ +// Copyright (C) Codeplay Software Limited +// +// Licensed under the Apache License, Version 2.0 (the "License") with LLVM +// Exceptions; you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +using namespace llvm; + +PreservedAnalyses compiler::utils::ReplaceGroupFuncsPass::run( + Module &M, ModuleAnalysisManager &AM) { + auto &BI = AM.getResult(M); + + SmallVector Calls; + for (auto &F : M.functions()) { + auto B = BI.analyzeBuiltin(F); + if (B.properties & eBuiltinPropertyMapToMuxGroupBuiltin) { + for (auto *U : F.users()) { + if (auto *CI = dyn_cast(U)) { + Calls.push_back(CI); + } + } + } + } + + if (Calls.empty()) { + return PreservedAnalyses::all(); + } + + for (auto *CI : Calls) { + if (auto *const NewCI = BI.mapGroupBuiltinToMuxGroupBuiltin(*CI)) { + CI->replaceAllUsesWith(NewCI); + CI->eraseFromParent(); + } + } + + return PreservedAnalyses::none(); +} diff --git a/modules/compiler/utils/source/replace_wgc_pass.cpp b/modules/compiler/utils/source/replace_wgc_pass.cpp index 12f5adfb8..5fc797853 100644 --- a/modules/compiler/utils/source/replace_wgc_pass.cpp +++ b/modules/compiler/utils/source/replace_wgc_pass.cpp @@ -68,7 +68,7 @@ Value *createSubgroupReduction(IRBuilder<> &Builder, llvm::Value *Src, const compiler::utils::GroupCollective &WGC) { StringRef name; compiler::utils::TypeQualifier Q = compiler::utils::eTypeQualNone; - switch (WGC.recurKind) { + switch (WGC.Recurrence) { default: return nullptr; case RecurKind::And: @@ -76,7 +76,7 @@ Value *createSubgroupReduction(IRBuilder<> &Builder, llvm::Value *Src, name = "sub_group_all"; Q = compiler::utils::eTypeQualSignedInt; } else { - name = !WGC.isLogical ? "sub_group_reduce_and" + name = !WGC.IsLogical ? "sub_group_reduce_and" : "sub_group_reduce_logical_and"; } break; @@ -85,7 +85,7 @@ Value *createSubgroupReduction(IRBuilder<> &Builder, llvm::Value *Src, name = "sub_group_any"; Q = compiler::utils::eTypeQualSignedInt; } else { - name = !WGC.isLogical ? "sub_group_reduce_or" + name = !WGC.IsLogical ? "sub_group_reduce_or" : "sub_group_reduce_logical_or"; } break; @@ -115,7 +115,7 @@ Value *createSubgroupReduction(IRBuilder<> &Builder, llvm::Value *Src, name = "sub_group_reduce_mul"; break; case RecurKind::Xor: - name = !WGC.isLogical ? "sub_group_reduce_xor" + name = !WGC.IsLogical ? "sub_group_reduce_xor" : "sub_group_reduce_logical_xor"; break; } @@ -340,11 +340,11 @@ Value *createBinOp(llvm::IRBuilder<> &Builder, llvm::Value *CurrentVal, void emitWorkGroupReductionBody(const compiler::utils::GroupCollective &WGC, compiler::utils::BuiltinInfo &BI) { // Create a global variable to do the reduction on. - auto &F = *WGC.func; + auto &F = *WGC.Func; auto *const Operand = F.getArg(0); auto *const ReductionType{Operand->getType()}; auto *const ReductionNeutralValue{ - compiler::utils::getNeutralVal(WGC.recurKind, WGC.type)}; + compiler::utils::getNeutralVal(WGC.Recurrence, WGC.Ty)}; auto *const Accumulator = new GlobalVariable{*F.getParent(), ReductionType, @@ -385,7 +385,7 @@ void emitWorkGroupReductionBody(const compiler::utils::GroupCollective &WGC, auto *const CurrentVal = Builder.CreateLoad(ReductionType, Accumulator, "current.val"); auto *const NextVal = createBinOp(Builder, CurrentVal, SubReduce, - WGC.recurKind, WGC.isAnyAll()); + WGC.Recurrence, WGC.isAnyAll()); Builder.CreateStore(NextVal, Accumulator); // Barrier, then read result and exit. @@ -417,7 +417,7 @@ void emitWorkGroupReductionBody(const compiler::utils::GroupCollective &WGC, void emitWorkGroupBroadcastBody(const compiler::utils::GroupCollective &WGC, compiler::utils::BuiltinInfo &BI) { // First arg is always the value to broadcast. - auto &F = *WGC.func; + auto &F = *WGC.Func; auto *const ValueToBroadcast = F.getArg(0); // Create a global variable to do the broadcast through. @@ -521,11 +521,11 @@ void emitWorkGroupBroadcastBody(const compiler::utils::GroupCollective &WGC, void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, compiler::utils::BuiltinInfo &BI) { // Create a global variable to do the scan on. - auto &F = *WGC.func; + auto &F = *WGC.Func; auto *const Operand = F.getArg(0); auto *const ReductionType{Operand->getType()}; auto *const ReductionNeutralValue{ - compiler::utils::getNeutralVal(WGC.recurKind, WGC.type)}; + compiler::utils::getNeutralVal(WGC.Recurrence, WGC.Ty)}; assert(ReductionNeutralValue && "Invalid neutral value"); auto &M = *F.getParent(); auto *const Accumulator = @@ -562,14 +562,14 @@ void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, Builder.CreateLoad(ReductionType, Accumulator, "current.val"); // Perform the subgroup scan operation and add it to the accumulator. - auto *SubScan = createSubgroupScan(Builder, Operand, WGC.recurKind, - IsInclusive, WGC.isLogical); + auto *SubScan = createSubgroupScan(Builder, Operand, WGC.Recurrence, + IsInclusive, WGC.IsLogical); assert(SubScan && "Invalid subgroup scan"); bool const NeedsIdentityFix = !IsInclusive && - (WGC.recurKind == RecurKind::FAdd || WGC.recurKind == RecurKind::FMin || - WGC.recurKind == RecurKind::FMax); + (WGC.Recurrence == RecurKind::FAdd || WGC.Recurrence == RecurKind::FMin || + WGC.Recurrence == RecurKind::FMax); // For FMin/FMax, we need to fix up the identity element on the zeroth // subgroup ID, because it will be +/-INFINITY, but we need it to be NaN. @@ -586,7 +586,7 @@ void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, } auto *const Result = - createBinOp(Builder, CurrentVal, SubScan, WGC.recurKind, WGC.isAnyAll()); + createBinOp(Builder, CurrentVal, SubScan, WGC.Recurrence, WGC.isAnyAll()); // Update the accumulator with the last element of the subgroup scan auto *const LastElement = Builder.CreateNUWSub( @@ -600,11 +600,11 @@ void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, if (!IsInclusive) { auto *const LastSrcValue = createSubgroupBroadcast(Builder, Operand, LastElement, "wgc_sg_tail"); - SubReduce = createBinOp(Builder, LastValue, LastSrcValue, WGC.recurKind, + SubReduce = createBinOp(Builder, LastValue, LastSrcValue, WGC.Recurrence, WGC.isAnyAll()); } auto *const NextVal = createBinOp(Builder, CurrentVal, SubReduce, - WGC.recurKind, WGC.isAnyAll()); + WGC.Recurrence, WGC.isAnyAll()); Builder.CreateStore(NextVal, Accumulator); // A third barrier ensures that if there are two or more scans, they can't get @@ -613,7 +613,7 @@ void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, if (NeedsIdentityFix) { auto *const Identity = - compiler::utils::getIdentityVal(WGC.recurKind, WGC.type); + compiler::utils::getIdentityVal(WGC.Recurrence, WGC.Ty); auto *const getLocalIDFn = BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetLocalId, M); auto *const IsZero = compiler::utils::isThreadZero(EntryBB, *getLocalIDFn); @@ -629,17 +629,17 @@ void emitWorkGroupScanBody(const compiler::utils::GroupCollective &WGC, /// @param[in] WGC Work-group collective function to be defined. void emitWorkGroupCollectiveBody(const compiler::utils::GroupCollective &WGC, compiler::utils::BuiltinInfo &BI) { - switch (WGC.op) { - case compiler::utils::GroupCollective::Op::All: - case compiler::utils::GroupCollective::Op::Any: - case compiler::utils::GroupCollective::Op::Reduction: + switch (WGC.Op) { + case compiler::utils::GroupCollective::OpKind::All: + case compiler::utils::GroupCollective::OpKind::Any: + case compiler::utils::GroupCollective::OpKind::Reduction: emitWorkGroupReductionBody(WGC, BI); break; - case compiler::utils::GroupCollective::Op::Broadcast: + case compiler::utils::GroupCollective::OpKind::Broadcast: emitWorkGroupBroadcastBody(WGC, BI); break; - case compiler::utils::GroupCollective::Op::ScanExclusive: - case compiler::utils::GroupCollective::Op::ScanInclusive: + case compiler::utils::GroupCollective::OpKind::ScanExclusive: + case compiler::utils::GroupCollective::OpKind::ScanInclusive: emitWorkGroupScanBody(WGC, BI); break; default: @@ -663,7 +663,7 @@ PreservedAnalyses compiler::utils::ReplaceWGCPass::run( SmallVector WGCollectives{}; for (auto &F : M) { auto WGC = isGroupCollective(&F); - if (WGC && WGC->scope == GroupCollective::Scope::WorkGroup) { + if (WGC && WGC->Scope == GroupCollective::ScopeKind::WorkGroup) { WGCollectives.push_back(*WGC); } } diff --git a/modules/mux/include/mux/mux.h b/modules/mux/include/mux/mux.h index 7b10355e2..18dbbe640 100644 --- a/modules/mux/include/mux/mux.h +++ b/modules/mux/include/mux/mux.h @@ -37,7 +37,7 @@ extern "C" { /// @brief Mux major version number. #define MUX_MAJOR_VERSION 0 /// @brief Mux minor version number. -#define MUX_MINOR_VERSION 77 +#define MUX_MINOR_VERSION 78 /// @brief Mux patch version number. #define MUX_PATCH_VERSION 0 /// @brief Mux combined version number. diff --git a/modules/mux/targets/host/include/host/host.h b/modules/mux/targets/host/include/host/host.h index 4b1699817..3aa8876e1 100644 --- a/modules/mux/targets/host/include/host/host.h +++ b/modules/mux/targets/host/include/host/host.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Host major version number. #define HOST_MAJOR_VERSION 0 /// @brief Host minor version number. -#define HOST_MINOR_VERSION 77 +#define HOST_MINOR_VERSION 78 /// @brief Host patch version number. #define HOST_PATCH_VERSION 0 /// @brief Host combined version number. diff --git a/modules/mux/targets/riscv/include/riscv/riscv.h b/modules/mux/targets/riscv/include/riscv/riscv.h index 4614ba456..6b29800ca 100644 --- a/modules/mux/targets/riscv/include/riscv/riscv.h +++ b/modules/mux/targets/riscv/include/riscv/riscv.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Riscv major version number. #define RISCV_MAJOR_VERSION 0 /// @brief Riscv minor version number. -#define RISCV_MINOR_VERSION 77 +#define RISCV_MINOR_VERSION 78 /// @brief Riscv patch version number. #define RISCV_PATCH_VERSION 0 /// @brief Riscv combined version number. diff --git a/modules/mux/tools/api/mux.xml b/modules/mux/tools/api/mux.xml index b88baca81..f07b49497 100644 --- a/modules/mux/tools/api/mux.xml +++ b/modules/mux/tools/api/mux.xml @@ -39,7 +39,7 @@ SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception ${FUNCTION_PREFIX}_MAJOR_VERSION0 ${Function_Prefix} major version number. - ${FUNCTION_PREFIX}_MINOR_VERSION77 + ${FUNCTION_PREFIX}_MINOR_VERSION78 ${Function_Prefix} minor version number. ${FUNCTION_PREFIX}_PATCH_VERSION0 ${Function_Prefix} patch version number.