Skip to content

Commit

Permalink
Merge pull request #124 from frasercrmck/subgroup-shuffles
Browse files Browse the repository at this point in the history
[compiler] Add mux sub-group shuffle builtins
  • Loading branch information
frasercrmck authored Sep 12, 2023
2 parents 9a68240 + 8a3d2f5 commit 142c2d4
Show file tree
Hide file tree
Showing 31 changed files with 1,362 additions and 395 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Upgrade guidance:
DMA builtins.
* 0.78.0: to introduce mux builtins for sub-group, work-group, and
vector-group operations.
* 0.79.0: to introduce mux builtins for sub-group shuffle operations.
* The `compiler::ImageArgumentSubstitutionPass` now replaces sampler typed
parameters in kernel functions with i32 parameters via a wrapper function.
The `host` target as a consequence now passes samplers to kernels as 32-bit
Expand Down
5 changes: 5 additions & 0 deletions doc/modules/mux/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ version increases mean backward compatible bug fixes have been applied.
Versions prior to 1.0.0 may contain breaking changes in minor
versions as the API is still under development.

0.79.0
------

* Added sub-group shuffle builtins.

0.78.0
------

Expand Down
104 changes: 99 additions & 5 deletions doc/specifications/mux-compiler-spec.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ComputeMux Compiler Specification
=================================

This is version 0.78.0 of the specification.
This is version 0.79.0 of the specification.

ComputeMux is Codeplay’s proprietary API for executing compute workloads across
heterogeneous devices. ComputeMux is an extremely lightweight,
Expand Down Expand Up @@ -915,15 +915,15 @@ A Mux implementation **shall** provide definitions for these builtin functions.
invocations within a work-group for the ``%i``'th dimension.
* ``size_t __mux_get_local_id(i32 %i)`` - Returns the unique local invocation
identifier for the ``%i``'th dimension.
* ``i32 __mux_get_sub_group_id()`` - Returns the subgroup ID.
* ``i32 __mux_get_sub_group_id()`` - Returns the sub-group ID.
* ``size_t __mux_get_num_groups(i32 %i)`` - Returns the number of work-groups
for the ``%i``'th dimension.
* ``i32 __mux_get_num_sub_groups()`` - Returns the number of subgroups for
* ``i32 __mux_get_num_sub_groups()`` - Returns the number of sub-groups for
the current work-group.
* ``i32 __mux_get_max_sub_group_size()`` - Returns the maximum subgroup size
* ``i32 __mux_get_max_sub_group_size()`` - Returns the maximum sub-group size
in the current kernel.
* ``i32 __mux_get_sub_group_size()`` - Returns the number of invocations in the
subgroup.
sub-group.
* ``i32 __mux_get_sub_group_local_id()`` - Returns the unique invocation ID
within the current sub-group.
* ``size_t __mux_get_group_id(i32 %i)`` - Returns the unique work-group
Expand Down Expand Up @@ -1110,6 +1110,100 @@ Examples:
i64 @__mux_vec_group_scan_exclusive_mul_nxv1i64(<vscale x 1 x i64> %val)
Sub-group ``shuffle`` builtin
+++++++++++++++++++++++++++++

The ``sub_group_shuffle`` builtin allows data to be arbitrarily transferred
between invocations in a sub-group. The data that is returned for this
invocation is the value of ``%val`` for the invocation identified by ``%lid``.

``%lid`` need not be the same value for all invocations in the sub-group.

.. code:: llvm
i32 @__mux_sub_group_shuffle_i32(i32 %val, i32 %lid)
Sub-group ``shuffle_up`` builtin
++++++++++++++++++++++++++++++++

The ``sub_group_shuffle_up`` builtin allows data to be transferred from an
invocation in the sub-group with a lower sub-group local invocation ID up to an
invocation in the sub-group with a higher sub-group local invocation ID.

The builtin has two operands: ``%prev`` and ``%curr``. To determine the result
of this builtin, first let ``SubgroupLocalInvocationId`` be equal to
``__mux_get_sub_group_local_id()``, let the signed shuffle index be equivalent
to this invocation’s ``SubgroupLocalInvocationId`` minus the specified
``%delta``, and ``MaxSubgroupSize`` be equal to
``__mux_get_max_sub_group_size()`` for the current kernel.

