Skip to content

Commit

Permalink
Merge pull request #87 from frasercrmck/lower-to-mux-builtins
Browse files Browse the repository at this point in the history
[compiler] Start to combine 'ReplaceXXX' passes
  • Loading branch information
frasercrmck authored Aug 15, 2023
2 parents 5b0365a + a16085d commit f2e1c93
Show file tree
Hide file tree
Showing 36 changed files with 545 additions and 563 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Upgrade guidance:
parameters in kernel functions with i32 parameters via a wrapper function.
The `host` target as a consequence now passes samplers to kernels as 32-bit
integer arguments, not as integer arguments disguised as pointer values.
* The `compiler::utils::ReplaceBarriersPass` has been replaced with the
`compiler::utils::LowerToMuxBuiltinsPass`.

## Version 3.0.0

Expand Down
38 changes: 27 additions & 11 deletions doc/modules/compiler/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ Their job is three-fold:
supplied to either pass on construction to encode this metadata. If not set,
the default ``xyz`` order is used.

LowerToMuxBuiltinsPass
----------------------

This pass replaces calls to the language implementation's builtin functions
with alternative sequences of instructions involving ComputeMux builtins.

Not all builtins **must** be lowered, but any builtins which can be
re-expressed in the terms of the following ComputeMux builtins **should** be
lowered:

* Sync builtins: ``__mux_mem_barrier``, ``__mux_(sub|work)_group_barrier``.
* Work-item builtins: ``__mux_get_local_id``, ``__mux_get_group_id``, etc.
* Group builtins: ``__mux_(sub|work)_group_(any|all|scan|reduce|broadcast)``.

Targets **must** lower any language builtins which **can** be expressed in
terms of these ComputeMux builtins in order for other ComputeMux compiler
passes to corectly recognise the program semantics.

This is because these builtins have special semantics that the compiler and
LLVM are generally unable to intuit using built-in properties of functions in
LLVM (e.g., attributes). They generally have some meaning "across" other
invocations of the same program or that influence the behaviour of other
invocations running in parallel.

See the :ref:`full list of builtins
<specifications/mux-compiler-spec:Builtins>` for more information.

AlignModuleStructsPass
----------------------

Expand Down Expand Up @@ -1048,17 +1075,6 @@ Removing this information is useful for debugging since the backend is less
likely to optimize away variables in the stack no longer used, as a result this
pass should only be run on debug builds of the module.

ReplaceBarriersPass
-------------------

Replaces calls to OpenCL's mangled barrier function with the appropriate
``__mux`` memory barrier, deduced based on the flags passed to the barrier
call. It covers the following barrier functions:

* ``_Z7barrierj``
* ``_Z18work_group_barrierj``
* ``_Z18work_group_barrierjj``

RemoveFencesPass
----------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <compiler/utils/simple_callback_pass.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <metadata/handler/vectorize_info_metadata.h>
#include <multi_llvm/optional_helper.h>
#include <refsi_g1_wi/refsi_pass_machinery.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <metadata/handler/vectorize_info_metadata.h>
#include <multi_llvm/optional_helper.h>
#include <refsi_m1/refsi_pass_machinery.h>
Expand Down
35 changes: 0 additions & 35 deletions modules/compiler/builtins/include/builtins/clbuiltins-3.0.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,6 @@
#include <abacus/abacus_integer.h>
#undef ABACUS_ENABLE_OPENCL_3_0_BUILTINS

extern size_t __attribute__((pure)) __mux_get_local_linear_id(void);
extern size_t __attribute__((pure)) __mux_get_global_linear_id(void);
extern size_t __attribute__((pure)) __mux_get_enqueued_local_size(uint x);
extern uint __attribute__((pure)) __mux_get_sub_group_id(void);
extern uint __attribute__((pure)) __mux_get_num_sub_groups(void);
extern uint __attribute__((pure)) __mux_get_max_sub_group_size(void);

