Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions csrc/kernel/kernel_gdn_causal_conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ using namespace kernel_utils;
// Weights W[K,channels] + bias[channels] are fp32 GM tensors. fp16 OR bf16
// I/O, fp32 accumulate.
//
// Filter width K and per-tile channel width MAX_W are compile-time
// constants chosen at the call site as template parameters (no preprocessor
// config), e.g.
// constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W;
// csilu::runConvSiluBatched<bfloat16_t, float, K, MAX_W>(...);
// Filter width K is a runtime parameter passed to runConvSiluBatched; the
// per-tile channel width MAX_W stays a compile-time template parameter (it
// fixes the UB tile sizes), e.g.
// constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W;
// const uint32_t K = CAUSAL_CONV_K;
// csilu::runConvSiluBatched<bfloat16_t, float, MAX_W>(..., K);
//
// 2-D-plus-batch work grid: workUnits = batch x sequenceChunkCount x
// channelTileCount. Each work unit produces outputs [batchIndex] x
Expand All @@ -61,7 +62,7 @@ using namespace kernel_utils;
// noted): K weights + bias(1) + accumRingSize accumulators + (K-1) partial
// products + input-as-fp32(1) = 2*K+accumRingSize+1 fp32 tiles; then the
// I/O region inputTile[0..1] + output0
// + output1 = 4 I/O tiles (input load double-buffered). A static_assert
// + output1 = 4 I/O tiles (input load double-buffered). A runtime guard
// keeps the total within UB_BYTES_PER_CORE for the chosen K / MAX_W /
// dtypes. NOTE: uses the PTO tile-op API (<pto/pto-inst.hpp>); the `csilu`
// namespace avoids a clash with pto::detail.
Expand All @@ -70,7 +71,7 @@ using namespace kernel_utils;
namespace csilu {

// Unified Buffer available per AIV core (Ascend 910B2 = 192 KiB). The UB
// static_assert in processWorkUnit checks the chosen layout fits; raise it
// runtime guard in processWorkUnit checks the chosen layout fits; raise it
// for a next-gen NPU with a larger UB.
constexpr uint32_t UB_BYTES_PER_CORE = 192u * 1024u;

Expand Down Expand Up @@ -104,14 +105,13 @@ AICORE inline void applySiluToTile(TileT& dst, TileT& src, TileT& scratch) {
// Process ONE work unit: outputs [outputRowStart,outputRowEnd) for channels
// [channelTileBase,channelTileBase+tileChannelCount) of the sequence whose
// first row is at element offset sequenceRowOffset. x[<0]=0 (no cache).
template <typename IoElemType, typename AccumElemType, uint32_t K,
uint32_t MAX_W>
template <typename IoElemType, typename AccumElemType, uint32_t MAX_W>
AICORE inline void processWorkUnit(
__gm__ IoElemType* input, __gm__ IoElemType* output,
__gm__ AccumElemType* weights, __gm__ AccumElemType* bias,
uint32_t channels, uint64_t sequenceRowOffset, uint32_t channelTileBase,
int32_t tileChannelCount, uint32_t outputRowStart, uint32_t outputRowEnd,
uint32_t applyActivation) {
uint32_t applyActivation, uint32_t K) {
using GlobalShape = pto::Shape<1, 1, 1, 1, DYNAMIC>;
using GlobalStride = pto::Stride<1, 1, 1, 1, 1>;
using GlobalIoTensor =
Expand All @@ -127,33 +127,31 @@ AICORE inline void processWorkUnit(
constexpr uint32_t ioTileBytes = MAX_W * sizeof(IoElemType);
// accumulator ring (power of two >= K) so the K in-flight outputs never
// alias.
constexpr uint32_t accumRingSize = roundUpToPowerOfTwo(K);
constexpr uint32_t accumRingMask =
accumRingSize - 1u; // ring-slot index mask
static_assert(K <= accumRingSize, "accumulator ring must hold all K taps");
const uint32_t accumRingSize = roundUpToPowerOfTwo(K);
const uint32_t accumRingMask = accumRingSize - 1u; // ring-slot index mask

// UB byte offsets. fp32 region: K weights (weight k at k*accumTileBytes) |
// bias | accumRingSize accumulators | K-1 partial products | input-as-fp32.
// Then the I/O region: 4 ioTileBytes-sized tiles (input load
// double-buffered).
constexpr uint32_t ubBiasOffset = K * accumTileBytes;
constexpr uint32_t ubAccumRingBase = (K + 1u) * accumTileBytes;
const uint32_t ubBiasOffset = K * accumTileBytes;
const uint32_t ubAccumRingBase = (K + 1u) * accumTileBytes;
// partial product for tap k at ubProductBase + (k-1)*accumTileBytes; also
// reused as the SiLU scratch tile once the products have been summed.
constexpr uint32_t ubProductBase = (K + 1u + accumRingSize) * accumTileBytes;
constexpr uint32_t ubInputFp32Offset =
(2u * K + accumRingSize) * accumTileBytes;
constexpr uint32_t ubIoRegionBase =
const uint32_t ubProductBase = (K + 1u + accumRingSize) * accumTileBytes;
const uint32_t ubInputFp32Offset = (2u * K + accumRingSize) * accumTileBytes;
const uint32_t ubIoRegionBase =
(2u * K + accumRingSize + 1u) * accumTileBytes;
static_assert(
ubIoRegionBase + 4u * ioTileBytes <= UB_BYTES_PER_CORE,
"conv1d UB exceeds UB_BYTES_PER_CORE: lower K/MAX_W or raise it");
// K is now a runtime parameter, so the UB-fit check that used to be a
// static_assert becomes a runtime guard: if the chosen K / MAX_W / dtypes
// overflow the per-core UB, skip the work unit rather than corrupt memory.
if (ubIoRegionBase + 4u * ioTileBytes > UB_BYTES_PER_CORE) return;

constexpr uint32_t ubOutputOffset[2] = {ubIoRegionBase + ioTileBytes,
ubIoRegionBase + 2u * ioTileBytes};
const uint32_t ubOutputOffset[2] = {ubIoRegionBase + ioTileBytes,
ubIoRegionBase + 2u * ioTileBytes};
// input double-buffer: slot 0 before the outputs, slot 1 after.
constexpr uint32_t ubInputOffset[2] = {ubIoRegionBase,
ubIoRegionBase + 3u * ioTileBytes};
const uint32_t ubInputOffset[2] = {ubIoRegionBase,
ubIoRegionBase + 3u * ioTileBytes};

const uint32_t firstInputRow =
(outputRowStart > (K - 1)) ? (outputRowStart - (K - 1)) : 0u;
Expand Down Expand Up @@ -309,15 +307,14 @@ AICORE inline void processWorkUnit(
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2);
}

template <typename IoElemType, typename AccumElemType, uint32_t K,
uint32_t MAX_W>
template <typename IoElemType, typename AccumElemType, uint32_t MAX_W>
AICORE void runConvSiluBatched(__gm__ IoElemType* input,
__gm__ IoElemType* output,
__gm__ AccumElemType* weights,
__gm__ AccumElemType* bias, uint32_t batch,
uint32_t seqLen, uint32_t channels,
uint32_t applyActivation) {
static_assert(K >= 1u, "K (filter width) must be >= 1");
uint32_t applyActivation, uint32_t K) {
if (K < 1u) return; // K (filter width) must be >= 1

set_mask_norm();
set_vector_mask(-1, -1);
Expand Down Expand Up @@ -393,10 +390,10 @@ AICORE void runConvSiluBatched(__gm__ IoElemType* input,
uint32_t outputRowEnd = outputRowStart + sequenceChunkLength;
if (outputRowEnd > seqLen) outputRowEnd = seqLen;
const uint64_t sequenceRowOffset = (uint64_t)batchIndex * seqLen * channels;
processWorkUnit<IoElemType, AccumElemType, K, MAX_W>(
processWorkUnit<IoElemType, AccumElemType, MAX_W>(
input, output, weights, bias, channels, sequenceRowOffset,
channelTileBase, tileChannelCount, outputRowStart, outputRowEnd,
applyActivation);
applyActivation, K);
}
}

Expand Down Expand Up @@ -433,10 +430,11 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_fp16(
__gm__ uint8_t* input, __gm__ uint8_t* output, __gm__ uint8_t* weights,
__gm__ uint8_t* bias, uint32_t seqLen, uint32_t channels) {
#if defined(__DAV_VEC__)
constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W;
csilu::runConvSiluBatched<half, float, K, MAX_W>(
constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W;
const uint32_t K = CAUSAL_CONV_K;
csilu::runConvSiluBatched<half, float, MAX_W>(
(__gm__ half*)input, (__gm__ half*)output, (__gm__ float*)weights,
(__gm__ float*)bias, 1u, seqLen, channels, 1u);
(__gm__ float*)bias, 1u, seqLen, channels, 1u, K);
#else
(void)input;
(void)output;
Expand Down Expand Up @@ -478,10 +476,11 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_batched_fp16(
__gm__ uint8_t* bias, uint32_t batch, uint32_t seqLen, uint32_t channels,
uint32_t applyActivation) {
#if defined(__DAV_VEC__)
constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W;
csilu::runConvSiluBatched<half, float, K, MAX_W>(
constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W;
const uint32_t K = CAUSAL_CONV_K;
csilu::runConvSiluBatched<half, float, MAX_W>(
(__gm__ half*)input, (__gm__ half*)output, (__gm__ float*)weights,
(__gm__ float*)bias, batch, seqLen, channels, applyActivation);
(__gm__ float*)bias, batch, seqLen, channels, applyActivation, K);
#else
(void)input;
(void)output;
Expand Down Expand Up @@ -527,11 +526,12 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_batched_bf16(
__gm__ uint8_t* bias, uint32_t batch, uint32_t seqLen, uint32_t channels,
uint32_t applyActivation) {
#if defined(__DAV_VEC__)
constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W;
csilu::runConvSiluBatched<bfloat16_t, float, K, MAX_W>(
constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W;
const uint32_t K = CAUSAL_CONV_K;
csilu::runConvSiluBatched<bfloat16_t, float, MAX_W>(
(__gm__ bfloat16_t*)input, (__gm__ bfloat16_t*)output,
(__gm__ float*)weights, (__gm__ float*)bias, batch, seqLen, channels,
applyActivation);
applyActivation, K);
#else
(void)input;
(void)output;
Expand Down