* If the shuffle index is greater than or equal to zero and less than the
``MaxSubgroupSize``, the result of this builtin is the value of the ``%curr``
operand for the invocation with ``SubgroupLocalInvocationId`` equal to the
shuffle index.

* If the shuffle index is less than zero but greater than or equal to the
negative ``MaxSubgroupSize``, the result of this builtin is the value of the
``%prev`` operand for the invocation with ``SubgroupLocalInvocationId`` equal
to the shuffle index plus the ``MaxSubgroupSize``.

All other values of the shuffle index are considered to be out-of-range.

``%delta`` need not be the same value for all invocations in the sub-group.

.. code:: llvm
i8 @__mux_sub_group_shuffle_up_i8(i8 %prev, i8 %curr, i32 %delta)
Sub-group ``shuffle_down`` builtin
++++++++++++++++++++++++++++++++++

The ``sub_group_shuffle_down`` builtin allows data to be transferred from an
invocation in the sub-group with a higher sub-group local invocation ID down to
a invocation in the sub-group with a lower sub-group local invocation ID.

The builtin has two operands: ``%curr`` and ``%next``. To determine the result
of this builtin , first let ``SubgroupLocalInvocationId`` be equal to
``__mux_get_sub_group_local_id()``, the unsigned shuffle index be equivalent to
the sum of this invocation’s ``SubgroupLocalInvocationId`` plus the specified
``%delta``, and ``MaxSubgroupSize`` be equal to
``__mux_get_max_sub_group_size()`` for the current kernel.

* If the shuffle index is less than the ``MaxSubgroupSize``, the result of this
builtin is the value of the ``%curr`` operand for the invocation with
``SubgroupLocalInvocationId`` equal to the shuffle index.

* If the shuffle index is greater than or equal to the ``MaxSubgroupSize`` but
less than twice the ``MaxSubgroupSize``, the result of this builtin is the
value of the ``%next`` operand for the invocation with
``SubgroupLocalInvocationId`` equal to the shuffle index minus the
``MaxSubgroupSize``. All other values of the shuffle index are considered to
be out-of-range.

All other values of the shuffle index are considered to be out-of-range.

``%delta`` need not be the same value for all invocations in the sub-group.

.. code:: llvm
float @__mux_sub_group_shuffle_down_f32(float %curr, float %next, i32 %delta)
Sub-group ``shuffle_xor`` builtin
+++++++++++++++++++++++++++++++++

These ``sub_group_shuffle_xor`` builtin allows for efficient sharing of data
between items within a sub-group.

The data that is returned for this invocation is the value of ``%val`` for the
invocation with sub-group local ID equal to this invocation’s sub-group local
ID XOR’d with the specified ``%xor_val``. If the result of the XOR is greater
than the current kernel's maximum sub-group size, then it is considered
out-of-range.

.. code:: llvm
double @__mux_sub_group_shuffle_xor_f64(double %val, i32 %xor_val)
Memory and Control Barriers
---------------------------

Expand Down
2 changes: 1 addition & 1 deletion doc/specifications/mux-runtime-spec.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ComputeMux Runtime Specification
================================

This is version 0.78.0 of the specification.
This is version 0.79.0 of the specification.

ComputeMux is Codeplay’s proprietary API for executing compute workloads across
heterogeneous devices. ComputeMux is an extremely lightweight,
Expand Down
38 changes: 38 additions & 0 deletions modules/compiler/spirv-ll/include/spirv-ll/opcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,44 @@ class OpGroupLogicalXorKHR
OpGroupLogicalXorKHR(OpCode const &other) : OpGroupOperation(other) {}
};

class OpSubgroupShuffle : public OpResult {
public:
OpSubgroupShuffle(OpCode const &other)
: OpResult(other, spv::OpSubgroupShuffleINTEL) {}
spv::Id Data() const { return getValueAtOffset(3); }
spv::Id InvocationId() const { return getValueAtOffset(4); }
static const spv::Op ClassCode = spv::OpSubgroupShuffleINTEL;
};

