Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
return false;
}

if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}

if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

// Full K needed for matrix B
const index_t Kt = karg.K;

auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);

const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;

GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
k_id,
Kt);
}
#else
ignore = karg;
Expand All @@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];

// Full K needed for matrix B
const index_t Kt = karg.K;

auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);

const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;

GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0,
p_shared_1,
karg);
karg,
k_id,
Kt);
}
#else
ignore = karg;
Expand Down Expand Up @@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
}

if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
// b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;

b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
}
}

if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
Expand All @@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
}

index_t a_k_split_offset;
index_t b_k_split_offset;
index_t c_reduce_offset;
};

Expand Down Expand Up @@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");

if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
{
return false;
}

if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
Expand Down Expand Up @@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
Expand Down Expand Up @@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPack * (get_thread_local_1d_id() % WarpSize)));

// LDS allocation for A and B: be careful of alignment
Expand Down Expand Up @@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
const index_t k_id,
const index_t Kt)
{
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
Expand All @@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bpreshuffled,
c_grid_desc_mblock_mperblock_nblock_nperblock);
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id);
}

template <typename AGridDesc_AK0_M_K1,
Expand All @@ -1509,7 +1514,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
Expand Down Expand Up @@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPack * (get_thread_local_1d_id() % WarpSize)));

// LDS allocation for A and B: be careful of alignment
Expand Down Expand Up @@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem)
const Problem& problem,
const index_t k_id,
const index_t Kt)
{
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
Expand All @@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bpreshuffled,
c_grid_desc_mblock_mperblock_nblock_nperblock);
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

// Full K needed for matrix B
const index_t Kt = karg.K;

auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);

const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;

GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
karg.c_element_op,
k_id,
Kt);
}
#else
ignore = karg;
Expand All @@ -79,19 +87,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];

// Full K needed for matrix B
const index_t Kt = karg.K;

auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);

const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;

GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
karg.c_element_op,
k_id,
Kt);
}
#else
ignore = karg;
Expand Down Expand Up @@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}

if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane;
}

if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
Expand All @@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}

index_t a_k_split_offset;
index_t b_k_split_offset;
};

__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Expand Down Expand Up @@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const index_t k_id,
const index_t Kt)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Expand All @@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
block_2_ctile_map,
k_id,
Kt);
}

template <typename Block2CTileMap,
Expand All @@ -1192,11 +1201,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const index_t k_id,
const index_t Kt)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);

const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
Expand Down Expand Up @@ -1293,7 +1304,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));

// LDS allocation for A and B: be careful of alignment
Expand Down Expand Up @@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const index_t k_id,
const index_t Kt)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Expand All @@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
block_2_ctile_map,
k_id,
Kt);
}

template <typename Block2CTileMap,
Expand All @@ -1628,11 +1643,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const index_t k_id,
const index_t Kt)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);

Expand Down Expand Up @@ -1731,7 +1748,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));

// LDS allocation for A and B: be careful of alignment
Expand Down
Loading