11import argparse
2+ import csv
23import logging
34
45from typing import Any , Callable , List , Optional
@@ -41,9 +42,28 @@ def parse_args(args):
4142 parser .add_argument ("--m" , type = int )
4243 parser .add_argument ("--k" , type = int )
4344 parser .add_argument ("--n" , type = int )
45+ parser .add_argument ("--filepath" , type = str , default = None )
4446 return parser .parse_args (args )
4547
4648
49+ def read_fp8_shapes (filepath ):
50+ fp8_shapes = []
51+ try :
52+ with open (filepath , "r" , newline = "" ) as csvfile :
53+ filtered_lines = (
54+ line
55+ for line in csvfile
56+ if line .strip () and not line .lstrip ().startswith ("#" )
57+ )
58+ reader = csv .reader (filtered_lines )
59+ for row in reader :
60+ fp8_shapes .append (tuple (map (int , row )))
61+ except Exception as e :
62+ logger .error (f"Failed to read fp8 shapes from { filepath } : { e } " )
63+ raise e
64+ return fp8_shapes
65+
66+
4767class Operator (BenchmarkOperator ):
4868 DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
4969 DEFAULT_PRECISION = "fp8"
@@ -70,6 +90,10 @@ def args(m, n, k):
7090 yield args (m , n , k )
7191 elif self .extra_args .m :
7292 yield args (self .extra_args .m , self .extra_args .n , self .extra_args .k )
93+ elif self .extra_args .filepath :
94+ fp8_shapes = read_fp8_shapes (self .extra_args .filepath )
95+ for m , n , k in fp8_shapes :
96+ yield args (m , n , k )
7397 else :
7498 for i in range (10 , 15 ):
7599 for j in range (0 , 4 ):
@@ -114,8 +138,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable:
114138 scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
115139 out_dtype = torch .bfloat16
116140 else :
117- scale_a = torch .tensor (1.0 , device = a .device )
118- scale_b = torch .tensor (1.0 , device = a .device )
141+ scale_a = torch .tensor (1.0 , dtype = torch . float32 , device = a .device )
142+ scale_b = torch .tensor (1.0 , dtype = torch . float32 , device = b .device )
119143 out_dtype = torch .float16
120144 f = lambda a , b : torch ._scaled_mm (
121145 a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
0 commit comments