44from fbscribelogger import make_scribe_logger
55
66import torch ._C ._instruction_counter as i_counter
7+ import torch ._dynamo .config as config
8+ from torch ._dynamo .utils import CompileTimeInstructionCounter
79
810
911scribe_log_torch_benchmark_compile_time = make_scribe_logger (
5153
5254
5355class BenchmarkBase (ABC ):
54- _instruction_count = False
56+ # measure total number of instruction spent in _work.
57+ _enable_instruction_count = False
58+
59+ # measure total number of instruction spent in convert_frame.compile_inner
60+ # TODO is there other parts we need to add ?
61+ _enable_compile_time_instruction_count = False
5562
5663 def enable_instruction_count (self ):
57- self ._instruction_count = True
64+ self ._enable_instruction_count = True
65+ return self
66+
67+ def enable_compile_time_instruction_count (self ):
68+ self ._enable_compile_time_instruction_count = True
5869 return self
5970
6071 def name (self ):
@@ -64,29 +75,44 @@ def description(self):
6475 return ""
6576
6677 @abstractmethod
67- def prepare (self ):
78+ def _prepare (self ):
6879 pass
6980
7081 @abstractmethod
71- def work (self ):
82+ def _work (self ):
7283 pass
7384
74- def prepare_once (self ): # noqa: B027
85+ def _prepare_once (self ): # noqa: B027
7586 pass
7687
77- def count_instructions (self ):
88+ def _count_instructions (self ):
7889 print (f"collecting instruction count for { self .name ()} " )
79- self .prepare_once ()
80-
8190 results = []
8291 for i in range (10 ):
83- self .prepare ()
92+ self ._prepare ()
8493 id = i_counter .start ()
85- self .work ()
94+ self ._work ()
8695 count = i_counter .end (id )
8796 print (f"instruction count for iteration { i } is { count } " )
88- if i != 0 :
89- results .append (count )
97+ results .append (count )
98+ return min (results )
99+
100+ def _count_compile_time_instructions (self ):
101+ print (f"collecting compile time instruction count for { self .name ()} " )
102+ config .record_compile_time_instruction_count = True
103+
104+ results = []
105+ for i in range (10 ):
106+ self ._prepare ()
107+ # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
108+ # hence this will only count instruction count spent in compile_inner.
109+ CompileTimeInstructionCounter .clear ()
110+ self ._work ()
111+ count = CompileTimeInstructionCounter .value ()
112+ print (f"compile time instruction count for iteration { i } is { count } " )
113+ results .append (count )
114+
115+ config .record_compile_time_instruction_count = False
90116 return min (results )
91117
92118 def append_results (self , path ):
@@ -102,12 +128,36 @@ def print(self):
102128 print (f"{ entry [0 ]} ,{ entry [1 ]} ,{ entry [2 ]} " )
103129
104130 def collect_all (self ):
131+ self ._prepare_once ()
105132 self .results = []
106- if self ._instruction_count :
107- r = self .count_instructions ()
133+ if (
134+ self ._enable_instruction_count
135+ and self ._enable_compile_time_instruction_count
136+ ):
137+ raise RuntimeError (
138+ "not supported until we update the logger, both logs to the same field now"
139+ )
140+
141+ if self ._enable_instruction_count :
142+ r = self ._count_instructions ()
108143 self .results .append ((self .name (), "instruction_count" , r ))
109144 scribe_log_torch_benchmark_compile_time (
110145 name = self .name (),
111146 instruction_count = r ,
112147 )
148+ if self ._enable_compile_time_instruction_count :
149+ r = self ._count_compile_time_instructions ()
150+
151+ self .results .append (
152+ (
153+ self .name (),
154+ "compile_time_instruction_count" ,
155+ r ,
156+ )
157+ )
158+ # TODO add a new field compile_time_instruction_count to the logger.
159+ scribe_log_torch_benchmark_compile_time (
160+ name = self .name (),
161+ instruction_count = r ,
162+ )
113163 return self
0 commit comments