From 929c52042eac500edf97770be60e7db28eaf427b Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 3 Jul 2026 08:38:03 +0000 Subject: [PATCH] fix(causal_conv1d) make K a dynamic parameter --- csrc/kernel/kernel_gdn_causal_conv1d.cpp | 84 ++++++++++++------------ 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/csrc/kernel/kernel_gdn_causal_conv1d.cpp b/csrc/kernel/kernel_gdn_causal_conv1d.cpp index f1d9369f..e5acf85a 100644 --- a/csrc/kernel/kernel_gdn_causal_conv1d.cpp +++ b/csrc/kernel/kernel_gdn_causal_conv1d.cpp @@ -39,11 +39,12 @@ using namespace kernel_utils; // Weights W[K,channels] + bias[channels] are fp32 GM tensors. fp16 OR bf16 // I/O, fp32 accumulate. // -// Filter width K and per-tile channel width MAX_W are compile-time -// constants chosen at the call site as template parameters (no preprocessor -// config), e.g. -// constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W; -// csilu::runConvSiluBatched(...); +// Filter width K is a runtime parameter passed to runConvSiluBatched; the +// per-tile channel width MAX_W stays a compile-time template parameter (it +// fixes the UB tile sizes), e.g. +// constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W; +// const uint32_t K = CAUSAL_CONV_K; +// csilu::runConvSiluBatched(..., K); // // 2-D-plus-batch work grid: workUnits = batch x sequenceChunkCount x // channelTileCount. Each work unit produces outputs [batchIndex] x @@ -61,7 +62,7 @@ using namespace kernel_utils; // noted): K weights + bias(1) + accumRingSize accumulators + (K-1) partial // products + input-as-fp32(1) = 2*K+accumRingSize+1 fp32 tiles; then the // I/O region inputTile[0..1] + output0 -// + output1 = 4 I/O tiles (input load double-buffered). A static_assert +// + output1 = 4 I/O tiles (input load double-buffered). A runtime guard // keeps the total within UB_BYTES_PER_CORE for the chosen K / MAX_W / // dtypes. NOTE: uses the PTO tile-op API (); the `csilu` // namespace avoids a clash with pto::detail. @@ -70,7 +71,7 @@ using namespace kernel_utils; namespace csilu { // Unified Buffer available per AIV core (Ascend 910B2 = 192 KiB). The UB -// static_assert in processWorkUnit checks the chosen layout fits; raise it +// runtime guard in processWorkUnit checks the chosen layout fits; raise it // for a next-gen NPU with a larger UB. constexpr uint32_t UB_BYTES_PER_CORE = 192u * 1024u; @@ -104,14 +105,13 @@ AICORE inline void applySiluToTile(TileT& dst, TileT& src, TileT& scratch) { // Process ONE work unit: outputs [outputRowStart,outputRowEnd) for channels // [channelTileBase,channelTileBase+tileChannelCount) of the sequence whose // first row is at element offset sequenceRowOffset. x[<0]=0 (no cache). -template +template AICORE inline void processWorkUnit( __gm__ IoElemType* input, __gm__ IoElemType* output, __gm__ AccumElemType* weights, __gm__ AccumElemType* bias, uint32_t channels, uint64_t sequenceRowOffset, uint32_t channelTileBase, int32_t tileChannelCount, uint32_t outputRowStart, uint32_t outputRowEnd, - uint32_t applyActivation) { + uint32_t applyActivation, uint32_t K) { using GlobalShape = pto::Shape<1, 1, 1, 1, DYNAMIC>; using GlobalStride = pto::Stride<1, 1, 1, 1, 1>; using GlobalIoTensor = @@ -127,33 +127,31 @@ AICORE inline void processWorkUnit( constexpr uint32_t ioTileBytes = MAX_W * sizeof(IoElemType); // accumulator ring (power of two >= K) so the K in-flight outputs never // alias. - constexpr uint32_t accumRingSize = roundUpToPowerOfTwo(K); - constexpr uint32_t accumRingMask = - accumRingSize - 1u; // ring-slot index mask - static_assert(K <= accumRingSize, "accumulator ring must hold all K taps"); + const uint32_t accumRingSize = roundUpToPowerOfTwo(K); + const uint32_t accumRingMask = accumRingSize - 1u; // ring-slot index mask // UB byte offsets. fp32 region: K weights (weight k at k*accumTileBytes) | // bias | accumRingSize accumulators | K-1 partial products | input-as-fp32. // Then the I/O region: 4 ioTileBytes-sized tiles (input load // double-buffered). - constexpr uint32_t ubBiasOffset = K * accumTileBytes; - constexpr uint32_t ubAccumRingBase = (K + 1u) * accumTileBytes; + const uint32_t ubBiasOffset = K * accumTileBytes; + const uint32_t ubAccumRingBase = (K + 1u) * accumTileBytes; // partial product for tap k at ubProductBase + (k-1)*accumTileBytes; also // reused as the SiLU scratch tile once the products have been summed. - constexpr uint32_t ubProductBase = (K + 1u + accumRingSize) * accumTileBytes; - constexpr uint32_t ubInputFp32Offset = - (2u * K + accumRingSize) * accumTileBytes; - constexpr uint32_t ubIoRegionBase = + const uint32_t ubProductBase = (K + 1u + accumRingSize) * accumTileBytes; + const uint32_t ubInputFp32Offset = (2u * K + accumRingSize) * accumTileBytes; + const uint32_t ubIoRegionBase = (2u * K + accumRingSize + 1u) * accumTileBytes; - static_assert( - ubIoRegionBase + 4u * ioTileBytes <= UB_BYTES_PER_CORE, - "conv1d UB exceeds UB_BYTES_PER_CORE: lower K/MAX_W or raise it"); + // K is now a runtime parameter, so the UB-fit check that used to be a + // static_assert becomes a runtime guard: if the chosen K / MAX_W / dtypes + // overflow the per-core UB, skip the work unit rather than corrupt memory. + if (ubIoRegionBase + 4u * ioTileBytes > UB_BYTES_PER_CORE) return; - constexpr uint32_t ubOutputOffset[2] = {ubIoRegionBase + ioTileBytes, - ubIoRegionBase + 2u * ioTileBytes}; + const uint32_t ubOutputOffset[2] = {ubIoRegionBase + ioTileBytes, + ubIoRegionBase + 2u * ioTileBytes}; // input double-buffer: slot 0 before the outputs, slot 1 after. - constexpr uint32_t ubInputOffset[2] = {ubIoRegionBase, - ubIoRegionBase + 3u * ioTileBytes}; + const uint32_t ubInputOffset[2] = {ubIoRegionBase, + ubIoRegionBase + 3u * ioTileBytes}; const uint32_t firstInputRow = (outputRowStart > (K - 1)) ? (outputRowStart - (K - 1)) : 0u; @@ -309,15 +307,14 @@ AICORE inline void processWorkUnit( wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); } -template +template AICORE void runConvSiluBatched(__gm__ IoElemType* input, __gm__ IoElemType* output, __gm__ AccumElemType* weights, __gm__ AccumElemType* bias, uint32_t batch, uint32_t seqLen, uint32_t channels, - uint32_t applyActivation) { - static_assert(K >= 1u, "K (filter width) must be >= 1"); + uint32_t applyActivation, uint32_t K) { + if (K < 1u) return; // K (filter width) must be >= 1 set_mask_norm(); set_vector_mask(-1, -1); @@ -393,10 +390,10 @@ AICORE void runConvSiluBatched(__gm__ IoElemType* input, uint32_t outputRowEnd = outputRowStart + sequenceChunkLength; if (outputRowEnd > seqLen) outputRowEnd = seqLen; const uint64_t sequenceRowOffset = (uint64_t)batchIndex * seqLen * channels; - processWorkUnit( + processWorkUnit( input, output, weights, bias, channels, sequenceRowOffset, channelTileBase, tileChannelCount, outputRowStart, outputRowEnd, - applyActivation); + applyActivation, K); } } @@ -433,10 +430,11 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_fp16( __gm__ uint8_t* input, __gm__ uint8_t* output, __gm__ uint8_t* weights, __gm__ uint8_t* bias, uint32_t seqLen, uint32_t channels) { #if defined(__DAV_VEC__) - constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W; - csilu::runConvSiluBatched( + constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W; + const uint32_t K = CAUSAL_CONV_K; + csilu::runConvSiluBatched( (__gm__ half*)input, (__gm__ half*)output, (__gm__ float*)weights, - (__gm__ float*)bias, 1u, seqLen, channels, 1u); + (__gm__ float*)bias, 1u, seqLen, channels, 1u, K); #else (void)input; (void)output; @@ -478,10 +476,11 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_batched_fp16( __gm__ uint8_t* bias, uint32_t batch, uint32_t seqLen, uint32_t channels, uint32_t applyActivation) { #if defined(__DAV_VEC__) - constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W; - csilu::runConvSiluBatched( + constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W; + const uint32_t K = CAUSAL_CONV_K; + csilu::runConvSiluBatched( (__gm__ half*)input, (__gm__ half*)output, (__gm__ float*)weights, - (__gm__ float*)bias, batch, seqLen, channels, applyActivation); + (__gm__ float*)bias, batch, seqLen, channels, applyActivation, K); #else (void)input; (void)output; @@ -527,11 +526,12 @@ extern "C" __global__ AICORE void gdn_causal_conv1d_batched_bf16( __gm__ uint8_t* bias, uint32_t batch, uint32_t seqLen, uint32_t channels, uint32_t applyActivation) { #if defined(__DAV_VEC__) - constexpr uint32_t K = CAUSAL_CONV_K, MAX_W = CAUSAL_CONV_MAX_W; - csilu::runConvSiluBatched( + constexpr uint32_t MAX_W = CAUSAL_CONV_MAX_W; + const uint32_t K = CAUSAL_CONV_K; + csilu::runConvSiluBatched( (__gm__ bfloat16_t*)input, (__gm__ bfloat16_t*)output, (__gm__ float*)weights, (__gm__ float*)bias, batch, seqLen, channels, - applyActivation); + applyActivation, K); #else (void)input; (void)output;