1- import difflib
21import inspect
32import json
43import os
2322 separate_inside_and_outside_square_brackets ,
2423)
2524from .settings_utils import get_constants , get_settings
26- from .text_utils import is_camel_case
2725from .type_utils import isoftype , issubtype
2826from .utils import (
2927 artifacts_json_cache ,
@@ -134,21 +132,11 @@ def maybe_recover_artifacts_structure(obj):
134132 return obj
135133
136134
137- def get_closest_artifact_type (type ):
138- artifact_type_options = list (Artifact ._class_register .keys ())
139- matches = difflib .get_close_matches (type , artifact_type_options )
140- if matches :
141- return matches [0 ] # Return the closest match
142- return None
143-
144135
145136class UnrecognizedArtifactTypeError (ValueError ):
146137 def __init__ (self , type ) -> None :
147138 maybe_class = type .split ("." )[- 1 ]
148139 message = f"'{ type } ' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{ maybe_class } ' or similar) is defined and/or imported anywhere in the code executed."
149- closest_artifact_type = get_closest_artifact_type (type )
150- if closest_artifact_type is not None :
151- message += f"\n \n Did you mean '{ closest_artifact_type } '?"
152140 super ().__init__ (message )
153141
154142
@@ -161,7 +149,7 @@ def __init__(self, dic) -> None:
161149
162150
163151class Artifact (Dataclass ):
164- _class_register = {}
152+ # _class_register = {}
165153
166154 __type__ : str = Field (default = None , final = True , init = False )
167155 __title__ : str = NonPositionalField (
@@ -252,29 +240,9 @@ def get_module_class(cls, artifact_type:str):
252240 return artifact_type .rsplit ("." , 1 )
253241
254242
255- @classmethod
256- def register_class (cls , artifact_class ):
257- assert issubclass (
258- artifact_class , Artifact
259- ), f"Artifact class must be a subclass of Artifact, got '{ artifact_class } '"
260- assert is_camel_case (
261- artifact_class .__name__
262- ), f"Artifact class name must be legal camel case, got '{ artifact_class .__name__ } '"
263-
264- if cls .is_registered_type (cls .get_artifact_type ()):
265- assert (
266- str (cls ._class_register [cls .get_artifact_type ()]) == cls .get_artifact_type ()
267- ), f"Artifact class name must be unique, '{ cls .get_artifact_type ()} ' is already registered as { cls ._class_register [cls .get_artifact_type ()]} . Cannot be overridden by { artifact_class } ."
268-
269- return cls .get_artifact_type ()
270-
271- cls ._class_register [cls .get_artifact_type ()] = cls .get_artifact_type () # for now, still maintain the registry from qualified to qualified
272-
273- return cls .get_artifact_type ()
274243
275244 def __init_subclass__ (cls , ** kwargs ):
276245 super ().__init_subclass__ (** kwargs )
277- cls .register_class (cls )
278246
279247 @classmethod
280248 def is_artifact_file (cls , path ):
@@ -284,18 +252,6 @@ def is_artifact_file(cls, path):
284252 d = json .load (f )
285253 return cls .is_artifact_dict (d )
286254
287- @classmethod
288- def is_registered_type (cls , type : str ):
289- return type in cls ._class_register
290-
291- @classmethod
292- def is_registered_class_name (cls , class_name : str ):
293- for k in cls ._class_register :
294- _ , artifact_class_name = cls .get_module_class (k )
295- if artifact_class_name == class_name :
296- return True
297- return False
298-
299255 @classmethod
300256 def get_class_from_artifact_type (cls , type :str ):
301257 module_path , class_name = cls .get_module_class (type )
@@ -313,23 +269,17 @@ def get_class_from_artifact_type(cls, type:str):
313269 @classmethod
314270 def _recursive_load (cls , obj ):
315271 if isinstance (obj , dict ):
316- new_d = {}
317- for key , value in obj .items ():
318- new_d [key ] = cls ._recursive_load (value )
319- obj = new_d
272+ obj = {key : cls ._recursive_load (value ) for key , value in obj .items ()}
273+ if cls .is_artifact_dict (obj ):
274+ try :
275+ artifact_type = obj .pop ("__type__" )
276+ artifact_class = cls .get_class_from_artifact_type (artifact_type )
277+ obj = artifact_class .process_data_after_load (obj )
278+ return artifact_class (** obj )
279+ except (ImportError , AttributeError ) as e :
280+ raise UnrecognizedArtifactTypeError (artifact_type ) from e
320281 elif isinstance (obj , list ):
321- obj = [cls ._recursive_load (value ) for value in obj ]
322- else :
323- pass
324- if cls .is_artifact_dict (obj ):
325- cls .verify_artifact_dict (obj )
326- try :
327- artifact_type = obj .pop ("__type__" )
328- artifact_class = cls .get_class_from_artifact_type (artifact_type )
329- obj = artifact_class .process_data_after_load (obj )
330- return artifact_class (** obj )
331- except (ImportError , AttributeError ) as e :
332- raise UnrecognizedArtifactTypeError (artifact_type ) from e
282+ return [cls ._recursive_load (value ) for value in obj ]
333283
334284 return obj
335285
@@ -389,7 +339,7 @@ def verify_data_classification_policy(self):
389339
390340 @final
391341 def __post_init__ (self ):
392- self .__type__ = self .register_class ( self . __class__ )
342+ self .__type__ = self .__class__ . get_artifact_type ( )
393343
394344 for field in fields (self ):
395345 if issubtype (
0 commit comments