Skip to content

Commit ab1a835

Browse files
authored
Add 2GB limitation for grouped conv bwd weight (#3054)
1 parent 1fbb47a commit ab1a835

File tree

4 files changed

+33
-0
lines changed

4 files changed

+33
-0
lines changed

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
18861886
}
18871887
}
18881888

1889+
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1890+
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1891+
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
1892+
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
1893+
{
1894+
return false;
1895+
}
1896+
18891897
return true;
18901898
}
18911899

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
14171417
return false;
14181418
}
14191419

1420+
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1421+
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1422+
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
1423+
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
1424+
{
1425+
return false;
1426+
}
1427+
14201428
// Gridwise GEMM size
14211429
return true;
14221430
}

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
13591359
}
13601360
}
13611361

1362+
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1363+
if(!(karg.M * karg.K * sizeof(ADataType) <= TwoGB &&
1364+
karg.N * karg.K * sizeof(BDataType) <= TwoGB &&
1365+
karg.M * karg.N * sizeof(CDataType) <= TwoGB))
1366+
{
1367+
return false;
1368+
}
1369+
13621370
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
13631371
return true;
13641372
}

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
581581
return false;
582582
}
583583

584+
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
585+
586+
if(!(a_b_k0_m_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
587+
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB &&
588+
c_m_n_grid_desc.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
589+
{
590+
return false;
591+
}
592+
584593
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
585594
return true;
586595
}

0 commit comments

Comments
 (0)