Skip to content

Commit

Permalink
Merge pull request #133 from frasercrmck/split-degenerate-and-no
Browse files Browse the repository at this point in the history
[compiler] Encode 'no' and 'degenerate' subgroups distinctly
  • Loading branch information
frasercrmck authored Sep 19, 2023
2 parents f56bed7 + d85eb55 commit e1dbfed
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 24 deletions.
8 changes: 8 additions & 0 deletions doc/specifications/mux-compiler-spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,14 @@ different stages of the pipeline:
* - ``"mux-barrier-schedule"="val"``
- Typically found on call sites. Determines the ordering of work-item
execution after a berrier. See the `BarrierSchedule` enum.
* - ``"mux-no-subgroups"``
- Marks the function as not explicitly using sub-groups (e.g., identified
by the use of known mux sub-group builtins). If a pass introduces the
explicit use of sub-groups to a function, it should remove this
attribute.
* - ``"mux-degenerate-subgroups"``
- Marks the function has using degenerate sub-groups (i.e. one sub-group
for the entire local work-group).

``mux-kernel`` attribute
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
3 changes: 3 additions & 0 deletions modules/compiler/source/base/source/pass_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <compiler/utils/replace_address_space_qualifier_functions_pass.h>
#include <compiler/utils/replace_mux_math_decls_pass.h>
#include <compiler/utils/replace_wgc_pass.h>
#include <compiler/utils/sub_group_usage_pass.h>
#include <llvm/ADT/StringSwitch.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Passes/PassBuilder.h>
Expand All @@ -50,6 +51,8 @@ void addPreVeczPasses(ModulePassManager &PM,
compiler::utils::OptimalBuiltinReplacementPass()));
}

PM.addPass(compiler::utils::SubgroupUsagePass());

if (tuner.degenerate_sub_groups) {
PM.addPass(compiler::utils::DegenerateSubGroupPass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ entry:
ret i32 %sqr
}

; CHECK: define spir_func i32 @sub_groups(i32 [[X5:%.+]]) #[[ATTR1:[0-9]+]] {
; CHECK: define spir_func i32 @sub_groups(i32 [[X5:%.+]]) #[[ATTR0:[0-9]+]] {
; CHECK: entry:
; CHECK: [[C5_1:%.+]] = call spir_func i32 @clone_this(i32 [[X5]])
; CHECK: [[C5_2:%.+]] = call spir_func i32 @shared(i32 [[X5]])
Expand All @@ -66,7 +66,7 @@ entry:
ret i32 %add
}

; CHECK: define spir_func i32 @no_sub_groups(i32 [[X4:%.+]]) #[[ATTR0:[0-9]+]] {
; CHECK: define spir_func i32 @no_sub_groups(i32 [[X4:%.+]]) #[[ATTR0]] {
; CHECK: entry:
; CHECK: [[R4:%.+]] = call spir_func i32 @shared(i32 [[X4]])
; CHECK: ret i32 [[R4]]
Expand Down Expand Up @@ -101,7 +101,6 @@ declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32)

attributes #0 = { "mux-kernel"="entry-point" }

; CHECK-DAG: attributes #[[ATTR0]] = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" }
; CHECK-DAG: attributes #[[ATTR1]] = { "mux-kernel"="entry-point" }
; CHECK-DAG: attributes #[[ATTR0]] = { "mux-kernel"="entry-point" }
; CHECK-DAG: attributes #[[ATTR2]] = { "mux-base-fn-name"="sub_groups" "mux-degenerate-subgroups" "mux-kernel"="entry-point" }
; CHECK-DAG: attributes #[[ATTR3]] = { "mux-base-fn-name"="clone_this" }
12 changes: 11 additions & 1 deletion modules/compiler/test/lit/passes/degenerate-sub-groups.ll
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ entry:
ret void
}

