Skip to content

Commit

Permalink
Merge pull request #80 from frasercrmck/mux-group-builtins
Browse files Browse the repository at this point in the history
[compiler] Introduce mux builtins for 'group operations'
  • Loading branch information
frasercrmck authored Aug 8, 2023
2 parents f297a2e + 7c64b47 commit 166d230
Show file tree
Hide file tree
Showing 41 changed files with 2,660 additions and 172 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Upgrade guidance:
* The mux spec has been bumped:
* 0.77.0: to loosen the requirements on the mux `event` type used by
DMA builtins.
* 0.78.0: to introduce mux builtins for sub-group, work-group, and
vector-group operations.
* The `compiler::ImageArgumentSubstitutionPass` now replaces sampler typed
parameters in kernel functions with i32 parameters via a wrapper function.
The `host` target as a consequence now passes samplers to kernels as 32-bit
Expand Down
5 changes: 5 additions & 0 deletions doc/modules/mux/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ version increases mean backward compatible bug fixes have been applied.
Versions prior to 1.0.0 may contain breaking changes in minor
versions as the API is still under development.

0.78.0
------

* Added sub-group, work-group, and vector-group operation builtins.

0.77.0
------

Expand Down
114 changes: 113 additions & 1 deletion doc/specifications/mux-compiler-spec.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ComputeMux Compiler Specification
=================================

This is version 0.77.0 of the specification.
This is version 0.78.0 of the specification.

ComputeMux is Codeplay’s proprietary API for executing compute workloads across
heterogeneous devices. ComputeMux is an extremely lightweight,
Expand Down Expand Up @@ -922,6 +922,10 @@ A Mux implementation **shall** provide definitions for these builtin functions.
the current work-group.
* ``i32 __mux_get_max_sub_group_size()`` - Returns the maximum subgroup size
in the current kernel.
* ``i32 __mux_get_sub_group_size()`` - Returns the number of invocations in the
subgroup.
* ``i32 __mux_get_sub_group_local_id()`` - Returns the unique invocation ID
within the current sub-group.
* ``size_t __mux_get_group_id(i32 %i)`` - Returns the unique work-group
identifier for the ``%i``'th dimension.
* ``i32 __mux_get_work_dim()`` - Returns the number of dimensions in
Expand Down Expand Up @@ -997,6 +1001,114 @@ A Mux implementation **shall** provide definitions for these builtin functions.
as ``__mux_mem_barrier(i32 %scope, i32 %semantics)``. See `below
<#memory-and-control-barriers>`__ for more information.

Group operation builtins
~~~~~~~~~~~~~~~~~~~~~~~~

ComputeMux defines a variety of builtins to handle operations across a
sub-group, work-group, or *vector group*.

The builtin functions are overloadable and are mangled according to the type of
operand they operate on.

Each *work-group* operation takes as its first parameter a 32-bit integer
barrier identifier (``i32 %id``). Note that if barriers are used to implement
these operations, implementations **must** ensure uniqueness of these IDs
themselves, e.g., by running the ``compiler::utils::PrepareBarriersPass``. The
barrier identifier parameter is not mangled.

.. note::

The sub-group and work-group builtins are all **uniform**, that is, the
behaviour is undefined unless all invocations in the group reach this point
of execution.

Future versions of ComputeMux **may** add **non-uniform** versions of these
builtins.

The groups are defined as:

* ``work-group`` - a group of invocations running together as part of an ND
range. These builtins **must** only take scalar values.
* ``sub-group`` - a subset of invocations in a work-group which can synchronize
and share data efficiently. ComputeMux leaves the choice of sub-group size
and implementation to the target; ComputeMux only defines these builtins with
a "trivial" sub-group size of 1. These builtins **must** only take scalar
values.
* ``vec-group`` - a software level group of invocations processing data in
parallel *on a single invocation*. This allows the compiler to simulate a
sub-group without any hardware sub-group support (e.g., through
vectorization). These builtins **may** take scalar *or vector* values. The
scalar versions of these builtins are essentially identical to the
corresponding ``sub-group`` builtins with a sub-group size of 1.


