diff --git a/cub/agent/agent_scan.cuh b/cub/agent/agent_scan.cuh index 8e265d6c0..e8869034c 100644 --- a/cub/agent/agent_scan.cuh +++ b/cub/agent/agent_scan.cuh @@ -184,22 +184,21 @@ struct AgentScan typedef BlockScanRunningPrefixOp RunningPrefixCallbackOp; // Shared memory type for this thread block - union _TempStorage + struct _TempStorage { - // Smem needed for tile loading - typename BlockLoadT::TempStorage load; - - // Smem needed for tile storing - typename BlockStoreT::TempStorage store; - - struct ScanStorage - { - // Smem needed for cooperative prefix callback - typename TilePrefixCallbackOpT::TempStorage prefix; + union { + // Smem needed for tile loading + typename BlockLoadT::TempStorage load; + // Smem needed for tile storing + typename BlockStoreT::TempStorage store; + // Smem needed for tile scanning typename BlockScanT::TempStorage scan; - } scan_storage; + }; + + // Smem needed for cooperative prefix callback + typename TilePrefixCallbackOpT::TempStorage prefix; }; // Alias wrapper allowing storage to be unioned @@ -229,7 +228,7 @@ struct AgentScan AccumT &block_aggregate, Int2Type /*is_inclusive*/) { - BlockScanT(temp_storage.scan_storage.scan) + BlockScanT(temp_storage.scan) .ExclusiveScan(items, items, init_value, scan_op, block_aggregate); block_aggregate = scan_op(init_value, block_aggregate); } @@ -243,7 +242,7 @@ struct AgentScan AccumT &block_aggregate, Int2Type /*is_inclusive*/) { - BlockScanT(temp_storage.scan_storage.scan) + BlockScanT(temp_storage.scan) .InclusiveScan(items, items, scan_op, block_aggregate); } @@ -256,7 +255,7 @@ struct AgentScan PrefixCallback &prefix_op, Int2Type /*is_inclusive*/) { - BlockScanT(temp_storage.scan_storage.scan) + BlockScanT(temp_storage.scan) .ExclusiveScan(items, items, scan_op, prefix_op); } @@ -269,7 +268,7 @@ struct AgentScan PrefixCallback &prefix_op, Int2Type /*is_inclusive*/) { - BlockScanT(temp_storage.scan_storage.scan) + BlockScanT(temp_storage.scan) .InclusiveScan(items, items, scan_op, prefix_op); } @@ -334,7 +333,7 @@ struct AgentScan { // Reset dsmem TilePrefixCallbackOpT prefix_op(tile_state, - temp_storage.scan_storage.prefix, + temp_storage.prefix, scan_op, tile_idx); prefix_op.InitializeDSMem();