@@ -48,40 +48,30 @@ def import_module_from_file(file_path):
4848    return  module 
4949
5050
51- # type is read from a catelog entry, the value of a key "__type__" 
52- def  get_class_from_artifact_type (type : str ):
53-     if  type  in  Artifact ._class_register :
54-         return  Artifact ._class_register [type ]
51+ # snake_case_class_name is read from a catelog entry, the value of a key "__type__" 
52+ # this method replaces the Artifact._class_register lookup, for all unitxt classes defined 
53+ # top level in any of the src/unitxt/*.py modules, which are all the classes that were registered 
54+ # by register_all_artifacts 
55+ def  get_class_from_artifact_type (snake_case_class_name : str ):
56+     if  snake_case_class_name  in  Artifact ._class_register :
57+         return  Artifact ._class_register [snake_case_class_name ]
5558
5659    module_path , class_name  =  find_unitxt_module_and_class_by_classname (
57-         snake_to_camel_case (type )
60+         snake_to_camel_case (snake_case_class_name )
5861    )
59-     if  module_path  ==  "class_register" :
60-         if  class_name  not  in   Artifact ._class_register :
61-             raise  ValueError (
62-                 f"Can not instantiate a class from type { type }  , because { class_name }   is currently not registered in Artifact._class_register." 
63-             )
64-         return  Artifact ._class_register [class_name ]
6562
6663    module  =  importlib .import_module (module_path )
6764
68-     if  "."  not  in   class_name :
69-         if  hasattr (module , class_name ) and  inspect .isclass (getattr (module , class_name )):
70-             return  getattr (module , class_name )
71-         if  class_name  in  Artifact ._class_register :
72-             return  Artifact ._class_register [class_name ]
73-         module_file  =  module .__file__  if  hasattr (module , "__file__" ) else  None 
74-         if  module_file :
75-             module  =  import_module_from_file (module_file )
76- 
77-         assert  class_name  in  Artifact ._class_register 
78-         return  Artifact ._class_register [class_name ]
79- 
80-     class_name_components  =  class_name .split ("." )
81-     klass  =  getattr (module , class_name_components [0 ])
82-     for  i  in  range (1 , len (class_name_components )):
83-         klass  =  getattr (klass , class_name_components [i ])
84-     return  klass 
65+     if  hasattr (module , class_name ) and  inspect .isclass (getattr (module , class_name )):
66+         klass  =  getattr (module , class_name )
67+         Artifact ._class_register [
68+             snake_case_class_name 
69+         ] =  klass   # use _class_register as a cache 
70+         return  klass 
71+ 
72+     raise  ValueError (
73+         f"Could not find the definition of class whose name, snake-cased is { snake_case_class_name }  " 
74+     )
8575
8676
8777def  is_name_legal_for_catalog (name ):
0 commit comments