@@ -519,6 +519,31 @@ def _print_graph(self):
519519 if hasattr (mod , "graph" ) and isinstance (mod .graph , torch .fx .Graph ):
520520 print (mod .graph )
521521
522+ def _adapt_flat_args (self , flat_args , in_spec ):
523+ signature = self .module_call_graph [0 ].signature
524+ if in_spec == signature .in_spec :
525+ return flat_args
526+
527+ if self .flat_args_adapter is None :
528+ raise TypeError (
529+ "There is no flat args adapter sepcified. "
530+ "Are you sure you are calling this with the right arguments? "
531+ )
532+ else :
533+ flat_args = self .flat_args_adapter .adapt (
534+ target_spec = signature .in_spec ,
535+ input_spec = in_spec ,
536+ input_args = flat_args ,
537+ )
538+
539+ if len (flat_args ) != signature .in_spec .num_leaves :
540+ raise TypeError (
541+ f"Flat args adaption failed, number of args mismatch "
542+ f"Adatped: { len (flat_args )} \n "
543+ f"Exported module: { signature .in_spec .num_leaves } "
544+ )
545+ return flat_args
546+
522547 def forward (self , * args , ** kwargs ):
523548 signature = self .module_call_graph [0 ].signature
524549
@@ -544,26 +569,9 @@ def forward(self, *args, **kwargs):
544569 f"Input treespec: { in_spec } . " ,
545570 f"Exported module treespec: { signature .in_spec } " ,
546571 )
547- if self .flat_args_adapter is None :
548- raise TypeError (
549- "There is no flat args adapter sepcified. "
550- "Are you sure you are calling this with the right arguments? "
551- )
552- else :
553- if not self .adapted :
554- print ("Adapting flat arg to match exported module's treespec" )
555- flat_args = self .flat_args_adapter .adapt (
556- target_spec = signature .in_spec ,
557- input_spec = in_spec ,
558- input_args = flat_args ,
559- )
560- self .adapted = True
561- if len (flat_args ) != signature .in_spec .num_leaves :
562- raise TypeError (
563- f"Flat args adaption failed, number of args mismatch "
564- f"Adatped: { len (flat_args )} \n "
565- f"Exported module: { signature .in_spec .num_leaves } "
566- )
572+ print ("Adapting flat arg to match exported module's treespec" )
573+ flat_args = self ._adapt_flat_args (flat_args , in_spec )
574+ self .adapted = True
567575
568576 if self .check_input_constraints :
569577 # Import here to avoid an unfortunate circular dependency.
0 commit comments