``any``/``all`` builtins
++++++++++++++++++++++++

The ``any`` and ``all`` builtins return ``true`` if any/all of their operands
are ``true`` and ``false`` otherwise.

.. code:: llvm
i1 @__mux_sub_group_any_i1(i1 %x)
i1 @__mux_work_group_any_i1(i32 %id, i1 %x)
i1 @__mux_vec_group_any_v4i1(<4 x i1> %x)
``broadcast`` builtins
++++++++++++++++++++++

The ``broadcast`` builtins broadcast the value corresponding to the local ID to
the result of all invocations in the group. The sub-group version of this
builtin takes an ``i32`` sub-group linear ID to identify the invocation to
broadcast, and the work-group version take three ``size_t`` indices to locate
the value to broadcast. Unused indices (e.g., in lower-dimension kernels)
**must** be set to zero - this is the same value returned by
``__mux_get_global_id`` for out-of-range dimensions.

.. code:: llvm
i64 @__mux_sub_group_broadcast_i64(i64 %val, i32 %sg_lid)
i32 @__mux_work_group_broadcast_i32(i32 %id, i32 %val, i64 %lidx, i64 %lidy, i64 %lidz)
i64 @__mux_vec_group_broadcast_v2i64(<2 x i64> %val, i32 %vec_id)
``reduce`` and ``scan`` builtins
++++++++++++++++++++++++++++++++

The ``reduce`` and ``scan`` builtins return the result of the group operation
for all values of their parameters specified by invocations in the group.

Scans may be either ``inclusive`` or ``exclusive``. Inclusive scans perform the
operation over all invocations in the group. Exclusive scans perform the
operation over the operation's identity value and all but the final invocation
in the group.

The group operation may be specified as one of:

* ``add``/``fadd`` - integer/floating-point addition.
* ``mul``/``fmul`` - integer/floating-point multiplication.
* ``smin``/``umin``/``fmin`` - signed integer/unsigned integer/floating-point minimum.
* ``smax``/``umax``/``fmax`` - signed integer/unsigned integer/floating-point maximum.
* ``and``/``or``/``xor`` - bitwise ``and``/``or``/``xor``.
* ``logical_and``/``logical_or``/``logical_xor`` - logical ``and``/``or``/``xor``.

Examples:

.. code:: llvm
i32 @__mux_sub_group_reduce_add_i32(i32 %val)
i32 @__mux_work_group_reduce_add_i32(i32 %id, i32 %val)
float @__mux_work_group_reduce_fadd_f32(i32 %id, float %val)
i32 @__mux_sub_group_scan_inclusive_mul_i32(i32 %val)
i32 @__mux_work_group_scan_inclusive_mul_i32(i32 %id, i32 %val)
float @__mux_work_group_scan_inclusive_fmul_f32(i32 %id, float %val)
i64 @__mux_sub_group_scan_exclusive_mul_i64(i64 %val)
i64 @__mux_work_group_scan_exclusive_mul_i64(i32 %id, i64 %val)
double @__mux_work_group_scan_exclusive_fmul_f64(i32 %id, double %val)
i64 @__mux_vec_group_scan_exclusive_mul_nxv1i64(<vscale x 1 x i64> %val)
Memory and Control Barriers
---------------------------
Expand Down
2 changes: 1 addition & 1 deletion doc/specifications/mux-runtime-spec.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ComputeMux Runtime Specification
================================

This is version 0.77.0 of the specification.
This is version 0.78.0 of the specification.

