@@ -305,7 +305,25 @@ def lib_linear():
305305 debug (d .actual_tensor (), ans , atol = atol , rtol = rtol )
306306
307307 assert torch .allclose (d .actual_tensor (), ans , atol = atol , rtol = rtol )
308-
308+ def profile_operation (name , func , device , num_prerun , num_iterations ):
309+ # Warm up
310+ for _ in range (num_prerun ):
311+ func ()
312+
313+ torch .cuda .synchronize ()
314+ start = torch .cuda .Event (enable_timing = True )
315+ end = torch .cuda .Event (enable_timing = True )
316+
317+ start .record ()
318+ for _ in range (num_iterations ):
319+ func ()
320+ end .record ()
321+
322+ torch .cuda .synchronize ()
323+ elapsed = start .elapsed_time (end )
324+ print (
325+ f"{ name } took { elapsed / num_iterations :.6f} ms over { num_iterations } iterations"
326+ )
309327 # Profiling workflow
310328 if PROFILE :
311329 # fmt: off
@@ -452,6 +470,25 @@ def lib_linear():
452470 )
453471
454472 lib_linear ()
473+ def profile_operation (name , func , device , num_prerun , num_iterations ):
474+ # Warm up
475+ for _ in range (num_prerun ):
476+ func ()
477+
478+ torch .cuda .synchronize ()
479+ start = torch .cuda .Event (enable_timing = True )
480+ end = torch .cuda .Event (enable_timing = True )
481+
482+ start .record ()
483+ for _ in range (num_iterations ):
484+ func ()
485+ end .record ()
486+
487+ torch .cuda .synchronize ()
488+ elapsed = start .elapsed_time (end )
489+ print (
490+ f"{ name } took { elapsed / num_iterations :.6f} ms over { num_iterations } iterations"
491+ )
455492 if PROFILE :
456493 # fmt: off
457494 profile_operation ("quant_linear" , lambda : lib_linear (), device , NUM_PRERUN , NUM_ITERATIONS )
0 commit comments