forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 14
/
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu
179 lines (152 loc) · 9.4 KB
/
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Ptr-Array GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x_ptr_array.hpp"
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
using namespace cute;
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
}
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(TestAll<Gemm>(1.0, 0.0));
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)