class OpSubgroupShuffleUp : public OpResult {
public:
OpSubgroupShuffleUp(OpCode const &other)
: OpResult(other, spv::OpSubgroupShuffleUpINTEL) {}
spv::Id Previous() const { return getValueAtOffset(3); }
spv::Id Current() const { return getValueAtOffset(4); }
spv::Id Delta() const { return getValueAtOffset(5); }
static const spv::Op ClassCode = spv::OpSubgroupShuffleUpINTEL;
};

class OpSubgroupShuffleDown : public OpResult {
public:
OpSubgroupShuffleDown(OpCode const &other)
: OpResult(other, spv::OpSubgroupShuffleDownINTEL) {}
spv::Id Current() const { return getValueAtOffset(3); }
spv::Id Next() const { return getValueAtOffset(4); }
spv::Id Delta() const { return getValueAtOffset(5); }
static const spv::Op ClassCode = spv::OpSubgroupShuffleDownINTEL;
};

class OpSubgroupShuffleXor : public OpResult {
public:
OpSubgroupShuffleXor(OpCode const &other)
: OpResult(other, spv::OpSubgroupShuffleXorINTEL) {}
spv::Id Data() const { return getValueAtOffset(3); }
spv::Id Value() const { return getValueAtOffset(4); }
static const spv::Op ClassCode = spv::OpSubgroupShuffleXorINTEL;
};

