Skip to content

Commit

Permalink
Fix temporary storage layout
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 9, 2023
1 parent dee9014 commit 1574d91
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -184,22 +184,21 @@ struct AgentScan
typedef BlockScanRunningPrefixOp<AccumT, ScanOpT> RunningPrefixCallbackOp;

// Shared memory type for this thread block
union _TempStorage
struct _TempStorage
{
// Smem needed for tile loading
typename BlockLoadT::TempStorage load;

// Smem needed for tile storing
typename BlockStoreT::TempStorage store;

struct ScanStorage
{
// Smem needed for cooperative prefix callback
typename TilePrefixCallbackOpT::TempStorage prefix;
union {
// Smem needed for tile loading
typename BlockLoadT::TempStorage load;

// Smem needed for tile storing
typename BlockStoreT::TempStorage store;

// Smem needed for tile scanning
typename BlockScanT::TempStorage scan;
} scan_storage;
};

// Smem needed for cooperative prefix callback
typename TilePrefixCallbackOpT::TempStorage prefix;
};

// Alias wrapper allowing storage to be unioned
Expand Down Expand Up @@ -229,7 +228,7 @@ struct AgentScan
AccumT &block_aggregate,
Int2Type<false> /*is_inclusive*/)
{
BlockScanT(temp_storage.scan_storage.scan)
BlockScanT(temp_storage.scan)
.ExclusiveScan(items, items, init_value, scan_op, block_aggregate);
block_aggregate = scan_op(init_value, block_aggregate);
}
Expand All @@ -243,7 +242,7 @@ struct AgentScan
AccumT &block_aggregate,
Int2Type<true> /*is_inclusive*/)
{
BlockScanT(temp_storage.scan_storage.scan)
BlockScanT(temp_storage.scan)
.InclusiveScan(items, items, scan_op, block_aggregate);
}

Expand All @@ -256,7 +255,7 @@ struct AgentScan
PrefixCallback &prefix_op,
Int2Type<false> /*is_inclusive*/)
{
BlockScanT(temp_storage.scan_storage.scan)
BlockScanT(temp_storage.scan)
.ExclusiveScan(items, items, scan_op, prefix_op);
}

Expand All @@ -269,7 +268,7 @@ struct AgentScan
PrefixCallback &prefix_op,
Int2Type<true> /*is_inclusive*/)
{
BlockScanT(temp_storage.scan_storage.scan)
BlockScanT(temp_storage.scan)
.InclusiveScan(items, items, scan_op, prefix_op);
}

Expand Down Expand Up @@ -334,7 +333,7 @@ struct AgentScan
{
// Reset dsmem
TilePrefixCallbackOpT prefix_op(tile_state,
temp_storage.scan_storage.prefix,
temp_storage.prefix,
scan_op,
tile_idx);
prefix_op.InitializeDSMem();
Expand Down

0 comments on commit 1574d91

Please sign in to comment.