Skip to content

Commit

Permalink
Merge pull request #84 from frasercrmck/mux-group-builtins-in-pipeline
Browse files Browse the repository at this point in the history
[compiler] Enable mux subgroup/workgroup builtins
  • Loading branch information
frasercrmck authored Aug 10, 2023
2 parents 166d230 + 15459f9 commit a8d487b
Show file tree
Hide file tree
Showing 57 changed files with 1,053 additions and 2,234 deletions.
41 changes: 18 additions & 23 deletions doc/modules/compiler/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -983,14 +983,13 @@ and is provided the ``llvm::Module`` as a parameter.
ReplaceWGCPass
--------------

The ``ReplaceWGCPass`` provides software implementations of the OpenCL C
The ``ReplaceWGCPass`` provides software implementations of the ComputeMux
work-group collective builtins. Targets wishing to support work-group
collectives in software **may** run this pass. This pass makes heavy use of
barriers, so do not expect performance. Because it introduces barriers into the
module, this pass **must** be run before any barrier analysis or
materialization e.g., the `PrepareBarriersPass`_ and `HandleBarriersPass`_.


This pass introduces global variables into the module qualified with the
:ref:`local/Workgroup <overview/compiler/ir:Address Spaces>` address space and
therefore **must** be run before any pass that materializes ``__local``
Expand Down Expand Up @@ -1173,29 +1172,25 @@ such APIs, several of which are given here by way of example:
Sub-groups
----------

The implementation of OpenCL C sub-group builtins is split between several
files. A trivial implementation (meaning sub-group == work-item) is provided in
the builtins header ``modules/builtins/include/builtins/clbuiltins-3.0.h``.
Some builtins (i.e. ``get_max_sub_group_size``, ``get_num_sub_groups`` and
``get_sub_group_id``) are implemented in terms of ``__mux`` builtins since they
may require scheduling information to be passed to their parameter list on some
implementations. ``__mux_get_max_sub_group_size``,
``__mux_get_num_sub_groups``, ``__mux_get_sub_group_id`` and
``__mux_set_max_sub_group_size`` are defined in in
``modules/compiler/utils/source/define_mux_builtins_pass.cpp``.
A implementation of OpenCL C sub-group builtins is provided by the default
compiler pipeline.

The OpenCL C sub-group builtins are first translated into the corresponding
ComputeMux builtin functions. These functions are understood by the rest of the
compiler and can be identified and analyzed by the ``BuiltinInfo`` analysis.

A definition of these mux builtins for where the sub-group size is 1 is
provided by ``BIMuxInfoConcept`` used by the `DefineMuxBuiltinsPass`_.

Vectorized definitions of the various sub-group builtins are provided by the
VECZ pass which will overwrite the trivial definitions provided in the builtin
headers, so any target running VECZ (and the above passes) will be able to
support sub-groups. We still have to provide a fallback implementation (in
this case the trivial implementation defined in the builtin headers) in order
to accommodate for the situation where VECZ fails, or is disabled, in which
case the target still needs to support sub-groups since they are a device
feature.

If a target not running VECZ wishes to provide their own sub-group
implementation they should target the OpenCL C sub-group builtins directly,
there are no ``__mux`` builtins for sub-groups other than those defined above.
VECZ pass, so any target running VECZ (and the above passes) will be able to
support sub-groups of a larger size than 1. Note that VECZ does not currently
interact "on top of" the mux builtins - it replaces them in the functions it
vectorized. This is future work to allow the two to build on top of each other.

If a target wishes to provide their own sub-group implementation they should
provide a derived ``BIMuxInfoConcept`` and override ``defineMuxBuiltin`` for
the sub-group builtins.

Linker support
--------------
Expand Down
4 changes: 2 additions & 2 deletions doc/tutorials/custom-lowering-work-item-builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ The code for this example is as follows:
return BIMuxInfoConcept::defineMuxBuiltin(ID, M, OverloadInfo);
}
llvm::Function *F =
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID));
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo));
// Set some useful function attributes
setDefaultBuiltinAttributes(*F);
F->setLinkage(llvm::GlobalValue::InternalLinkage);
Expand Down Expand Up @@ -396,7 +396,7 @@ data beyond the view of ComputeMux, e.g., in the driver or the HAL.
llvm::ArrayRef<llvm::Type *> OverloadInfo = {}) override {
if (ID == utils::eMuxBuiltinGetLocalId) {
llvm::Function *F =
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID));
M.getFunction(utils::BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo));
// Set some useful function attributes
setDefaultBuiltinAttributes(*F);
// We additionally know that our function is readnone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ Function *RefSiG1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
ArrayRef<Type *> OverloadInfo) {
assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) &&
"Only handling mux builtins");
Function *F =
M.getFunction(compiler::utils::BuiltinInfo::getMuxBuiltinName(ID));
Function *F = M.getFunction(
compiler::utils::BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo));

// FIXME: We'd ideally want to declare it here to reduce pass
// inter-dependencies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ llvm::ModulePassManager RefSiG1PassMachinery::getLateTargetPasses() {
PM.addPass(riscv::IRToBuiltinReplacementPass());

if (env_var_opts.early_link_builtins) {
PM.addPass(compiler::utils::LinkBuiltinsPass(/*EarlyLinking*/ true));
PM.addPass(compiler::utils::LinkBuiltinsPass());
}

// Bit nasty, but we must schedule a run of the DefineMuxDmaPass to define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <compiler/utils/dma.h>
#include <llvm/ADT/StringSwitch.h>
#include <llvm/IR/Operator.h>
#include <multi_llvm/opaque_pointers.h>
#include <refsi_m1/refsi_mux_builtin_info.h>

Expand Down Expand Up @@ -379,7 +380,8 @@ Function *RefSiM1BIMuxInfo::defineMuxBuiltin(compiler::utils::BuiltinID ID,
ArrayRef<Type *> OverloadInfo) {
assert(compiler::utils::BuiltinInfo::isMuxBuiltinID(ID) &&
"Only handling mux builtins");
auto FnName = compiler::utils::BuiltinInfo::getMuxBuiltinName(ID);
auto FnName =
compiler::utils::BuiltinInfo::getMuxBuiltinName(ID, OverloadInfo);
Function *F = M.getFunction(FnName);

// FIXME: We'd ideally want to declare it here to reduce pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ llvm::ModulePassManager RefSiM1PassMachinery::getLateTargetPasses() {
PM.addPass(riscv::IRToBuiltinReplacementPass());

if (env_var_opts.early_link_builtins) {
PM.addPass(compiler::utils::LinkBuiltinsPass(/*EarlyLinking*/ true));
PM.addPass(compiler::utils::LinkBuiltinsPass());
}

// TODON'T temporary fix to get subgroup tests passing while we refactor
Expand Down
Loading

0 comments on commit a8d487b

Please sign in to comment.