From 85253f879c491da3b7c7038ecfb0751eba4235e6 Mon Sep 17 00:00:00 2001 From: Stanley Tsang Date: Thu, 18 Apr 2024 09:33:06 -0600 Subject: [PATCH] Added check to partition kernel if size is smaller than items_per_block (#538) (#546) Co-authored-by: Nick Breed <78807921+NB4444@users.noreply.github.com> Co-authored-by: Nick Breed --- .../rocprim/device/detail/device_merge.hpp | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/rocprim/include/rocprim/device/detail/device_merge.hpp b/rocprim/include/rocprim/device/detail/device_merge.hpp index f10ac8f80..3126f05c9 100644 --- a/rocprim/include/rocprim/device/detail/device_merge.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -68,24 +68,23 @@ void partition_kernel_impl(IndexIterator indices, const unsigned int spacing, BinaryFunction compare_function) { - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int flat_block_size = ::rocprim::detail::block_size<0>(); + const unsigned int input_size = input1_size + input2_size; + const unsigned int id = flat_block_id * flat_block_size + flat_id; + const unsigned int partition_id = id * spacing; + const unsigned int partitions = (input_size + spacing - 1) / spacing; - unsigned int id = flat_block_id * flat_block_size + flat_id; + if(id > partitions) + { + return; + } - unsigned int partition_id = id * spacing; size_t diag = min(static_cast(partition_id), input1_size + input2_size); - unsigned int begin = - merge_path( - keys_input1, - keys_input2, - input1_size, - input2_size, - diag, - compare_function - ); + unsigned int begin + = merge_path(keys_input1, keys_input2, input1_size, input2_size, diag, compare_function); indices[id] = begin; } @@ -310,8 +309,10 @@ void merge_kernel_impl(IndexIterator indices, const unsigned int valid_in_last_block = count - block_offset; const bool is_incomplete_block = valid_in_last_block < items_per_block; - const unsigned int p1 = indices[flat_block_id]; - const unsigned int p2 = indices[flat_block_id + 1]; + const unsigned int partitions = (count + items_per_block - 1) / items_per_block; + + const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)]; + const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)]; range_t range = compute_range(