@@ -277,6 +277,105 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277277
278278    return  not_selected_op1_nodes 
279279
280+ def  custom_write_calibration_table (calibration_cache , dir = "." ):
281+     """ 
282+     Helper function to write calibration table to files. 
283+     """ 
284+ 
285+     import  json 
286+     import  logging 
287+     import  flatbuffers 
288+     import  numpy  as  np 
289+ 
290+     import  onnxruntime .quantization .CalTableFlatBuffers .KeyValue  as  KeyValue 
291+     import  onnxruntime .quantization .CalTableFlatBuffers .TrtTable  as  TrtTable 
292+     from  onnxruntime .quantization .calibrate  import  CalibrationMethod , TensorData , TensorsData 
293+ 
294+     logging .info (f"calibration cache: { calibration_cache }  )
295+ 
296+     class  MyEncoder (json .JSONEncoder ):
297+         def  default (self , obj ):
298+             if  isinstance (obj , (TensorData , TensorsData )):
299+                 return  obj .to_dict ()
300+             if  isinstance (obj , TensorDataWrapper ):
301+                 return  obj .data_dict 
302+             if  isinstance (obj , np .ndarray ):
303+                 return  {"data" : obj .tolist (), "dtype" : str (obj .dtype ), "CLS" : "numpy.array" }
304+             if  isinstance (obj , CalibrationMethod ):
305+                 return  {"CLS" : obj .__class__ .__name__ , "value" : str (obj )}
306+             return  json .JSONEncoder .default (self , obj )
307+ 
308+     json_data  =  json .dumps (calibration_cache , cls = MyEncoder )
309+ 
310+     with  open (os .path .join (dir , "calibration.json" ), "w" ) as  file :
311+         file .write (json_data )  # use `json.loads` to do the reverse 
312+ 
313+     # Serialize data using FlatBuffers 
314+     zero  =  np .array (0 )
315+     builder  =  flatbuffers .Builder (1024 )
316+     key_value_list  =  []
317+ 
318+     for  key  in  sorted (calibration_cache .keys ()):
319+         values  =  calibration_cache [key ]
320+         d_values  =  values .to_dict ()
321+ 
322+         highest  =  d_values .get ("highest" , zero )
323+         lowest  =  d_values .get ("lowest" , zero )
324+ 
325+         highest_val  =  highest .item () if  hasattr (highest , "item" ) else  float (highest )
326+         lowest_val  =  lowest .item () if  hasattr (lowest , "item" ) else  float (lowest )
327+ 
328+         floats  =  [float (highest_val ), float (lowest_val )]
329+ 
330+         value  =  str (max (floats ))
331+ 
332+         flat_key  =  builder .CreateString (key )
333+         flat_value  =  builder .CreateString (value )
334+ 
335+         KeyValue .KeyValueStart (builder )
336+         KeyValue .KeyValueAddKey (builder , flat_key )
337+         KeyValue .KeyValueAddValue (builder , flat_value )
338+         key_value  =  KeyValue .KeyValueEnd (builder )
339+ 
340+         key_value_list .append (key_value )
341+ 
342+ 
343+     TrtTable .TrtTableStartDictVector (builder , len (key_value_list ))
344+     for  key_value  in  key_value_list :
345+         builder .PrependUOffsetTRelative (key_value )
346+     main_dict  =  builder .EndVector ()
347+ 
348+     TrtTable .TrtTableStart (builder )
349+     TrtTable .TrtTableAddDict (builder , main_dict )
350+     cal_table  =  TrtTable .TrtTableEnd (builder )
351+ 
352+     builder .Finish (cal_table )
353+     buf  =  builder .Output ()
354+ 
355+     with  open (os .path .join (dir , "calibration.flatbuffers" ), "wb" ) as  file :
356+         file .write (buf )
357+ 
358+     # Deserialize data (for validation) 
359+     if  os .environ .get ("QUANTIZATION_DEBUG" , 0 ) in  (1 , "1" ):
360+         cal_table  =  TrtTable .TrtTable .GetRootAsTrtTable (buf , 0 )
361+         dict_len  =  cal_table .DictLength ()
362+         for  i  in  range (dict_len ):
363+             key_value  =  cal_table .Dict (i )
364+             logging .info (key_value .Key ())
365+             logging .info (key_value .Value ())
366+ 
367+     # write plain text 
368+     with  open (os .path .join (dir , "calibration.cache" ), "w" ) as  file :
369+         for  key  in  sorted (calibration_cache .keys ()):
370+             values  =  calibration_cache [key ]
371+             d_values  =  values .to_dict ()
372+             floats  =  [
373+                 float (d_values .get ("highest" , zero ).item ()),
374+                 float (d_values .get ("lowest" , zero ).item ()),
375+             ]
376+             value  =  key  +  " "  +  str (max (floats ))
377+             file .write (value )
378+             file .write ("\n " )
280379
281380def  parse_input_args ():
282381    parser  =  argparse .ArgumentParser ()
@@ -553,8 +652,42 @@ def output_run_config(flags, samples):
553652        for  k , v  in  compute_range .data .items ():
554653            json_compute_range [k ] =  (float (v .range_value [0 ]), float (v .range_value [1 ]))
555654
655+         print ("Writing calibration table" )
656+         try :
657+             write_calibration_table (json_compute_range )
658+         except  AttributeError  as  e :
659+             class  TensorDataWrapper :
660+                 def  __init__ (self , data_dict ):
661+                     self .data_dict  =  data_dict 
662+ 
663+                 def  to_dict (self ):
664+                     return  self .data_dict 
665+ 
666+                 def  __repr__ (self ):
667+                     return  repr (self .data_dict )
668+ 
669+                 def  __serializable__ (self ):
670+                     return  self .data_dict 
671+ 
672+             calibration_data  =  {}
673+             for  k , v  in  compute_range .data .items ():
674+                 if  hasattr (v , 'to_dict' ):
675+                     tensor_dict  =  v .to_dict ()
676+                     processed_dict  =  {}
677+                     for  dk , dv  in  tensor_dict .items ():
678+                         if  isinstance (dv , np .ndarray ):
679+                             processed_dict [dk ] =  dv .item () if  dv .size  ==  1  else  dv .tolist ()
680+                         elif  isinstance (dv , np .number ):
681+                             processed_dict [dk ] =  dv .item ()
682+                         else :
683+                             processed_dict [dk ] =  dv 
684+                     calibration_data [k ] =  TensorDataWrapper (processed_dict )
685+                 else :
686+                     calibration_data [k ] =  v 
687+ 
688+             print ("Using custom calibration table function" )
689+             custom_write_calibration_table (calibration_data )
556690
557-         write_calibration_table (json_compute_range )
558691        print ("Calibration is done. Calibration cache is saved to calibration.json" )
559692
560693        model_quants  =  model_quants  +  "_int8" 
0 commit comments