-
Notifications
You must be signed in to change notification settings - Fork 285
Fix BlockScan accumulator type handling #6443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
bc27d15 to
c122fad
Compare
cub/cub/block/block_reduce.cuh
Outdated
| // Reduce partials | ||
| T partial = cub::ThreadReduce(inputs, reduction_op); | ||
| T partial = | ||
| cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(inputs)>, ReductionOp, T, T>(inputs, reduction_op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks a partial solution. It could regress for small integer types. For example, reduction/scan over int8_t. It is better to perform the computation with 32-bit and cast back at the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I dropped the explicit T and taught ThreadReduce to keep its __accumulator_t promotion, so int8_t still widens to 32-bit.
cub/cub/block/block_reduce.cuh
Outdated
| { | ||
| // Reduce partials | ||
| T partial = cub::ThreadReduce(inputs, ::cuda::std::plus<>{}); | ||
| T partial = cub::ThreadReduce<::cuda::std::remove_reference_t<decltype(inputs)>, ::cuda::std::plus<>, T, T>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. please isolate the first template parameter with using to improve readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fbusato Thanks! I reverted that spot to plain ThreadReduce(inputs, …), so there’s nothing left to alias. Let me know if you’d still like a using helper there.
85c2484 to
0bcd084
Compare
Summary
ThreadReduceaccumulator types pinned to the block value type acrossBlockScanandBlockReduceBlockScanwith a functor returning a wider typeMotivation
#5668shows thatBlockScanwidens the accumulator when the scan functor returns a wider type than the block value. That implicit widening breaks user code that relies on the original type and can even hit deleted overloads.Explanation
ThreadReducewas deducing its accumulator type from the functor instead of the block valueT. The patch explicitly instantiatesThreadReducewithAccumT = TeverywhereBlockScanandBlockReducedispatch through it, including the raking specialization. The new unit test exercises an operator that returnslong longforintinputs and verifies the accumulator remainsint.Rationale
ThreadReducecall sites; public APIs and template parameters stay the same.Testing
pre-commit run --files cub/cub/block/block_scan.cuh cub/cub/block/block_reduce.cuh cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh cub/test/catch2_test_block_scan.cu