size_t __CL_WORK_ITEM_ATTRIBUTES get_local_linear_id(void) {
return __mux_get_local_linear_id();
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_global_linear_id(void) {
return __mux_get_global_linear_id();
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_enqueued_local_size(uint x) {
return __mux_get_enqueued_local_size(x);
}

uint __CL_WORK_ITEM_ATTRIBUTES get_max_sub_group_size(void) {
return __mux_get_max_sub_group_size();
}

uint __CL_WORK_ITEM_ATTRIBUTES get_num_sub_groups(void) {
return __mux_get_num_sub_groups();
}

uint __CL_WORK_ITEM_ATTRIBUTES get_enqueued_num_sub_groups(void) {
return get_num_sub_groups();
}

uint __CL_WORK_ITEM_ATTRIBUTES get_sub_group_id(void) {
return __mux_get_sub_group_id();
}

void __CL_BARRIER_ATTRIBUTES sub_group_barrier(cl_mem_fence_flags flags) {
(void)flags;
}
Expand Down
40 changes: 0 additions & 40 deletions modules/compiler/builtins/include/builtins/clbuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,51 +36,11 @@
extern bool __attribute__((const)) __mux_isftz(void);
extern bool __attribute__((const)) __mux_usefast(void);
extern bool __attribute__((const)) __mux_isembeddedprofile(void);
extern size_t __attribute__((pure)) __mux_get_global_size(uint x);
extern size_t __attribute__((pure)) __mux_get_global_id(uint x);
extern size_t __attribute__((pure)) __mux_get_global_offset(uint x);
extern size_t __attribute__((pure)) __mux_get_local_size(uint x);
extern size_t __attribute__((pure)) __mux_get_local_id(uint x);
extern size_t __attribute__((pure)) __mux_get_num_groups(uint x);
extern size_t __attribute__((pure)) __mux_get_group_id(uint x);
extern uint __attribute__((pure)) __mux_get_work_dim(void);

bool __CL_CONST_ATTRIBUTES __abacus_isftz() { return __mux_isftz(); }
bool __CL_CONST_ATTRIBUTES __abacus_usefast() { return __mux_usefast(); }
bool __CL_CONST_ATTRIBUTES __abacus_isembeddedprofile() {
return __mux_isembeddedprofile();
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_global_size(uint x) {
return __mux_get_global_size(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_global_id(uint x) {
return __mux_get_global_id(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_global_offset(uint x) {
return __mux_get_global_offset(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_local_size(uint x) {
return __mux_get_local_size(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_local_id(uint x) {
return __mux_get_local_id(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_num_groups(uint x) {
return __mux_get_num_groups(x);
}

size_t __CL_WORK_ITEM_ATTRIBUTES get_group_id(uint x) {
return __mux_get_group_id(x);
}

uint __CL_WORK_ITEM_ATTRIBUTES get_work_dim(void) {
return __mux_get_work_dim();
}

#endif // OCL_CLBUILTINS_H_INCLUDED
43 changes: 0 additions & 43 deletions modules/compiler/builtins/scripts/generate_header_30.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,6 @@ function cxx_unsafe_end()
[[ "cxx" != "$generated_output_type" ]] && force_cxx_unsafe_end
}

function all_work_item()
{
echo "extern size_t __attribute__((pure)) __mux_get_local_linear_id(void);"
echo "extern size_t __attribute__((pure)) __mux_get_global_linear_id(void);"
echo "extern size_t __attribute__((pure)) __mux_get_enqueued_local_size(uint x);"
echo "extern uint __attribute__((pure)) __mux_get_sub_group_id(void);"
echo "extern uint __attribute__((pure)) __mux_get_num_sub_groups(void);"
echo "extern uint __attribute__((pure)) __mux_get_max_sub_group_size(void);"
echo ""
echo "size_t __CL_WORK_ITEM_ATTRIBUTES get_local_linear_id(void) {"
echo " return __mux_get_local_linear_id();"
echo "}"
echo ""
echo "size_t __CL_WORK_ITEM_ATTRIBUTES get_global_linear_id(void) {"
echo " return __mux_get_global_linear_id();"
echo "}"
echo ""
echo "size_t __CL_WORK_ITEM_ATTRIBUTES get_enqueued_local_size(uint x) {"
echo " return __mux_get_enqueued_local_size(x);"
echo "}"
echo ""
}

function all_ctz()
{
echo "#if __OPENCL_C_VERSION__ >= 200"
Expand Down Expand Up @@ -483,24 +460,6 @@ function all_sub_group()
echo "int __CL_BARRIER_ATTRIBUTES sub_group_all(int predicate);"
echo "int __CL_BARRIER_ATTRIBUTES sub_group_any(int predicate);"
echo ""
elif [[ "cl" == "$generated_output_type" ]]
then
echo "uint __CL_WORK_ITEM_ATTRIBUTES get_max_sub_group_size(void) {"
echo " return __mux_get_max_sub_group_size();"
echo "}"
echo ""
echo "uint __CL_WORK_ITEM_ATTRIBUTES get_num_sub_groups(void) {"
echo " return __mux_get_num_sub_groups();"
echo "}"
echo ""
echo "uint __CL_WORK_ITEM_ATTRIBUTES get_enqueued_num_sub_groups(void) {"
echo " return get_num_sub_groups();"
echo "}"
echo ""
echo "uint __CL_WORK_ITEM_ATTRIBUTES get_sub_group_id(void) {"
echo " return __mux_get_sub_group_id();"
echo "}"
echo ""
fi

if [[ "header" == "$generated_output_type" ]]
Expand Down Expand Up @@ -699,8 +658,6 @@ function output_for_type()

header > "$outputFile"

[[ "cl" == "$generated_output_type" ]] && all_work_item >> "$outputFile"

[[ "header" == "$generated_output_type" ]] && all_ctz >> "$outputFile"

[[ "header" == "$generated_output_type" ]] && all_typedefs >> "$outputFile"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <metadata/handler/vectorize_info_metadata.h>
#include <multi_llvm/optional_helper.h>
#include <optional>
Expand Down
1 change: 1 addition & 0 deletions modules/compiler/riscv/source/riscv_pass_machinery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <metadata/handler/vectorize_info_metadata.h>
#include <multi_llvm/optional_helper.h>
#include <riscv/ir_to_builtins_pass.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <compiler/utils/fixup_calling_convention_pass.h>
#include <compiler/utils/handle_barriers_pass.h>
#include <compiler/utils/link_builtins_pass.h>
#include <compiler/utils/lower_to_mux_builtins_pass.h>
#include <compiler/utils/make_function_name_unique_pass.h>
#include <compiler/utils/materialize_absent_work_item_builtins_pass.h>
#include <compiler/utils/metadata_analysis.h>
Expand All @@ -55,9 +56,7 @@
#include <compiler/utils/replace_address_space_qualifier_functions_pass.h>
#include <compiler/utils/replace_async_copies_pass.h>
#include <compiler/utils/replace_atomic_funcs_pass.h>
#include <compiler/utils/replace_barriers_pass.h>
#include <compiler/utils/replace_c11_atomic_funcs_pass.h>
#include <compiler/utils/replace_group_funcs_pass.h>
#include <compiler/utils/replace_local_module_scope_variables_pass.h>
#include <compiler/utils/replace_mem_intrinsics_pass.h>
#include <compiler/utils/replace_mux_math_decls_pass.h>
Expand Down Expand Up @@ -142,9 +141,9 @@ Expected<StringRef> parseMakeFunctionNameUniquePassOptions(StringRef Params) {
}

template <size_t N>
static ErrorOr<std::array<multi_llvm::Optional<uint64_t>, N>> parseIntList(
static ErrorOr<std::array<std::optional<uint64_t>, N>> parseIntList(
StringRef OptionVal, bool AllowNegative = false) {
std::array<multi_llvm::Optional<uint64_t>, N> Arr;
std::array<std::optional<uint64_t>, N> Arr;
for (unsigned i = 0; i < N; i++) {
int64_t Res;
StringRef Val;
Expand All @@ -165,7 +164,7 @@ static ErrorOr<std::array<multi_llvm::Optional<uint64_t>, N>> parseIntList(
if (!OptionVal.empty()) {
return std::errc::argument_list_too_long;
}
return std::array<multi_llvm::Optional<uint64_t>, 3U>(Arr);
return std::array<std::optional<uint64_t>, 3U>(Arr);
}

constexpr const char LocalSizesOptName[] = "max-local-sizes=";
Expand All @@ -179,7 +178,7 @@ parseEncodeBuiltinRangeMetadataPassOptions(StringRef Params) {
std::tie(ParamName, Params) = Params.split(';');

StringRef OptName;
std::array<multi_llvm::Optional<uint64_t>, 3> *SizesPtr = nullptr;
std::array<std::optional<uint64_t>, 3> *SizesPtr = nullptr;
if (ParamName.consume_front(LocalSizesOptName)) {
OptName = LocalSizesOptName;
SizesPtr = &Opts.MaxLocalSizes;
Expand Down
3 changes: 1 addition & 2 deletions modules/compiler/source/base/source/base_pass_registry.def
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,16 @@ MODULE_PASS("define-mux-builtins", compiler::utils::DefineMuxBuiltinsPass())
MODULE_PASS("define-mux-dma", compiler::utils::DefineMuxDmaPass())
MODULE_PASS("degenerate-sub-groups", compiler::utils::DegenerateSubGroupPass())
MODULE_PASS("link-builtins", compiler::utils::LinkBuiltinsPass())
MODULE_PASS("lower-to-mux-builtins", compiler::utils::LowerToMuxBuiltinsPass())

MODULE_PASS("missing-builtins",
compiler::utils::MaterializeAbsentWorkItemBuiltinsPass())
MODULE_PASS("prepare-barriers", compiler::utils::PrepareBarriersPass())
MODULE_PASS("rename-builtins", compiler::utils::RenameBuiltinsPass())
MODULE_PASS("replace-async-copies", compiler::utils::ReplaceAsyncCopiesPass())
MODULE_PASS("replace-atomic-funcs", compiler::utils::ReplaceAtomicFuncsPass())
MODULE_PASS("replace-barriers", compiler::utils::ReplaceBarriersPass())
MODULE_PASS("replace-c11-atomic-funcs",
compiler::utils::ReplaceC11AtomicFuncsPass())
MODULE_PASS("replace-group-funcs", compiler::utils::ReplaceGroupFuncsPass())
MODULE_PASS("replace-wgc", compiler::utils::ReplaceWGCPass())

MODULE_PASS("replace-module-scope-vars",
Expand Down
9 changes: 5 additions & 4 deletions modules/compiler/source/base/source/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@
#include <compiler/limits.h>
#include <compiler/utils/encode_builtin_range_metadata_pass.h>
#include <compiler/utils/llvm_global_mutex.h>
#include <compiler/utils/lower_to_mux_builtins_pass.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/pass_machinery.h>
#include <compiler/utils/replace_async_copies_pass.h>
#include <compiler/utils/replace_atomic_funcs_pass.h>
#include <compiler/utils/replace_barriers_pass.h>
#include <compiler/utils/replace_c11_atomic_funcs_pass.h>
#include <compiler/utils/replace_group_funcs_pass.h>
#include <compiler/utils/replace_target_ext_tys_pass.h>
#include <compiler/utils/simple_callback_pass.h>
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
Expand Down Expand Up @@ -91,6 +90,7 @@
#include <llvm/Transforms/Vectorize/LoopVectorize.h>
#include <llvm/Transforms/Vectorize/SLPVectorizer.h>
#include <multi_llvm/llvm_version.h>
#include <multi_llvm/multi_llvm.h>
#include <multi_llvm/optional_helper.h>
#include <multi_llvm/triple.h>
#include <mux/mux.hpp>
Expand Down Expand Up @@ -1690,6 +1690,9 @@ Result BaseModule::finalize(
pm.addPass(compiler::utils::ReplaceTargetExtTysPass(RTETOpts));
#endif

// Lower all language-level builtins with corresponding mux builtins
pm.addPass(compiler::utils::LowerToMuxBuiltinsPass());

pm.addPass(llvm::createModuleToFunctionPassAdaptor(
compiler::SoftwareDivisionPass()));
pm.addPass(compiler::ImageArgumentSubstitutionPass());
Expand All @@ -1712,8 +1715,6 @@ Result BaseModule::finalize(
}));

pm.addPass(compiler::utils::ReplaceC11AtomicFuncsPass());
pm.addPass(compiler::utils::ReplaceBarriersPass());
pm.addPass(compiler::utils::ReplaceGroupFuncsPass());

if (options.prevec_mode != compiler::PreVectorizationMode::NONE) {
llvm::FunctionPassManager fpm;
Expand Down
21 changes: 8 additions & 13 deletions modules/compiler/source/base/source/printf_replacement_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <base/printf_replacement_pass.h>
#include <compiler/utils/builtin_info.h>
#include <compiler/utils/device_info.h>
#include <compiler/utils/pass_functions.h>
#include <llvm/ADT/SmallPtrSet.h>
Expand Down Expand Up @@ -814,23 +815,17 @@ PreservedAnalyses compiler::PrintfReplacementPass::run(
return PreservedAnalyses::all();
}

auto &BI = AM.getResult<compiler::utils::BuiltinInfoAnalysis>(module);
const auto &DI = AM.getResult<compiler::utils::DeviceInfoAnalysis>(module);
// Set up the double support for this run of the pass
double_support = DI.double_capabilities != 0;

// get the type of size_t on the device
Type *size_t_type = compiler::utils::getSizeType(module);

FunctionType *FuncTy = FunctionType::get(
size_t_type, Type::getInt32Ty(module.getContext()), false);
Function *get_group_id = dyn_cast<Function>(
module.getOrInsertFunction("_Z12get_group_idj", FuncTy).getCallee());
assert(get_group_id && "Could not get or insert _Z12get_group_idj");
get_group_id->setCallingConv(CallingConv::SPIR_FUNC);

Function *get_num_groups = cast<Function>(
module.getOrInsertFunction("_Z14get_num_groupsj", FuncTy).getCallee());
get_num_groups->setCallingConv(CallingConv::SPIR_FUNC);
Function *get_group_id =
BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetGroupId, module);
assert(get_group_id && "Could not get or insert __mux_get_group_id");
Function *get_num_groups = BI.getOrDeclareMuxBuiltin(
compiler::utils::eMuxBuiltinGetNumGroups, module);
assert(get_num_groups && "Could not get or insert __mux_get_num_groups");

SmallVector<CallInst *, 32> callsToErase;

Expand Down
Loading

0 comments on commit f2e1c93

Please sign in to comment.