Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
e5eb4ed
Add device operation to conv signature. Use unions to hold conv layou…
vpietila-amd Oct 29, 2025
74ba32e
Add predicates for all device op instances.
vpietila-amd Oct 29, 2025
fbdded6
Use the device op signature for validation.
vpietila-amd Oct 29, 2025
7b14dde
Merge branch 'develop' into vpietila/ckb-generalize-conv-factory
vpietila-amd Oct 29, 2025
ee13982
Fix ckb CMakeLists.txt file for tests.
vpietila-amd Oct 30, 2025
28e0d5f
Fix building CK Builder instance traits after the introduction of dir…
vpietila-amd Oct 30, 2025
e129843
Merge branch 'vpietila/fix-building-ckb-tests' into vpietila/ckb-gene…
vpietila-amd Oct 30, 2025
c8eac6f
Fix clang-formatting.
vpietila-amd Oct 30, 2025
d0d33d9
Add factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device op.
vpietila-amd Oct 30, 2025
c1609ff
Add conv factory for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
vpietila-amd Oct 31, 2025
59c619c
Rename elements per wave per shuffle member in the epilogue concept.
vpietila-amd Oct 31, 2025
e77fdbe
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-add-r…
vpietila-amd Oct 31, 2025
f6c5cc3
clang-format
vpietila-amd Oct 31, 2025
795f03a
Add concepts and types for optional device op template parameters.
vpietila-amd Oct 31, 2025
e9b9b49
Add optional compute, direct load, and loop scheduler arguments to co…
vpietila-amd Oct 31, 2025
31c432e
Add number of groups to merge template parameter.
vpietila-amd Oct 31, 2025
bac14ad
clang-format.
vpietila-amd Oct 31, 2025
fc0b1ff
Merge branch 'develop' into vpietila/ckb-add-remining-fwd-conv-device…
vpietila-amd Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ concept ThreadBlockDescriptor = requires(T t) {
{ t.tile_size.k } -> std::convertible_to<size_t>;
};

// Concept for parameters that describe a gridwise GEMM problem.
// Concept for parameters that describe a gridwise XDL GEMM problem.
template <typename T>
concept GridwiseGemmDescriptor = requires(T t) {
concept GridwiseXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> std::convertible_to<size_t>;
Expand All @@ -35,6 +35,24 @@ concept GridwiseGemmDescriptor = requires(T t) {
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
};

// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
};

// Concept for parameters that describe a gridwise WMMA GEMM problem.
template <typename T>
concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m_per_wmma } -> std::convertible_to<size_t>;
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
};

// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
Expand Down Expand Up @@ -66,8 +84,8 @@ concept LdsTransferDescriptor = requires(T t) {
// LDS).
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.m_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
};

Expand All @@ -77,7 +95,7 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};

// No requirements yet for a ConvAlogorithm concept.
// No requirements yet for a ConvAlgorithm concept.
template <typename T>
concept ConvAlgorithmDescriptor = std::is_class_v<T>;

Expand All @@ -91,10 +109,16 @@ concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};

// Concept to check if a struct specifies gridwise GEMM info.
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseGemm = requires {
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
concept SpecifiesGridwiseXdlGemm = requires {
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
};

// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires {
{ T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
};

// Concept to check if a struct specifies convolution input and output block transfer info.
Expand Down Expand Up @@ -127,15 +151,36 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
};

// Concept to check if struct specifies block_gemm_pipeline_version.
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesGemmPipelineVersion = requires {
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
};

template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};

template <typename T>
concept SpecifiesGemmSpecialization = requires {
{ T::gemm_specialization } -> std::convertible_to<GemmSpecialization>;
};

template <typename T>
concept SpecifiesNumPrefetchStages = requires {
{ T::num_gemm_k_prefetch_stages } -> std::convertible_to<size_t>;
};

template <typename T>
concept SpecifiesNumGroupsToMerge = requires {
{ T::num_groups_to_merge } -> std::convertible_to<size_t>;
};

template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
};

} // namespace ck_tile::builder
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ concept InputVectorTransferLimits = requires {
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
Value.n_xdl_per_wave_per_shuffle > 0;
requires Value.scalar_per_vector > 0 && Value.m_per_wave_per_shuffle > 0 &&
Value.n_per_wave_per_shuffle > 0;
};

// Limits for access order. Must be a permutation of {0, 1, 2}.
Expand Down
Loading