1- import difflib
21import inspect
32import json
43import os
54import pkgutil
65import re
76import warnings
87from abc import abstractmethod
8+ from importlib import import_module
99from typing import Any , Dict , List , Optional , Tuple , Union , final
1010
1111from .dataclass import (
2222 separate_inside_and_outside_square_brackets ,
2323)
2424from .settings_utils import get_constants , get_settings
25- from .text_utils import camel_to_snake_case , is_camel_case , print_dict_as_yaml
25+ from .text_utils import print_dict_as_yaml
2626from .type_utils import isoftype , issubtype
2727from .utils import (
2828 artifacts_json_cache ,
@@ -133,21 +133,11 @@ def maybe_recover_artifacts_structure(obj):
133133 return obj
134134
135135
136- def get_closest_artifact_type (type ):
137- artifact_type_options = list (Artifact ._class_register .keys ())
138- matches = difflib .get_close_matches (type , artifact_type_options )
139- if matches :
140- return matches [0 ] # Return the closest match
141- return None
142-
143136
144137class UnrecognizedArtifactTypeError (ValueError ):
145138 def __init__ (self , type ) -> None :
146- maybe_class = "" . join ( word . capitalize () for word in type .split ("_" ))
139+ maybe_class = type .split ("." )[ - 1 ]
147140 message = f"'{ type } ' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{ maybe_class } ' or similar) is defined and/or imported anywhere in the code executed."
148- closest_artifact_type = get_closest_artifact_type (type )
149- if closest_artifact_type is not None :
150- message += f"\n \n Did you mean '{ closest_artifact_type } '?"
151141 super ().__init__ (message )
152142
153143
@@ -160,7 +150,7 @@ def __init__(self, dic) -> None:
160150
161151
162152class Artifact (Dataclass ):
163- _class_register = {}
153+ # _class_register = {}
164154
165155 __type__ : str = Field (default = None , final = True , init = False )
166156 __title__ : str = NonPositionalField (
@@ -200,38 +190,60 @@ def verify_artifact_dict(cls, d):
200190 )
201191 if "__type__" not in d :
202192 raise MissingArtifactTypeError (d )
203- if not cls .is_registered_type (d ["__type__" ]):
204- raise UnrecognizedArtifactTypeError (d ["__type__" ])
193+ # if not cls.is_registered_type(d["__type__"]):
194+ # raise UnrecognizedArtifactTypeError(d["__type__"])
195+
196+ @staticmethod
197+ def fix_module_name_if_not_in_path (module ):
198+ module_name = getattr (module , "__name__" , None )
199+ if not module_name :
200+ if getattr (module , "__file__" , None ):
201+ return module .__file__ .split (os .sep )[- 1 ].split ("." )[0 ]
202+ return "dummy_module_name"
203+ if not getattr (module , "__file__" , None ):
204+ return module_name
205+ name_components = module .__name__ .split ("." )
206+ if all (name_component in module .__file__ for name_component in name_components ):
207+ return module_name
208+ file_components = module .__file__ .split (os .sep )
209+ if file_components [0 ] == "" :
210+ file_components = file_components [1 :]
211+ file_components [- 1 ] = file_components [- 1 ].split ("." )[0 ] #omit the .py
212+ if not getattr (module , "__package__" , None ) or len (module .__package__ ) == 0 :
213+ return file_components [- 1 ]
214+ package_components = module .__package__ .split ("." )
215+ assert all (p_c in file_components for p_c in package_components )
216+ for i in range (len (file_components )- len (package_components )+ 1 ):
217+ if all (package_components [j ] == file_components [i + j ] for j in range (len (package_components ))):
218+ if i == len (file_components )- len (package_components ):
219+ return module .__package__
220+ return module .__package__ + "." + ("." .join (file_components [i + len (package_components ):]))
221+ return "dummy_module_name"
205222
206223 @classmethod
207224 def get_artifact_type (cls ):
208- return camel_to_snake_case (cls .__name__ )
225+ module = inspect .getmodule (cls )
226+ # standardize module name
227+ module_name = getattr (module , "__name__" , None )
228+ module_package = getattr (module , "__package__" , None )
229+ module_name = Artifact .fix_module_name_if_not_in_path (module )
230+ if module_package :
231+ if not module_name .startswith (module_package ):
232+ module_name = module_package + "." + module_name
233+ if hasattr (cls , "__qualname__" ) and "." in cls .__qualname__ :
234+ return module_name + "/" + cls .__qualname__
235+ return module_name + "." + cls .__name__
209236
210237 @classmethod
211- def register_class (cls , artifact_class ):
212- assert issubclass (
213- artifact_class , Artifact
214- ), f"Artifact class must be a subclass of Artifact, got '{ artifact_class } '"
215- assert is_camel_case (
216- artifact_class .__name__
217- ), f"Artifact class name must be legal camel case, got '{ artifact_class .__name__ } '"
218-
219- snake_case_key = camel_to_snake_case (artifact_class .__name__ )
220-
221- if cls .is_registered_type (snake_case_key ):
222- assert (
223- str (cls ._class_register [snake_case_key ]) == str (artifact_class )
224- ), f"Artifact class name must be unique, '{ snake_case_key } ' already exists for { cls ._class_register [snake_case_key ]} . Cannot be overridden by { artifact_class } ."
225-
226- return snake_case_key
238+ def get_module_class (cls , artifact_type :str ):
239+ if "/" in artifact_type :
240+ return artifact_type .split ("/" )
241+ return artifact_type .rsplit ("." , 1 )
227242
228- cls ._class_register [snake_case_key ] = artifact_class
229243
230- return snake_case_key
231244
232245 def __init_subclass__ (cls , ** kwargs ):
233246 super ().__init_subclass__ (** kwargs )
234- cls .register_class (cls )
235247
236248 @classmethod
237249 def is_artifact_file (cls , path ):
@@ -242,34 +254,32 @@ def is_artifact_file(cls, path):
242254 return cls .is_artifact_dict (d )
243255
244256 @classmethod
245- def is_registered_type (cls , type : str ):
246- return type in cls ._class_register
257+ def get_class_from_artifact_type (cls , type :str ):
258+ module_path , class_name = cls .get_module_class (type )
259+ module = import_module (module_path )
260+ if "." not in class_name :
261+ return getattr (module , class_name )
262+ class_name_components = class_name .split ("." )
263+ klass = getattr (module , class_name_components [0 ])
264+ for i in range (1 , len (class_name_components )):
265+ klass = getattr (klass , class_name_components [i ])
266+ return klass
247267
248- @classmethod
249- def is_registered_class_name (cls , class_name : str ):
250- snake_case_key = camel_to_snake_case (class_name )
251- return cls .is_registered_type (snake_case_key )
252-
253- @classmethod
254- def is_registered_class (cls , clz : object ):
255- return clz in set (cls ._class_register .values ())
256268
257269 @classmethod
258270 def _recursive_load (cls , obj ):
259271 if isinstance (obj , dict ):
260- new_d = {}
261- for key , value in obj .items ():
262- new_d [key ] = cls ._recursive_load (value )
263- obj = new_d
272+ obj = {key : cls ._recursive_load (value ) for key , value in obj .items ()}
273+ if cls .is_artifact_dict (obj ):
274+ try :
275+ artifact_type = obj .pop ("__type__" )
276+ artifact_class = cls .get_class_from_artifact_type (artifact_type )
277+ obj = artifact_class .process_data_after_load (obj )
278+ return artifact_class (** obj )
279+ except (ImportError , AttributeError ) as e :
280+ raise UnrecognizedArtifactTypeError (artifact_type ) from e
264281 elif isinstance (obj , list ):
265- obj = [cls ._recursive_load (value ) for value in obj ]
266- else :
267- pass
268- if cls .is_artifact_dict (obj ):
269- cls .verify_artifact_dict (obj )
270- artifact_class = cls ._class_register [obj .pop ("__type__" )]
271- obj = artifact_class .process_data_after_load (obj )
272- return artifact_class (** obj )
282+ return [cls ._recursive_load (value ) for value in obj ]
273283
274284 return obj
275285
@@ -283,7 +293,7 @@ def from_dict(cls, d, overwrite_args=None):
283293 @classmethod
284294 def load (cls , path , artifact_identifier = None , overwrite_args = None ):
285295 d = artifacts_json_cache (path )
286- if "__type__" in d and d ["__type__" ] == "artifact_link" :
296+ if "__type__" in d and d ["__type__" ]. endswith ( "ArtifactLink" ) :
287297 cls .from_dict (d ) # for verifications and warnings
288298 catalog , artifact_rep , _ = get_catalog_name_and_args (name = d ["to" ])
289299 return catalog .get_with_overwrite (
@@ -329,7 +339,7 @@ def verify_data_classification_policy(self):
329339
330340 @final
331341 def __post_init__ (self ):
332- self .__type__ = self .register_class ( self . __class__ )
342+ self .__type__ = self .__class__ . get_artifact_type ( )
333343
334344 for field in fields (self ):
335345 if issubtype (
@@ -347,14 +357,18 @@ def __post_init__(self):
347357
348358 def _to_raw_dict (self ):
349359 return {
350- "__type__" : self .__type__ ,
360+ "__type__" : self .__class__ . get_artifact_type () ,
351361 ** self .process_data_before_dump (self ._init_dict ),
352362 }
353363
354364 def __deepcopy__ (self , memo ):
355365 if id (self ) in memo :
356366 return memo [id (self )]
357- new_obj = Artifact .from_dict (self .to_dict ())
367+ try :
368+ new_obj = Artifact .from_dict (self .to_dict ())
369+ except :
370+ # needed only for artifacts defined inline for testing etc. E.g. 'NERWithoutClassReporting'
371+ new_obj = self
358372 memo [id (self )] = new_obj
359373 return new_obj
360374
0 commit comments