; CHECK: define spir_func void @no_sub_groups_test() [[ATTRS:#[0-9]+]] {
define spir_func void @no_sub_groups_test() #1 {
ret void
}

; CHECK-DAG: declare spir_func i1 @__mux_work_group_all_i1(i32, i1)
declare spir_func i1 @__mux_sub_group_all_i1(i1)
; CHECK-DAG: declare spir_func i1 @__mux_work_group_any_i1(i32, i1)
Expand Down Expand Up @@ -260,9 +265,14 @@ declare spir_func i32 @__mux_get_sub_group_local_id()
; CHECK-DAG: declare spir_func void @__mux_work_group_barrier(i32, i32, i32)
declare spir_func void @__mux_sub_group_barrier(i32, i32, i32)

; Check we didn't mark a function uses no sub-groups as having degenerate
; sub-groups.
; CHECK-DAG: attributes [[ATTRS]] = { "mux-kernel"="entry-point" "mux-no-subgroups" }
; CHECK-DAG: attributes #0 = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" }
attributes #0 = { "mux-kernel"="entry-point" }
attributes #1 = { "mux-kernel"="entry-point" "mux-no-subgroups" }

!0 = !{i32 13, i32 64, i32 64}
; CHECK: attributes #0 = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" }

!opencl.ocl.version = !{!1}

Expand Down
13 changes: 13 additions & 0 deletions modules/compiler/test/lit/passes/sub-group-analysis.ll
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ define spir_func void @function4() {
ret void
}

; This function does use sub-groups but has the 'mux-no-subgroups' attribute.
; Check that the analysis obeys the attribute rather than re-checking the
; function! A pass should not introduce new sub-group usage without removing
; that attribute!
; CHECK: Function 'function5' uses no sub-group builtins
define spir_kernel void @function5() #0 {
entry:
%lid = call i32 @__mux_get_sub_group_local_id()
ret void
}

declare i32 @__mux_get_sub_group_id()
declare i32 @__mux_get_sub_group_local_id()
declare i32 @__mux_sub_group_shuffle_i32(i32, i32)
Expand All @@ -86,3 +97,5 @@ declare i32 @__mux_get_max_sub_group_size()
declare void @__mux_set_sub_group_id(i32)
declare void @__mux_set_num_sub_groups(i32)
declare void @__mux_set_max_sub_group_size(i32)

attributes #0 = { "mux-no-subgroups" }
1 change: 1 addition & 0 deletions modules/compiler/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/scheduling.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/simple_callback_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/sub_group_analysis.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/sub_group_usage_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/target_extension_types.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/unique_opaque_structs_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/vectorization_factor.h
Expand Down
15 changes: 14 additions & 1 deletion modules/compiler/utils/include/compiler/utils/attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,20 @@ void setHasDegenerateSubgroups(llvm::Function &F);
/// @param[in] F Function to check.
bool hasDegenerateSubgroups(const llvm::Function &F);

/// @brief Returns the mux sub-group size for the current function.
/// @brief Marks a function as not explicitly using subgroups
///
/// May be set even with unresolved external functions, assuming those don't
/// explicitly use subgroups.
///
/// @param[in] F Function in which to encode the information.
void setHasNoExplicitSubgroups(llvm::Function &F);

/// @brief Returns whether the kernel does not explicitly use subgroups
///
/// @param[in] F Function to check.
bool hasNoExplicitSubgroups(const llvm::Function &F);

/// @brief Returns the mux subgroup size for the current function.
///
/// Currently always returns 1!
unsigned getMuxSubgroupSize(const llvm::Function &F);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) Codeplay Software Limited
//
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM
// Exceptions; you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

/// @file
///
/// Sub-group usage attribute pass.

#ifndef COMPILER_UTILS_SUB_GROUP_USAGE_PASS_H_INCLUDED
#define COMPILER_UTILS_SUB_GROUP_USAGE_PASS_H_INCLUDED

#include <compiler/utils/attributes.h>
#include <compiler/utils/sub_group_analysis.h>
#include <llvm/IR/PassManager.h>

namespace compiler {
namespace utils {

/// @brief Sets (caches) function attributes concerning sub-group usage,
/// assuming they will not become invalidated by later passes.
class SubgroupUsagePass final : public llvm::PassInfoMixin<SubgroupUsagePass> {
public:
explicit SubgroupUsagePass() {}

llvm::PreservedAnalyses run(llvm::Module &M,
llvm::ModuleAnalysisManager &AM) {
const auto &GSGI = AM.getResult<SubgroupAnalysis>(M);

for (auto &F : M) {
if (!F.isDeclaration() && !GSGI.usesSubgroups(F)) {
setHasNoExplicitSubgroups(F);
}
}
return llvm::PreservedAnalyses::all();
}
};

} // namespace utils
} // namespace compiler

#endif // COMPILER_UTILS_SUB_GROUP_USAGE_PASS_H_INCLUDED
11 changes: 11 additions & 0 deletions modules/compiler/utils/source/attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ bool hasDegenerateSubgroups(const Function &F) {
return Attr.isValid();
}

static constexpr const char *MuxNoSubgroupsAttrName = "mux-no-subgroups";

void setHasNoExplicitSubgroups(Function &F) {
F.addFnAttr(MuxNoSubgroupsAttrName);
}

bool hasNoExplicitSubgroups(const Function &F) {
Attribute Attr = F.getFnAttribute(MuxNoSubgroupsAttrName);
return Attr.isValid();
}