ComputeMux is Codeplay’s proprietary API for executing compute workloads across
heterogeneous devices. ComputeMux is an extremely lightweight,
Expand Down
14 changes: 8 additions & 6 deletions doc/tutorials/custom-lowering-work-item-builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ The code for this example is as follows:
.. code:: cpp
class MyMuxImpl : public utils::BIMuxInfoConcept {
virtual llvm::Function *defineMuxBuiltin(utils::BuiltinID ID,
llvm::Module &M) override {
virtual llvm::Function *defineMuxBuiltin(
utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override {
if (ID != utils::eMuxBuiltinGetLocalId) {
return BIMuxInfoConcept::defineMuxBuiltin(ID, M);
return BIMuxInfoConcept::defineMuxBuiltin(ID, M, OverloadInfo);
}
llvm::Function *F =
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID));
Expand Down Expand Up @@ -390,8 +391,9 @@ data beyond the view of ComputeMux, e.g., in the driver or the HAL.
return List;
}
virtual llvm::Function *defineMuxBuiltin(utils::BuiltinID ID,
llvm::Module &M) override {
virtual llvm::Function *defineMuxBuiltin(
utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override {
if (ID == utils::eMuxBuiltinGetLocalId) {
llvm::Function *F =
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID));
Expand All @@ -407,7 +409,7 @@ data beyond the view of ComputeMux, e.g., in the driver or the HAL.
B.CreateRet(std::prev(F->arg_end()));
return F;
}
return BIMuxInfoConcept::defineMuxBuiltin(ID, M);
return BIMuxInfoConcept::defineMuxBuiltin(ID, M, OverloadInfo);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class RefSiG1BIMuxInfo : public compiler::utils::BIMuxInfoConcept {
public:
static llvm::StructType *getExecStateStruct(llvm::Module &M);

llvm::Function *getOrDeclareMuxBuiltin(compiler::utils::BuiltinID ID,
llvm::Module &M) override;
llvm::Function *getOrDeclareMuxBuiltin(
compiler::utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override;

llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID,
llvm::Module &M) override;
llvm::Function *defineMuxBuiltin(
compiler::utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override;
};

} // namespace refsi_g1_wi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ StructType *RefSiG1BIMuxInfo::getExecStateStruct(Module &M) {
}

Function *RefSiG1BIMuxInfo::getOrDeclareMuxBuiltin(
compiler::utils::BuiltinID ID, Module &M) {
auto *F = compiler::utils::BIMuxInfoConcept::getOrDeclareMuxBuiltin(ID, M);
compiler::utils::BuiltinID ID, Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo) {
auto *F = compiler::utils::BIMuxInfoConcept::getOrDeclareMuxBuiltin(
ID, M, OverloadInfo);
if (!F) {
return F;
}
Expand All @@ -78,7 +80,8 @@ Function *RefSiG1BIMuxInfo::getOrDeclareMuxBuiltin(
}

Function *RefSiG1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
Module &M) {
Module &M,
ArrayRef<Type *> OverloadInfo) {
assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) &&
"Only handling mux builtins");
Function *F =
Expand Down Expand Up @@ -213,7 +216,8 @@ Function *RefSiG1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
}

if (ID != compiler::utils::eMuxBuiltinGetLocalId) {
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M);
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M,
OverloadInfo);
}

Optional<unsigned> ParamIdx;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ namespace refsi_m1 {

class RefSiM1BIMuxInfo : public compiler::utils::BIMuxInfoConcept {
public:
llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID,
llvm::Module &M) override;
llvm::Function *defineMuxBuiltin(
compiler::utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override;
};

} // namespace refsi_m1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ void defineRefSiDmaWait(Function &F) {
}

Function *RefSiM1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
Module &M) {
Module &M,
ArrayRef<Type *> OverloadInfo) {
assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) &&
"Only handling mux builtins");
auto FnName = compiler::utils::BuiltinInfo::getMuxBuiltinName(ID);
Expand All @@ -390,7 +391,8 @@ Function *RefSiM1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,

