diff --git a/.vscode/launch.json b/.vscode/launch.json index 985fa924..bd31386c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -176,9 +176,17 @@ "program": "${workspaceFolder}/build_debug/stringzillas_${fileBasenameNoExtension}_cu20", "cwd": "${workspaceFolder}", "environment": [ + { + "name": "STRINGWARS_STRESS", + "value": "0" + }, + { + "name": "STRINGWARS_FILTER", + "value": "levenshtein_cuda:batch64" + }, { "name": "STRINGWARS_DATASET", - "value": "leipzig1M.txt" + "value": "acgt_100.txt" } ], "stopAtEntry": false, diff --git a/CMakeLists.txt b/CMakeLists.txt index 79a3cc7a..421aaac7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -702,7 +702,6 @@ function (define_stringzillas_shared target source_file backend_flags) target_include_directories(${target} PRIVATE fork_union/include) target_compile_definitions(${target} PRIVATE "SZ_DYNAMIC_DISPATCH=1") target_compile_definitions(${target} PRIVATE "SZ_AVOID_LIBC=0") - target_compile_definitions(${target} PRIVATE "SZ_DEBUG=0") # Set backend-specific compilation flags foreach (flag ${backend_flags}) diff --git a/include/stringzillas/similarities.cuh b/include/stringzillas/similarities.cuh index f20ae59f..5a03c214 100644 --- a/include/stringzillas/similarities.cuh +++ b/include/stringzillas/similarities.cuh @@ -1629,7 +1629,6 @@ __global__ void linear_score_on_each_cuda_warp_( // substituter_type_ const substituter, linear_gap_costs_t const gap_costs, // unsigned const shared_memory_size) { - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. using task_t = task_type_; using char_t = char_type_; using index_t = index_type_; @@ -1674,7 +1673,7 @@ __global__ void linear_score_on_each_cuda_warp_( // size_t const longer_length = task.longer_length; auto &result_ref = task.result; - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. + // We are going to store 3 diagonals of the matrix. // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. unsigned const shorter_dim = static_cast(shorter_length + 1); unsigned const longer_dim = static_cast(longer_length + 1); @@ -1822,7 +1821,6 @@ __global__ void affine_score_on_each_cuda_warp_( // substituter_type_ const substituter, affine_gap_costs_t const gap_costs, // unsigned const shared_memory_size) { - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. using task_t = task_type_; using char_t = char_type_; using index_t = index_type_; @@ -1867,7 +1865,7 @@ __global__ void affine_score_on_each_cuda_warp_( // size_t const longer_length = task.longer_length; auto &result_ref = task.result; - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. + // We are going to store 3 diagonals of the matrix. // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. unsigned const shorter_dim = static_cast(shorter_length + 1); unsigned const longer_dim = static_cast(longer_length + 1); @@ -2017,6 +2015,286 @@ __global__ void affine_score_on_each_cuda_warp_( // } } +inline static constexpr unsigned levenshtein_on_each_cuda_thread_default_text_limit_k = 128; + +/** + * @brief Byte-wise equality comparison: returns 0xFF per byte where equal, 0x00 where different. + * Example: vcmpeq4(0x12345678, 0x12FF5678) = 0xFF00FFFF + */ +__forceinline__ __device__ __host__ sz_u32_t sz_u32_vcmpeq4_(sz_u32_t a, sz_u32_t b) { +#ifdef __CUDA_ARCH__ + return __vcmpeq4(a, b); +#else + sz_u32_t result = 0; + for (int i = 0; i < 4; ++i) { + sz_u8_t byte_a = (a >> (i * 8)) & 0xFF; + sz_u8_t byte_b = (b >> (i * 8)) & 0xFF; + if (byte_a == byte_b) result |= 0xFFu << (i * 8); + } + return result; +#endif +} + +/** + * @brief Byte-wise unsigned minimum. + * Example: vminu4(0x12345678, 0x34127856) = 0x12125656 + */ +__forceinline__ __device__ __host__ sz_u32_t sz_u32_vminu4_(sz_u32_t a, sz_u32_t b) { +#ifdef __CUDA_ARCH__ + return __vminu4(a, b); +#else + sz_u32_t result = 0; + for (int i = 0; i < 4; ++i) { + sz_u8_t byte_a = (a >> (i * 8)) & 0xFF; + sz_u8_t byte_b = (b >> (i * 8)) & 0xFF; + sz_u8_t min_byte = (byte_a < byte_b) ? byte_a : byte_b; + result |= (sz_u32_t)min_byte << (i * 8); + } + return result; +#endif +} + +/** + * @brief Byte-wise saturating addition (clamps at 0xFF). + * Example: vaddus4(0xFE020304, 0x03020100) = 0xFF040404 + */ +__forceinline__ __device__ __host__ sz_u32_t sz_u32_vaddus4_(sz_u32_t a, sz_u32_t b) { +#ifdef __CUDA_ARCH__ + return __vaddus4(a, b); +#else + sz_u32_t result = 0; + for (int i = 0; i < 4; ++i) { + sz_u8_t byte_a = (a >> (i * 8)) & 0xFF; + sz_u8_t byte_b = (b >> (i * 8)) & 0xFF; + sz_u32_t sum = byte_a + byte_b; + sz_u8_t sat_byte = (sum > 0xFF) ? 0xFF : (sz_u8_t)sum; + result |= (sz_u32_t)sat_byte << (i * 8); + } + return result; +#endif +} + +/** + * @brief Byte permutation: select 4 bytes from 8-byte source {x[0..3], y[0..3]}. + * Selector format (nibbles): each 4-bit value selects one of 8 source bytes. + * + * Selector nibbles: [3] [2] [1] [0] (hex digits in selector) + * Result bytes: byte3 byte2 byte1 byte0 + * + * Example: byte_perm(0x44332211, 0x88776655, 0x6543) + * Source: [11 22 33 44 | 55 66 77 88] (indices 0-7) + * Selector 0x6543 = nibbles [6,5,4,3] + * Result: [77 66 55 44] = 0x77665544 + */ +__forceinline__ __device__ __host__ sz_u32_t sz_u32_byte_perm_(sz_u32_t x, sz_u32_t y, sz_u32_t selector) { +#ifdef __CUDA_ARCH__ + return __byte_perm(x, y, selector); +#else + sz_u8_t source[8]; + for (int i = 0; i < 4; ++i) { + source[i] = (x >> (i * 8)) & 0xFF; + source[i + 4] = (y >> (i * 8)) & 0xFF; + } + + sz_u32_t result = 0; + for (int i = 0; i < 4; ++i) { + sz_u8_t sel = (selector >> (i * 4)) & 0x7; + result |= (sz_u32_t)source[sel] << (i * 8); + } + return result; +#endif +} + +/** + * @brief Register-only Levenshtein distance for strings up to 128 bytes using SIMD operations. + * + * Implements the Wagner-Fischer algorithm using a single-row optimization where only the current + * row of the DP matrix is stored in registers. The longer input string is also cached in registers + * to avoid repeated memory access. This design targets GPU architectures (H100) where register + * memory is abundant but shared/global memory access is expensive. The algorithm processes 4 columns + * at a time by packing them into uint32 vectors and using video instructions (vcmpeq4, vminu4, + * vaddus4, byte_perm) for parallel byte-wise operations. + * + * Each iteration computes one row by: (1) comparing the current character against 4 cached characters + * in parallel, (2) computing vertical and diagonal dependencies using SIMD min/add, and (3) propagating + * horizontal dependencies through 4 sequential iterations within each uint32. The result is extracted + * from the final row vector at the position corresponding to the longer string's length. Total register + * usage is approximately 70 registers (256 bytes for arrays plus temporaries). + * + * Register layout (max_text_length_ = 128): + * + * row_vec_[32] : current DP row, 4 cells per uint32 (128 bytes) + * longer_string_vec_[32] : cached string, 4 chars per uint32 (128 bytes) + * + * Processing order per uint32: + * + * 1. Compare: match[0..3] = (char == longer[0..3]) [parallel] + * 2. Vertical: r[0..3] = min(top[0..3]+gap, diag[0..3]+cost) [parallel] + * 3. Horizontal: r[0]=min(r[0],left+gap), r[1]=min(r[1],r[0]+gap), ... [sequential] + */ +template +struct register_optimal_levenshtein { + static constexpr unsigned max_text_length_k = max_text_length_; + static constexpr unsigned vec_count_k = max_text_length_k / sizeof(sz_u32_vec_t); + + sz_u32_vec_t row_vec_[vec_count_k]; + sz_u32_vec_t longer_string_vec_[vec_count_k]; + + __forceinline__ __device__ __host__ sz_u8_t operator()( // + sz_u8_t const *longer_string, unsigned longer_length, // + sz_u8_t const *shorter_string, unsigned shorter_length, // + uniform_substitution_costs_t const substituter, linear_gap_costs_t const gap_costs) { + + // Initialize the first row with the vectorized variant of: + // for (unsigned col_idx = 0; col_idx < matrix_side_k; ++col_idx) + // row_registers[col_idx] = col_idx * gap_costs.open_or_extend; + for (unsigned i = 0, running_gap = gap_costs.open_or_extend; i < vec_count_k; ++i) { + row_vec_[i].u32 = running_gap; + running_gap += gap_costs.open_or_extend; + row_vec_[i].u32 |= running_gap << 8; + running_gap += gap_costs.open_or_extend; + row_vec_[i].u32 |= running_gap << 16; + running_gap += gap_costs.open_or_extend; + row_vec_[i].u32 |= running_gap << 24; + running_gap += gap_costs.open_or_extend; + } + + // Load longer string into vector array (stays in registers, accessed in inner loop) + for (unsigned i = 0; i < longer_length; ++i) longer_string_vec_[0].u8s[i] = longer_string[i]; + + // Broadcast costs to all 4 bytes per uint32 + error_cost_t const gap_cost = gap_costs.open_or_extend; + error_cost_t const match_cost = substituter.match; + error_cost_t const mismatch_cost = substituter.mismatch; + sz_u32_vec_t gap_vec, match_vec, mismatch_vec; + gap_vec.u32 = gap_cost | (gap_cost << 8) | (gap_cost << 16) | (gap_cost << 24); + match_vec.u32 = match_cost | (match_cost << 8) | (match_cost << 16) | (match_cost << 24); + mismatch_vec.u32 = mismatch_cost | (mismatch_cost << 8) | (mismatch_cost << 16) | (mismatch_cost << 24); + + // Outer loop: iterate over shorter string (fewer iterations) + sz_u32_vec_t shorter_char_vec; + for (unsigned row_idx = 1; row_idx <= shorter_length; ++row_idx) { + // Load one character from shorter string and broadcast to all 4 bytes + sz_u8_t const shorter_char = shorter_string[row_idx - 1]; + shorter_char_vec.u32 = shorter_char | (shorter_char << 8) | (shorter_char << 16) | (shorter_char << 24); + + // Column 0 values (not stored in row_vec_) + sz_u8_t col0_prev = (row_idx - 1) * gap_cost; + sz_u8_t col0_curr = row_idx * gap_cost; + + // Broadcast col0_prev for diagonal construction + sz_u32_t prev_u32vec = col0_prev | (col0_prev << 8) | (col0_prev << 16) | (col0_prev << 24); + + // Inner loop: process longer string 4 bytes at a time + for (unsigned vec_idx = 0; vec_idx < vec_count_k; ++vec_idx) { + // Load DP values from previous row ("top" in DP matrix) + sz_u32_t top_u32vec = row_vec_[vec_idx].u32; + + // Construct diagonal: {prev[byte3], top[byte0], top[byte1], top[byte2]} + sz_u32_t diag_u32vec = sz_u32_byte_perm_(prev_u32vec, top_u32vec, 0x6543); + + // Compare shorter_char with 4 chars from longer string + sz_u32_t longer_u32vec = longer_string_vec_[vec_idx].u32; + sz_u32_t match_u32vec = sz_u32_vcmpeq4_(shorter_char_vec.u32, longer_u32vec); + + // Blend match_cost/mismatch_cost using match mask + // match_u32vec = 0xFF (match) or 0x00 (mismatch) + sz_u32_t substitutions_u32vec = (match_vec.u32 & match_u32vec) | (mismatch_vec.u32 & ~match_u32vec); + + // DP recurrence: min(diagonal + subst, top + gap, left + gap) + sz_u32_t from_diag_u32vec = sz_u32_vaddus4_(diag_u32vec, substitutions_u32vec); + sz_u32_t from_top_u32vec = sz_u32_vaddus4_(top_u32vec, gap_vec.u32); + sz_u32_t result_u32vec = sz_u32_vminu4_(from_diag_u32vec, from_top_u32vec); + + // Propagate left dependency across 4 bytes (sequential scan) + sz_u32_t left_source = (vec_idx == 0) ? col0_curr : (row_vec_[vec_idx - 1].u32 >> 24); + sz_u32_t left_u32vec = left_source | (left_source << 8) | (left_source << 16) | (left_source << 24); + + // 4 iterations to propagate left-to-right within uint32: + // After iter N, bytes [0..N] are finalized + sz_u32_t shifted_u32vec = sz_u32_byte_perm_(left_u32vec, result_u32vec, 0x6540); + result_u32vec = sz_u32_vminu4_(result_u32vec, sz_u32_vaddus4_(shifted_u32vec, gap_vec.u32)); + + shifted_u32vec = sz_u32_byte_perm_(result_u32vec, result_u32vec, 0x2100); + result_u32vec = sz_u32_vminu4_(result_u32vec, sz_u32_vaddus4_(shifted_u32vec, gap_vec.u32)); + + shifted_u32vec = sz_u32_byte_perm_(result_u32vec, result_u32vec, 0x2110); + result_u32vec = sz_u32_vminu4_(result_u32vec, sz_u32_vaddus4_(shifted_u32vec, gap_vec.u32)); + + shifted_u32vec = sz_u32_byte_perm_(result_u32vec, result_u32vec, 0x2221); + result_u32vec = sz_u32_vminu4_(result_u32vec, sz_u32_vaddus4_(shifted_u32vec, gap_vec.u32)); + + prev_u32vec = top_u32vec; + row_vec_[vec_idx].u32 = result_u32vec; + } + } + + // Extract result once after loop completes + unsigned result_vec_idx = (longer_length - 1) / 4; + unsigned result_byte_idx = (longer_length - 1) % 4; + sz_u32_t result_vec = row_vec_[result_vec_idx].u32; + sz_u8_t relevant_score = (result_vec >> (result_byte_idx * 8)) & 0xFF; + return relevant_score; + } +}; + +/** + * @brief Levenshtein edit distances algorithm evaluating the Dynamic Programming matrix + * @b two-rows at a time on a GPU, leveraging CUDA for parallelization. + * Each pair of strings gets its own @b "thread" and uses only @b register memory! + * + * As each thread receives a different input, the DP matrix size would also be different and we would end up + * with divergent branches. To address this issue, we pad all the inputs to a fixed maximum length, and keep + * a separate variable tracking the passed entry. + * + * @param[in] tasks Tasks containing the strings and output locations. + * @param[in] tasks_count The number of tasks to process. + * @param[in] substituter The substitution costs. + * @param[in] gap_costs The @b linear gap costs. + */ +template < // + typename task_type_, + typename char_type_ = char, // + typename score_type_ = size_t, // + sz_capability_t capability_ = sz_cap_cuda_k, // + unsigned max_text_length_ = levenshtein_on_each_cuda_thread_default_text_limit_k // + > +__global__ __launch_bounds__(256, 4) // Target: 256 threads/block, 4 blocks/SM minimum + void levenshtein_on_each_cuda_thread_( // + task_type_ *tasks, size_t tasks_count, // + uniform_substitution_costs_t const substituter, linear_gap_costs_t const gap_costs) { + + using task_t = task_type_; + using char_t = char_type_; + using score_t = score_type_; + + // We may have multiple warps operating in the same block. + unsigned const warp_size = warpSize; + size_t const global_thread_index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + size_t const warps_per_block = static_cast(blockDim.x / warp_size); + size_t const warps_per_device = static_cast(gridDim.x * warps_per_block); + size_t const threads_per_device = warps_per_device * warp_size; + + // Instantiate register-optimal Levenshtein struct (processes 4 bytes at a time) + register_optimal_levenshtein levenshtein_computer; + + // We are computing N edit distances for N pairs of strings. Not a cartesian product! + for (size_t task_idx = global_thread_index; task_idx < tasks_count; task_idx += threads_per_device) { + task_t &task = tasks[task_idx]; + sz_u8_t const *shorter_ptr = reinterpret_cast(task.shorter_ptr); + sz_u8_t const *longer_ptr = reinterpret_cast(task.longer_ptr); + unsigned const shorter_length = task.shorter_length; + unsigned const longer_length = task.longer_length; + auto &result_ref = task.result; + + // Call SIMD/SWAR optimized Levenshtein distance computation + sz_u8_t distance = + levenshtein_computer(longer_ptr, longer_length, shorter_ptr, shorter_length, substituter, gap_costs); + + result_ref = distance; + } +} + #pragma endregion #pragma region - Levenshtein Distance in CUDA @@ -2052,6 +2330,15 @@ struct cuda_similarity_task_ { } constexpr size_t max_diagonal_length() const noexcept { return sz_max_of_two(shorter_length, longer_length) + 1; } + + constexpr bool fits_in_registers() const noexcept { + return + // ? 2-byte entries require too many registers + bytes_per_cell == one_byte_per_cell_k && + // ? H100 has 256x 32-bit registers per thread + shorter_length <= levenshtein_on_each_cuda_thread_default_text_limit_k && + longer_length <= levenshtein_on_each_cuda_thread_default_text_limit_k; + } }; /** @@ -2106,7 +2393,7 @@ struct levenshtein_distances; for (size_t i = 0; i < first_strings.size(); ++i) { // Ensure inputs are device-accessible (Unified/Device memory) @@ -2132,10 +2419,51 @@ struct levenshtein_distances({tasks.data(), tasks.size()}, specs); + // Now some warp-level tasks are so tiny, we can perform them at the single register level! + if constexpr (!is_affine_k) { + if (count_register_level_tasks > 0) { + std::partition( // + tasks.begin(), tasks.end(), [](task_t const &task) { return task.fits_in_registers(); }); + auto thread_level_kernel = + &levenshtein_on_each_cuda_thread_; + + // Store values in variables since we need their addresses + task_t *tasks_ptr = tasks.data(); + size_t tasks_count = count_register_level_tasks; // Only register-level tasks! + + void *thread_level_kernel_args[4]; + thread_level_kernel_args[0] = (void *)(&tasks_ptr); + thread_level_kernel_args[1] = (void *)(&tasks_count); + thread_level_kernel_args[2] = (void *)(&substituter_); + thread_level_kernel_args[3] = (void *)(&gap_costs_); + + // Launch configuration tuned for register-heavy kernel with __launch_bounds__(256, 4) + unsigned const threads_per_block = 256; + unsigned const blocks_per_multiprocessor = 8; // 2048 threads/SM = 256 × 8 + unsigned const total_blocks = blocks_per_multiprocessor * specs.streaming_multiprocessors; + cudaError_t launch_error = cudaLaunchKernel( // + reinterpret_cast(thread_level_kernel), // Kernel function pointer + dim3(total_blocks), // Grid dimensions + dim3(threads_per_block), // Block dimensions + thread_level_kernel_args, // Array of kernel argument pointers + 0, // Shared memory per block (in bytes) + executor.stream()); // CUDA stream + if (launch_error != cudaSuccess) { + if (launch_error == cudaErrorMemoryAllocation) return {status_t::bad_alloc_k, launch_error}; + else + return {status_t::unknown_k, launch_error}; + } + } + } + + // Group remaining non-register tasks into device-level and warp-level. + auto [device_level_tasks, warp_level_tasks, empty_tasks] = warp_tasks_grouping( + {tasks.data() + count_register_level_tasks, tasks.size() - count_register_level_tasks}, specs); if (device_level_tasks.size()) { auto device_level_u16_kernel = @@ -2258,6 +2586,13 @@ struct levenshtein_distances(indicative_task.memory_requirement * optimal_density); @@ -2291,18 +2626,17 @@ struct levenshtein_distances matrix_buffer(rows * cols); // Initialize the borders of the matrix. - for (std::size_t i = 0; i < rows; ++i) matrix_buffer[i * cols + 0] /* [i][0] in 2D */ = i * gap_cost; - for (std::size_t j = 0; j < cols; ++j) matrix_buffer[0 * cols + j] /* [0][j] in 2D */ = j * gap_cost; - - for (std::size_t i = 1; i < rows; ++i) { - std::size_t const *last_row = &matrix_buffer[(i - 1) * cols]; - std::size_t *row = &matrix_buffer[i * cols]; - for (std::size_t j = 1; j < cols; ++j) { - std::size_t substitution_cost = (s1[i - 1] == s2[j - 1]) ? match_cost : mismatch_cost; - std::size_t if_deletion_or_insertion = std::min(last_row[j], row[j - 1]) + gap_cost; - row[j] = std::min(if_deletion_or_insertion, last_row[j - 1] + substitution_cost); + for (std::size_t row = 0; row < rows; ++row) matrix_buffer[row * cols + 0] /* [row][0] in 2D */ = row * gap_cost; + for (std::size_t col = 0; col < cols; ++col) matrix_buffer[0 * cols + col] /* [0][col] in 2D */ = col * gap_cost; + + for (std::size_t row = 1; row < rows; ++row) { + std::size_t const *last_row_buffer = &matrix_buffer[(row - 1) * cols]; + std::size_t *current_row_buffer = &matrix_buffer[row * cols]; + for (std::size_t col = 1; col < cols; ++col) { + std::size_t substitution_cost = (s1[row - 1] == s2[col - 1]) ? match_cost : mismatch_cost; + std::size_t if_deletion_or_insertion = + std::min(last_row_buffer[col], current_row_buffer[col - 1]) + gap_cost; + current_row_buffer[col] = std::min(if_deletion_or_insertion, last_row_buffer[col - 1] + substitution_cost); } } return matrix_buffer.back(); } +#if 0 // ! Coming later :) + +/** + * The purpose of this class is to emulate SIMD processing in CUDA with 4-byte-wide words, + * entirely unrolling the entire Levenshtein matrix calculation in diagonal order, only + * using properly aligned loads! + */ +template +struct levenshtein_recursive_algorithm { + + using text_t = char[max_length_]; + using diagonal_scores_t = std::uint8_t[max_length_ + 1]; + + static constexpr unsigned diagonals_count = max_length_ * 2 + 1; + static constexpr unsigned max_diagonal_length = max_length_ + 1; + constexpr unsigned next_diagonal_length = next_diagonal_index_ + 1; + + std::size_t operator()( // + text_t s1, std::size_t len1, text_t s2, std::size_t len2, // + error_cost_t match_cost = 0, error_cost_t mismatch_cost = 1, error_cost_t gap_cost = 1, // + diagonal_scores_t previous, diagonal_scores_t current, diagonal_scores_t next) { + + // Top-left cell of the matrix + if constexpr (next_diagonal_index_ == 0) { + next[0] = 0; + // Recurse to the next diagonal + levenshtein_recursive_algorithm next_run; + return next_run(s1, s2, // same texts as before + match_cost, mismatch_cost, gap_cost, // same costs + current, next, previous); // 3-way rotation of the diagonals + } + // The cells below and next to it + else if constexpr (next_diagonal_index_ == 1) { + next[0] = gap_cost, next[1] = gap_cost; + + // Recurse to the next diagonal + levenshtein_recursive_algorithm next_run; + return next_run(s1, s2, // same texts as before + match_cost, mismatch_cost, gap_cost, // same costs + current, next, previous); // 3-way rotation of the diagonals + } + // We are still in the top-left triangle + else if constexpr (next_diagonal_index_ <= max_length_) { + + constexpr unsigned next_diagonal_length = next_diagonal_index_ + 1; +#pragma unroll + for (unsigned k = 1; k < next_diagonal_length - 1; ++k) { + char first_char = s1[next_diagonal_length - k]; + char second_char = s2[k - 1]; + + error_cost_t substitution_cost = first_char == second_char ? match_cost : mismatch_cost; + error_cost_t if_substitution = previous[k - 1] + substitution_cost; + error_cost_t if_deletion_or_insertion = std::min(current[k], current[k - 1]) + gap_cost; + next[k] = std::min(if_deletion_or_insertion, if_substitution); + } + + // Single-byte stores at the edges of the diagonal - to overwrite the noisy values: + next[0] = gap_cost * next_diagonal_index_; + next[next_diagonal_length - 1] = gap_cost * next_diagonal_index_; + + // Recurse to the next diagonal + levenshtein_recursive_algorithm next_run; + return next_run(s1, s2, // same texts as before + match_cost, mismatch_cost, gap_cost, // same costs + current, next, previous); // 3-way rotation of the diagonals + } + } + // We are in the last diagonal and need to return the bottom-right cell + else if constexpr (next_diagonal_index_ + 1 == diagonals_count) { + char first_char = s1[max_length_ - 1]; + char second_char = s2[max_length_ - 1]; + + error_cost_t substitution_cost = first_char == second_char ? match_cost : mismatch_cost; + error_cost_t if_substitution = previous[0] + substitution_cost; + error_cost_t if_deletion_or_insertion = std::min(current[0], current[1]) + gap_cost; + next[0] = std::min(if_deletion_or_insertion, if_substitution); + + return next[0]; + } + // We are in the bottom-right triangle + else { + + constexpr unsigned next_diagonal_length = diagonals_count - next_diagonal_index_; +#pragma unroll + for (unsigned k = 0; k < next_diagonal_length; ++k) { + char first_char = s1[...]; + char second_char = s2[...]; + + error_cost_t substitution_cost = first_char == second_char ? match_cost : mismatch_cost; + error_cost_t if_substitution = previous[k - 1] + substitution_cost; + error_cost_t if_deletion_or_insertion = std::min(current[k], current[k - 1]) + gap_cost; + next[k] = std::min(if_deletion_or_insertion, if_substitution); + } + + // Recurse to the next diagonal + levenshtein_recursive_algorithm next_run; + return next_run(s1, s2, // same texts as before + match_cost, mismatch_cost, gap_cost, // same costs + current, next, previous); // 3-way rotation of the diagonals + } +}; + +inline std::size_t levenshtein_recursive_baseline( // + char const *s1, std::size_t len1, char const *s2, std::size_t len2, // + error_cost_t match_cost = 0, error_cost_t mismatch_cost = 1, error_cost_t gap_cost = 1) noexcept(false) { + + static constexpr std::size_t max_length = 8; + levenshtein_recursive_algorithm algorithm; + + char s1_padded[max_length] = {0}; + char s2_padded[max_length] = {0}; + std::memcpy(s1_padded, s1, len1); + std::memcpy(s2_padded, s2, len2); + + // Need 3 diagonals for the rotation + std::uint8_t previous[max_length + 1] = {0}; + std::uint8_t current[max_length + 1] = {0}; + std::uint8_t next[max_length + 1] = {0}; + + return algorithm(s1_padded, s2_padded, match_cost, mismatch_cost, gap_cost, previous, current, next); +} + +#endif // ! Coming later :) + /** * @brief Inefficient baseline Needleman-Wunsch alignment score computation, as implemented in most codebases. * @warning Allocates a new matrix on every call, with rows potentially scattered around memory.