unsigned getMuxSubgroupSize(const llvm::Function &) {
// FIXME: The mux sub-group size is currently assumed to be 1 for all
// functions, kerrnels, and targets. This helper function is just to avoid
Expand Down
9 changes: 0 additions & 9 deletions modules/compiler/utils/source/degenerate_sub_group_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,6 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
// If there were no sub-group builtin calls we are done, exit early and
// preserve all analysis since we didn't touch the module.
if (usesSubgroups.empty()) {
for (auto *const K : kernels) {
// Set the attribute on every kernel that doesn't use any subgroups at
// all, so the vectorizer knows it can vectorize them however it likes.
setHasDegenerateSubgroups(*K);
}
return PreservedAnalyses::all();
}

Expand All @@ -344,10 +339,6 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
for (auto *const K : kernels) {
bool const subgroups = usesSubgroups.contains(K);
if (!subgroups) {
// Set the attribute on every kernel that doesn't use any subgroups at
// all, so the vectorizer knows it can vectorize them however it likes.
setHasDegenerateSubgroups(*K);

// No need to clone kernels that don't use any subgroup functions.
kernelsToClone.erase(K);
}
Expand Down
8 changes: 5 additions & 3 deletions modules/compiler/utils/source/metadata_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ GenericMetadataAnalysis::Result GenericMetadataAnalysis::run(
auto kernel_name = Fn.getName().str();
auto source_name = getOrigFnNameOrFnName(Fn).str();

bool degenerate_sub_groups = compiler::utils::hasDegenerateSubgroups(Fn);
bool degenerate_or_no_sub_groups =
compiler::utils::hasDegenerateSubgroups(Fn) ||
compiler::utils::hasNoExplicitSubgroups(Fn);
FixedOrScalableQuantity<uint32_t> sub_group_size(
degenerate_sub_groups ? 0 : 1, false);
degenerate_or_no_sub_groups ? 0 : 1, false);
// If there are no degenerate sub-groups, whole-function vectorization
// multiplies the sub-group size.
if (!degenerate_sub_groups) {
if (!degenerate_or_no_sub_groups) {
if (auto vf_info = parseWrapperFnMetadata(Fn)) {
VectorizationFactor vf = vf_info->first.vf;
sub_group_size =
Expand Down
9 changes: 9 additions & 0 deletions modules/compiler/utils/source/sub_group_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <compiler/utils/attributes.h>
#include <compiler/utils/builtin_info.h>
#include <compiler/utils/sub_group_analysis.h>
#include <llvm/ADT/PriorityWorklist.h>
Expand All @@ -33,6 +34,14 @@ GlobalSubgroupInfo::GlobalSubgroupInfo(Module &M, BuiltinInfo &BI) : BI(BI) {
continue;
}
auto SGI = std::make_unique<SubgroupInfo>();

// Assume the 'mux-no-subgroups' attribute is correct. If a pass introduces
// the use of sub-groups, then it should remove the attribute itself!
if (hasNoExplicitSubgroups(F)) {
FunctionMap.insert({&F, std::move(SGI)});
continue;
}

for (auto &BB : F) {
for (const auto &I : BB) {
if (auto *const CI = dyn_cast<CallInst>(&I)) {
Expand Down
4 changes: 2 additions & 2 deletions modules/mux/source/hal/include/mux/hal/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ struct kernel_variant_s {
/// Note that the last sub-group in a work-group may be smaller than this
/// value.
/// * If one, denotes a trivial sub-group.
/// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the
/// work-group at enqueue time).
/// * If zero, denotes either no sub-groups or a 'degenerate' sub-group
/// (i.e., the size of the work-group at enqueue time).
uint32_t sub_group_size = 0;
};

Expand Down
4 changes: 2 additions & 2 deletions modules/mux/targets/host/include/host/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ struct binary_kernel_s {
/// Note that the last sub-group in a work-group may be smaller than this
/// value.
/// * If one, denotes a trivial sub-group.
/// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the
/// work-group at enqueue time).
/// * If zero, denotes either no sub-groups or a 'degenerate' sub-group
/// (i.e., the size of the work-group at enqueue time).
uint32_t sub_group_size;
};

Expand Down
4 changes: 2 additions & 2 deletions modules/utils/targets/host/include/host/utils/jit_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ struct jit_kernel_s {
/// Note that the last sub-group in a work-group may be smaller than this
/// value.
/// * If one, denotes a trivial sub-group.
/// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the
/// work-group at enqueue time).
/// * If zero, denotes either no sub-groups or a 'degenerate' sub-group
/// (i.e., the size of the work-group at enqueue time).
uint32_t sub_group_size;
};

Expand Down

0 comments on commit e1dbfed

Please sign in to comment.