@@ -56,14 +56,27 @@ def __init__(
5656
5757 def get_input_iter (self ):
5858 def args (m , n , k ):
59- a = torch .randn (m , k , device = self .device ).to (torch .float8_e4m3fn )
59+ a = torch .randn (m , k , device = self .device ).to (torch .float16 )
6060 b = (
6161 torch .randn (k , n , device = self .device )
62- .to (torch .float8_e4m3fn )
62+ .to (torch .float16 )
6363 .T .contiguous ()
6464 .T
6565 )
66- return (a , b )
66+
67+ if self .extra_args .scaling_rowwise :
68+ M , N = a .shape [0 ], b .shape [1 ]
69+ scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
70+ scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
71+ else :
72+ scale_a = torch .tensor (1.0 , device = a .device )
73+ scale_b = torch .tensor (1.0 , device = a .device )
74+
75+ # Kernels expect dtype=float8_e4m3fn
76+ a = a .to (torch .float8_e4m3fn )
77+ b = b .to (torch .float8_e4m3fn )
78+
79+ return (a , b , scale_a , scale_b )
6780
6881 if (
6982 hasattr (self , "external_shapes" ) and self .external_shapes
@@ -90,62 +103,49 @@ def args(m, n, k):
90103 yield args (m , n , k )
91104
92105 def get_x_val (self , example_inputs ) -> float :
93- a , b = example_inputs
106+ a , b , _ , _ = example_inputs
94107 m , k = a .size ()
95108 _ , n = b .size ()
96109 return (m , n , k )
97110
98- @register_benchmark (baseline = True )
99- def torch_fp8_gemm (self , a , b ):
111+ def _get_out_dtype (self ):
100112 if self .extra_args .scaling_rowwise :
101- M , N = a .shape [0 ], b .shape [1 ]
102- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
103- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
104- out_dtype = torch .bfloat16
113+ return torch .bfloat16
105114 else :
106- scale_a = torch .tensor (1.0 , device = a .device )
107- scale_b = torch .tensor (1.0 , device = a .device )
108- out_dtype = torch .float16
115+ return torch .float16
109116
117+ @register_benchmark (baseline = True )
118+ def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
110119 return lambda : torch ._scaled_mm (
111- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
120+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
112121 )
113122
114123 @register_benchmark ()
115- def pt2_fp8_gemm (self , a , b ) -> Callable :
124+ def pt2_fp8_gemm (self , a , b , scale_a , scale_b ) -> Callable :
116125 torch ._dynamo .reset ()
117126 with inductor_config .patch (
118127 max_autotune = True ,
119128 max_autotune_gemm_backends = "TRITON" ,
120129 autotune_fallback_to_aten = False ,
121130 ):
122- if self .extra_args .scaling_rowwise :
123- M , N = a .shape [0 ], b .shape [1 ]
124- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
125- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
126- out_dtype = torch .bfloat16
127- else :
128- scale_a = torch .tensor (1.0 , device = a .device )
129- scale_b = torch .tensor (1.0 , device = b .device )
130- out_dtype = torch .float16
131131 f = lambda a , b : torch ._scaled_mm (
132- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
132+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
133133 )
134134 compiled = torch .compile (f , dynamic = False )
135135 compiled (a , b )
136136
137137 return lambda : compiled (a , b )
138138
139139 @register_benchmark ()
140- def triton_fp8_gemm (self , a , b ):
140+ def triton_fp8_gemm (self , a , b , scale_a , scale_b ):
141141 return lambda : tutorial_matmul (a , b )
142142
143143 @register_benchmark (enabled = HAS_TMA )
144- def triton_persistent_fp8_gemm (self , a , b ):
144+ def triton_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
145145 return lambda : matmul_persistent (a , b )
146146
147147 @register_benchmark (enabled = HAS_TMA )
148- def triton_tma_persistent_fp8_gemm (self , a , b ):
148+ def triton_tma_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
149149 b = b .T .contiguous ()
150150 c , desc_a , desc_b , desc_c = allocate_matmul_tma (a , b )
151151 return lambda : matmul_tma_persistent (a , b , c , desc_a , desc_b , desc_c )
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155155 def nbytes (t ):
156156 return t .numel () * t .element_size ()
157157
158- a , b = example_inputs
158+ a , b , _ , _ = example_inputs
159159 c = fn ()
160160 c = c [0 ] if isinstance (c , tuple ) else c
161161
@@ -168,7 +168,7 @@ def nbytes(t):
168168 def flops (
169169 self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
170170 ) -> float :
171- a , b = example_inputs
171+ a , b , _ , _ = example_inputs
172172 m , k = a .size ()
173173 _ , n = b .size ()
174174 flops = 2 * m * n * k
0 commit comments