switch (ID) {
default:
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M);
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M,
OverloadInfo);
case compiler::utils::eMuxBuiltinDMARead1D:
case compiler::utils::eMuxBuiltinDMAWrite1D:
defineRefSiDma1D(*F, *this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <compiler/utils/replace_atomic_funcs_pass.h>
#include <compiler/utils/replace_barriers_pass.h>
#include <compiler/utils/replace_c11_atomic_funcs_pass.h>
#include <compiler/utils/replace_group_funcs_pass.h>
#include <compiler/utils/replace_local_module_scope_variables_pass.h>
#include <compiler/utils/replace_mem_intrinsics_pass.h>
#include <compiler/utils/replace_mux_math_decls_pass.h>
Expand Down
1 change: 1 addition & 0 deletions modules/compiler/source/base/source/base_pass_registry.def
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ MODULE_PASS("replace-atomic-funcs", compiler::utils::ReplaceAtomicFuncsPass())
MODULE_PASS("replace-barriers", compiler::utils::ReplaceBarriersPass())
MODULE_PASS("replace-c11-atomic-funcs",
compiler::utils::ReplaceC11AtomicFuncsPass())
MODULE_PASS("replace-group-funcs", compiler::utils::ReplaceGroupFuncsPass())
MODULE_PASS("replace-wgc", compiler::utils::ReplaceWGCPass())

MODULE_PASS("replace-module-scope-vars",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ class HostBIMuxInfo : public compiler::utils::BIMuxInfoConcept {
llvm::SmallVector<compiler::utils::BuiltinInfo::SchedParamInfo, 4>
getMuxSchedulingParameters(llvm::Module &M) override;

llvm::Function *defineMuxBuiltin(compiler::utils::BuiltinID ID,
llvm::Module &M) override;
llvm::Function *defineMuxBuiltin(
compiler::utils::BuiltinID ID, llvm::Module &M,
llvm::ArrayRef<llvm::Type *> OverloadInfo) override;

llvm::Value *initializeSchedulingParamForWrappedKernel(
const compiler::utils::BuiltinInfo::SchedParamInfo &Info,
Expand Down
6 changes: 4 additions & 2 deletions modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ HostBIMuxInfo::getMuxSchedulingParameters(Module &M) {
}

Function *HostBIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
Module &M) {
Module &M,
ArrayRef<Type *> OverloadInfo) {
assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) &&
"Only handling mux builtins");
Function *F =
Expand All @@ -157,7 +158,8 @@ Function *HostBIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,

switch (ID) {
default:
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M);
return compiler::utils::BIMuxInfoConcept::defineMuxBuiltin(ID, M,
OverloadInfo);
case compiler::utils::eMuxBuiltinGetLocalSize:
ParamIdx = SchedParamIndices::SCHED;
DefaultVal = 1;
Expand Down
3 changes: 2 additions & 1 deletion modules/compiler/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

add_ca_executable(UnitCompiler
${CMAKE_CURRENT_SOURCE_DIR}/common.h
${CMAKE_CURRENT_SOURCE_DIR}/group_ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/info.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/library.cpp
Expand All @@ -30,7 +31,7 @@ target_include_directories(UnitCompiler PRIVATE
${PROJECT_SOURCE_DIR}/modules/compiler/include)

target_link_libraries(UnitCompiler PRIVATE cargo
compiler-static mux ca_gtest_main compiler-utils)
compiler-static mux ca_gtest_main compiler-base compiler-utils)

target_resources(UnitCompiler NAMESPACES ${BUILTINS_NAMESPACES})

Expand Down
25 changes: 25 additions & 0 deletions modules/compiler/test/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include <compiler/module.h>
#include <compiler/target.h>
#include <gtest/gtest.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/SourceMgr.h>
#include <mux/mux.h>
#include <mux/utils/helpers.h>

Expand Down Expand Up @@ -323,6 +326,28 @@ static inline std::vector<const compiler::Info *> deferrableCompilers() {
return deferrable_compilers;
}

/// @brief Fixture for testing behavior of the compiler with LLVM modules.
///
/// Tests based on this fixture should test the behavior of
/// LLVM-based APIs and transforms.
struct CompilerLLVMModuleTest : ::testing::Test {
void SetUp() override {}

std::unique_ptr<llvm::Module> parseModule(llvm::StringRef Assembly) {
llvm::SMDiagnostic Error;
auto M = llvm::parseAssemblyString(Assembly, Error, Context);

std::string ErrMsg;
llvm::raw_string_ostream OS(ErrMsg);
Error.print("", OS);
EXPECT_TRUE(M) << OS.str();

return M;
}

llvm::LLVMContext Context;
};

/// @brief Macro for instantiating test fixture parameterized over all
/// compiler targets available on the platform.
///
Expand Down
Loading

0 comments on commit 166d230

Please sign in to comment.