class OpReadPipe : public OpResult {
public:
OpReadPipe(OpCode const &other) : OpResult(other, spv::OpReadPipe) {}
Expand Down
95 changes: 95 additions & 0 deletions modules/compiler/spirv-ll/source/builder_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <compiler/utils/builtin_info.h>
#include <compiler/utils/target_extension_types.h>
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/SmallVector.h>
Expand Down Expand Up @@ -6758,6 +6759,100 @@ cargo::optional<Error> Builder::create<OpGroupLogicalXorKHR>(
return cargo::nullopt;
}

template <>
cargo::optional<Error> Builder::create<OpSubgroupShuffle>(
const OpSubgroupShuffle *op) {
std::string muxBuiltinName = "__mux_sub_group_shuffle_";

auto *data = module.getValue(op->Data());
SPIRV_LL_ASSERT_PTR(data);

auto *invocation_id = module.getValue(op->InvocationId());
SPIRV_LL_ASSERT_PTR(invocation_id);

auto retTy = module.getType(op->IdResultType());
SPIRV_LL_ASSERT_PTR(retTy);

muxBuiltinName += compiler::utils::BuiltinInfo::getMangledTypeStr(retTy);

auto *const ci = createBuiltinCall(
muxBuiltinName, retTy, {data, invocation_id}, /*convergent*/ true);
module.addID(op->IdResult(), op, ci);
return cargo::nullopt;
}

template <>
cargo::optional<Error> Builder::create<OpSubgroupShuffleUp>(
const OpSubgroupShuffleUp *op) {
std::string muxBuiltinName = "__mux_sub_group_shuffle_up_";

auto *previous = module.getValue(op->Previous());
SPIRV_LL_ASSERT_PTR(previous);

auto *current = module.getValue(op->Current());
SPIRV_LL_ASSERT_PTR(current);

auto *delta = module.getValue(op->Delta());
SPIRV_LL_ASSERT_PTR(delta);

auto retTy = module.getType(op->IdResultType());
SPIRV_LL_ASSERT_PTR(retTy);

muxBuiltinName += compiler::utils::BuiltinInfo::getMangledTypeStr(retTy);

auto *const ci = createBuiltinCall(
muxBuiltinName, retTy, {previous, current, delta}, /*convergent*/ true);
module.addID(op->IdResult(), op, ci);
return cargo::nullopt;
}

template <>
cargo::optional<Error> Builder::create<OpSubgroupShuffleDown>(
const OpSubgroupShuffleDown *op) {
std::string muxBuiltinName = "__mux_sub_group_shuffle_down_";

auto *current = module.getValue(op->Current());
SPIRV_LL_ASSERT_PTR(current);

auto *next = module.getValue(op->Next());
SPIRV_LL_ASSERT_PTR(next);

auto *delta = module.getValue(op->Delta());
SPIRV_LL_ASSERT_PTR(delta);

auto retTy = module.getType(op->IdResultType());
SPIRV_LL_ASSERT_PTR(retTy);

muxBuiltinName += compiler::utils::BuiltinInfo::getMangledTypeStr(retTy);

auto *const ci = createBuiltinCall(
muxBuiltinName, retTy, {current, next, delta}, /*convergent*/ true);
module.addID(op->IdResult(), op, ci);
return cargo::nullopt;
}

template <>
cargo::optional<Error> Builder::create<OpSubgroupShuffleXor>(
const OpSubgroupShuffleXor *op) {
std::string muxBuiltinName = "__mux_sub_group_shuffle_xor_";

auto *data = module.getValue(op->Data());
SPIRV_LL_ASSERT_PTR(data);

auto *value = module.getValue(op->Value());
SPIRV_LL_ASSERT_PTR(value);

auto retTy = module.getType(op->IdResultType());
SPIRV_LL_ASSERT_PTR(retTy);

muxBuiltinName += compiler::utils::BuiltinInfo::getMangledTypeStr(retTy);

auto *const ci = createBuiltinCall(muxBuiltinName, retTy, {data, value},
/*convergent*/ true);
module.addID(op->IdResult(), op, ci);
return cargo::nullopt;
}

template <>
cargo::optional<Error> Builder::create<OpReadPipe>(const OpReadPipe *) {
// Capability Pipes isn't supported by CL 1.2, see OpenCL SPIR-V
Expand Down
13 changes: 13 additions & 0 deletions modules/compiler/spirv-ll/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <spirv-ll/builder.h>
#include <spirv-ll/context.h>
#include <spirv-ll/module.h>
#include <spirv/unified1/spirv.hpp>

spirv_ll::Context::Context()
: llvmContext(new llvm::LLVMContext), llvmContextIsOwned(true) {}
Expand Down Expand Up @@ -971,6 +972,18 @@ cargo::expected<spirv_ll::Module, spirv_ll::Error> spirv_ll::Context::translate(
case spv::OpGroupLogicalXorKHR:
error = builder.create<OpGroupLogicalXorKHR>(op);
break;
case spv::OpSubgroupShuffleINTEL:
error = builder.create<OpSubgroupShuffle>(op);
break;
case spv::OpSubgroupShuffleUpINTEL:
error = builder.create<OpSubgroupShuffleUp>(op);
break;
case spv::OpSubgroupShuffleDownINTEL:
error = builder.create<OpSubgroupShuffleDown>(op);
break;
case spv::OpSubgroupShuffleXorINTEL:
error = builder.create<OpSubgroupShuffleXor>(op);
break;
case spv::OpReadPipe:
error = builder.create<OpReadPipe>(op);
break;
Expand Down
4 changes: 4 additions & 0 deletions modules/compiler/spirv-ll/source/opcodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ bool OpCode::hasResult() const {
case spv::OpSubgroupBallotKHR:
case spv::OpSubgroupFirstInvocationKHR:
case spv::OpSubgroupReadInvocationKHR:
case spv::OpSubgroupShuffleINTEL:
case spv::OpSubgroupShuffleUpINTEL:
case spv::OpSubgroupShuffleDownINTEL:
case spv::OpSubgroupShuffleXorINTEL:
case spv::OpTranspose:
case spv::OpUConvert:
case spv::OpUDiv:
Expand Down
7 changes: 6 additions & 1 deletion modules/compiler/spirv-ll/test/spvasm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2374,7 +2374,12 @@ endif()

if (SpirvAsVersionYear GREATER_EQUAL 2022)
list(APPEND SPVASM_FILES
intel_opt_none.spvasm)
intel_opt_none.spvasm
subgroup_shuffle_intel.spvasm
subgroup_shuffle_up_intel.spvasm
subgroup_shuffle_down_intel.spvasm
subgroup_shuffle_xor_intel.spvasm
)
endif()

if (SpirvAsVersionYear GREATER 2022)
Expand Down
Loading

0 comments on commit 142c2d4

Please sign in to comment.