Skip to content

Commit f3e6412

Browse files
committed
change each __type__ in catalog to full qualified name, rather than snake of class name, and remove _class_register altogether
Signed-off-by: dafnapension <[email protected]>
1 parent aac33a6 commit f3e6412

File tree

4,710 files changed

+26824
-26846
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

4,710 files changed

+26824
-26846
lines changed

docs/catalog.py

Lines changed: 18 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def imports_to_syntax_highlighted_html(subtypes: List[str])-> str:
5050
return ""
5151
module_to_class_names = defaultdict(list)
5252
for subtype in subtypes:
53-
subtype_class = Artifact._class_register.get(subtype)
54-
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)
53+
(module, class_name) = Artifact.get_module_class(subtype)
54+
module_to_class_names[module].append(class_name)
5555

5656
imports_txt = ""
5757
for modu in sorted(module_to_class_names.keys()):
@@ -101,31 +101,6 @@ def custom_walk(top):
101101
yield entry
102102

103103

104-
def all_subtypes_of_artifact(artifact):
105-
if (
106-
artifact is None
107-
or isinstance(artifact, str)
108-
or isinstance(artifact, bool)
109-
or isinstance(artifact, int)
110-
or isinstance(artifact, float)
111-
):
112-
return []
113-
if isinstance(artifact, list):
114-
to_return = []
115-
for art in artifact:
116-
to_return.extend(all_subtypes_of_artifact(art))
117-
return to_return
118-
# artifact is a dict
119-
to_return = []
120-
for key, value in artifact.items():
121-
if isinstance(value, str):
122-
if key == "__type__":
123-
to_return.append(value)
124-
else:
125-
to_return.extend(all_subtypes_of_artifact(value))
126-
return to_return
127-
128-
129104
def get_all_type_elements(nested_dict):
130105
type_elements = set()
131106

@@ -148,19 +123,18 @@ def recursive_search(d):
148123

149124
@lru_cache(maxsize=None)
150125
def artifact_type_to_link(artifact_type):
151-
artifact_class = Artifact._class_register.get(artifact_type)
152-
type_class_name = artifact_class.__name__
153-
artifact_class_id = f"{artifact_class.__module__}.{type_class_name}"
154-
return f'<a class="reference internal" href="../{artifact_class.__module__}.html#{artifact_class_id}" title="{artifact_class_id}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{type_class_name}</span></code></a>'
126+
artifact_module, artifact_class_name = Artifact.get_module_class(artifact_type)
127+
return f'<a class="reference internal" href="../{artifact_module}.html#{artifact_module}.{artifact_class_name}" title="{artifact_module}.{artifact_class_name}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{artifact_class_name}</span></code></a>'
155128

156129

157130
# flake8: noqa: C901
131+
132+
158133
def make_content(artifact, label, all_labels):
159-
artifact_type = artifact["__type__"]
160-
artifact_class = Artifact._class_register.get(artifact_type)
161-
type_class_name = artifact_class.__name__
162-
catalog_id = label.replace("catalog.", "")
134+
artifact_type = artifact["__type__"] #qualified class name
135+
artifact_class = Artifact.get_class_from_artifact_type(artifact_type)
163136

137+
catalog_id = label.replace("catalog.", "")
164138
result = ""
165139

166140
if "__description__" in artifact and artifact["__description__"] is not None:
@@ -203,23 +177,16 @@ def make_content(artifact, label, all_labels):
203177
)
204178

205179
for type_name in type_elements:
206-
# source = f'<span class="nt">__type__</span><span class="p">:</span><span class="w"> </span><span class="l l-Scalar l-Scalar-Plain">{type_name}</span>'
207-
source = f'<span class="n">__type__{type_name}</span><span class="p">'
208-
target = artifact_type_to_link(type_name)
209-
html_for_dict = html_for_dict.replace(
210-
source,
211-
f'<span class="n" STYLE="font-size:108%">{target}</span><span class="p">'
212-
# '<span class="nt">&quot;type&quot;</span><span class="p">:</span><span class="w"> </span>'
213-
# + target,
214-
)
215-
216-
pattern = r'(<span class="nt">)&quot;(.*?)&quot;(</span>)'
180+
artifact_module, artifact_class_name = Artifact.get_module_class(type_name)
181+
pattern = re.compile(f'<span class="n">__type__(.*?)<span class="n">{artifact_class_name}</span>')
182+
repl = '<span class="n" STYLE="font-size:108%">'+artifact_type_to_link(type_name)+"</span>"
183+
html_for_dict = pattern.sub(repl, html_for_dict)
217184

185+
# pattern = r'(<span class="nt">)&quot;(.*?)&quot;(</span>)'
218186
# Replacement function
219-
html_for_dict = re.sub(pattern, r"\1\2\3", html_for_dict)
187+
# html_for_dict = re.sub(pattern, r"\1\2\3", html_for_dict)
220188

221-
subtypes = all_subtypes_of_artifact(artifact)
222-
subtypes = list(set(subtypes))
189+
subtypes = type_elements
223190
subtypes.remove(artifact_type) # this was already documented
224191
html_for_imports = imports_to_syntax_highlighted_html(subtypes)
225192

@@ -235,13 +202,13 @@ def make_content(artifact, label, all_labels):
235202
result += " " + html_for_element + "\n"
236203

237204
if artifact_class.__doc__:
238-
explanation_str = f"Explanation about `{type_class_name}`"
205+
explanation_str = f"Explanation about `{artifact_class.__name__}`"
239206
result += f"\n{explanation_str}\n"
240207
result += "+" * len(explanation_str) + "\n\n"
241208
result += artifact_class.__doc__ + "\n"
242209

243210
for subtype in subtypes:
244-
subtype_class = Artifact._class_register.get(subtype)
211+
subtype_class = Artifact.get_class_from_artifact_type(subtype)
245212
subtype_class_name = subtype_class.__name__
246213
if subtype_class.__doc__:
247214
explanation_str = f"Explanation about `{subtype_class_name}`"

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def autodoc_skip_member(app, what, name, obj, would_skip, options):
115115
class_name = obj.__qualname__.split(".")[0]
116116
if (
117117
class_name
118-
and Artifact.is_registered_class_name(class_name)
119118
and class_name != name
120119
):
121120
return True

prepare/metrics/custom_f1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,4 +433,7 @@ class NERWithoutClassReporting(NER):
433433
global_target=global_target,
434434
)
435435

436-
add_to_catalog(metric, "metrics.ner", overwrite=True)
436+
if __name__ == "__main__" or __name__ == "custom_f1":
437+
# because a class is defined in this module, need to not add_to_catalog just for importing that module in order to retrieve the defined class
438+
# and need to prepare for case when this module is run directly from python (__main__) or, for example, from test_preparation (custom_f1)
439+
add_to_catalog(metric, "metrics.ner", overwrite=True)

src/unitxt/artifact.py

Lines changed: 76 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import difflib
21
import inspect
32
import json
43
import os
54
import pkgutil
65
import re
76
import warnings
87
from abc import abstractmethod
8+
from importlib import import_module
99
from typing import Any, Dict, List, Optional, Tuple, Union, final
1010

1111
from .dataclass import (
@@ -22,7 +22,7 @@
2222
separate_inside_and_outside_square_brackets,
2323
)
2424
from .settings_utils import get_constants, get_settings
25-
from .text_utils import camel_to_snake_case, is_camel_case, print_dict_as_yaml
25+
from .text_utils import print_dict_as_yaml
2626
from .type_utils import isoftype, issubtype
2727
from .utils import (
2828
artifacts_json_cache,
@@ -133,21 +133,11 @@ def maybe_recover_artifacts_structure(obj):
133133
return obj
134134

135135

136-
def get_closest_artifact_type(type):
137-
artifact_type_options = list(Artifact._class_register.keys())
138-
matches = difflib.get_close_matches(type, artifact_type_options)
139-
if matches:
140-
return matches[0] # Return the closest match
141-
return None
142-
143136

144137
class UnrecognizedArtifactTypeError(ValueError):
145138
def __init__(self, type) -> None:
146-
maybe_class = "".join(word.capitalize() for word in type.split("_"))
139+
maybe_class = type.split(".")[-1]
147140
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
148-
closest_artifact_type = get_closest_artifact_type(type)
149-
if closest_artifact_type is not None:
150-
message += f"\n\nDid you mean '{closest_artifact_type}'?"
151141
super().__init__(message)
152142

153143

@@ -160,7 +150,7 @@ def __init__(self, dic) -> None:
160150

161151

162152
class Artifact(Dataclass):
163-
_class_register = {}
153+
# _class_register = {}
164154

165155
__type__: str = Field(default=None, final=True, init=False)
166156
__title__: str = NonPositionalField(
@@ -200,38 +190,60 @@ def verify_artifact_dict(cls, d):
200190
)
201191
if "__type__" not in d:
202192
raise MissingArtifactTypeError(d)
203-
if not cls.is_registered_type(d["__type__"]):
204-
raise UnrecognizedArtifactTypeError(d["__type__"])
193+
# if not cls.is_registered_type(d["__type__"]):
194+
# raise UnrecognizedArtifactTypeError(d["__type__"])
195+
196+
@staticmethod
197+
def fix_module_name_if_not_in_path(module):
198+
module_name = getattr(module, "__name__", None)
199+
if not module_name:
200+
if getattr(module, "__file__", None):
201+
return module.__file__.split(os.sep)[-1].split(".")[0]
202+
return "dummy_module_name"
203+
if not getattr(module, "__file__", None):
204+
return module_name
205+
name_components = module.__name__.split(".")
206+
if all (name_component in module.__file__ for name_component in name_components):
207+
return module_name
208+
file_components = module.__file__.split(os.sep)
209+
if file_components[0] == "":
210+
file_components = file_components[1:]
211+
file_components[-1] = file_components[-1].split(".")[0] #omit the .py
212+
if not getattr(module, "__package__", None) or len(module.__package__) == 0:
213+
return file_components[-1]
214+
package_components = module.__package__.split(".")
215+
assert all(p_c in file_components for p_c in package_components)
216+
for i in range(len(file_components)-len(package_components)+1):
217+
if all(package_components[j] == file_components[i+j] for j in range(len(package_components))):
218+
if i==len(file_components)-len(package_components):
219+
return module.__package__
220+
return module.__package__ +"."+ (".".join(file_components[i+len(package_components):]))
221+
return "dummy_module_name"
205222

206223
@classmethod
207224
def get_artifact_type(cls):
208-
return camel_to_snake_case(cls.__name__)
225+
module = inspect.getmodule(cls)
226+
# standardize module name
227+
module_name = getattr(module, "__name__", None)
228+
module_package = getattr(module, "__package__", None)
229+
module_name = Artifact.fix_module_name_if_not_in_path(module)
230+
if module_package:
231+
if not module_name.startswith(module_package):
232+
module_name = module_package+"."+module_name
233+
if hasattr(cls, "__qualname__") and "." in cls.__qualname__:
234+
return module_name+"/"+cls.__qualname__
235+
return module_name+"."+cls.__name__
209236

210237
@classmethod
211-
def register_class(cls, artifact_class):
212-
assert issubclass(
213-
artifact_class, Artifact
214-
), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'"
215-
assert is_camel_case(
216-
artifact_class.__name__
217-
), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'"
218-
219-
snake_case_key = camel_to_snake_case(artifact_class.__name__)
220-
221-
if cls.is_registered_type(snake_case_key):
222-
assert (
223-
str(cls._class_register[snake_case_key]) == str(artifact_class)
224-
), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}."
225-
226-
return snake_case_key
238+
def get_module_class(cls, artifact_type:str):
239+
if "/" in artifact_type:
240+
return artifact_type.split("/")
241+
return artifact_type.rsplit(".", 1)
227242

228-
cls._class_register[snake_case_key] = artifact_class
229243

230-
return snake_case_key
231244

232245
def __init_subclass__(cls, **kwargs):
233246
super().__init_subclass__(**kwargs)
234-
cls.register_class(cls)
235247

236248
@classmethod
237249
def is_artifact_file(cls, path):
@@ -242,34 +254,32 @@ def is_artifact_file(cls, path):
242254
return cls.is_artifact_dict(d)
243255

244256
@classmethod
245-
def is_registered_type(cls, type: str):
246-
return type in cls._class_register
257+
def get_class_from_artifact_type(cls, type:str):
258+
module_path, class_name = cls.get_module_class(type)
259+
module = import_module(module_path)
260+
if "." not in class_name:
261+
return getattr(module, class_name)
262+
class_name_components = class_name.split(".")
263+
klass = getattr(module, class_name_components[0])
264+
for i in range (1, len(class_name_components)):
265+
klass = getattr(klass, class_name_components[i])
266+
return klass
247267

248-
@classmethod
249-
def is_registered_class_name(cls, class_name: str):
250-
snake_case_key = camel_to_snake_case(class_name)
251-
return cls.is_registered_type(snake_case_key)
252-
253-
@classmethod
254-
def is_registered_class(cls, clz: object):
255-
return clz in set(cls._class_register.values())
256268

257269
@classmethod
258270
def _recursive_load(cls, obj):
259271
if isinstance(obj, dict):
260-
new_d = {}
261-
for key, value in obj.items():
262-
new_d[key] = cls._recursive_load(value)
263-
obj = new_d
272+
obj = {key: cls._recursive_load(value) for key, value in obj.items()}
273+
if cls.is_artifact_dict(obj):
274+
try:
275+
artifact_type = obj.pop("__type__")
276+
artifact_class = cls.get_class_from_artifact_type(artifact_type)
277+
obj = artifact_class.process_data_after_load(obj)
278+
return artifact_class(**obj)
279+
except (ImportError, AttributeError) as e:
280+
raise UnrecognizedArtifactTypeError(artifact_type) from e
264281
elif isinstance(obj, list):
265-
obj = [cls._recursive_load(value) for value in obj]
266-
else:
267-
pass
268-
if cls.is_artifact_dict(obj):
269-
cls.verify_artifact_dict(obj)
270-
artifact_class = cls._class_register[obj.pop("__type__")]
271-
obj = artifact_class.process_data_after_load(obj)
272-
return artifact_class(**obj)
282+
return [cls._recursive_load(value) for value in obj]
273283

274284
return obj
275285

@@ -283,7 +293,7 @@ def from_dict(cls, d, overwrite_args=None):
283293
@classmethod
284294
def load(cls, path, artifact_identifier=None, overwrite_args=None):
285295
d = artifacts_json_cache(path)
286-
if "__type__" in d and d["__type__"] == "artifact_link":
296+
if "__type__" in d and d["__type__"].endswith("ArtifactLink"):
287297
cls.from_dict(d) # for verifications and warnings
288298
catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"])
289299
return catalog.get_with_overwrite(
@@ -329,7 +339,7 @@ def verify_data_classification_policy(self):
329339

330340
@final
331341
def __post_init__(self):
332-
self.__type__ = self.register_class(self.__class__)
342+
self.__type__ = self.__class__.get_artifact_type()
333343

334344
for field in fields(self):
335345
if issubtype(
@@ -347,14 +357,18 @@ def __post_init__(self):
347357

348358
def _to_raw_dict(self):
349359
return {
350-
"__type__": self.__type__,
360+
"__type__": self.__class__.get_artifact_type(),
351361
**self.process_data_before_dump(self._init_dict),
352362
}
353363

354364
def __deepcopy__(self, memo):
355365
if id(self) in memo:
356366
return memo[id(self)]
357-
new_obj = Artifact.from_dict(self.to_dict())
367+
try:
368+
new_obj = Artifact.from_dict(self.to_dict())
369+
except:
370+
# needed only for artifacts defined inline for testing etc. E.g. 'NERWithoutClassReporting'
371+
new_obj = self
358372
memo[id(self)] = new_obj
359373
return new_obj
360374

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"__type__": "artifact_link",
2+
"__type__": "unitxt.artifact.ArtifactLink",
33
"to": "augmentors.text.whitespace_prefix_suffix",
44
"__deprecated_msg__": "Artifact 'augmentors.augment_whitespace_prefix_and_suffix_task_input' is deprecated. Artifact 'augmentors.text.whitespace_prefix_suffix' will be instantiated instead. In future uses, please reference artifact 'augmentors.text.whitespace_prefix_suffix' directly."
55
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"__type__": "artifact_link",
2+
"__type__": "unitxt.artifact.ArtifactLink",
33
"to": "augmentors.text.whitespace_prefix_suffix",
44
"__deprecated_msg__": "Artifact 'augmentors.augment_whitespace_task_input' is deprecated. Artifact 'augmentors.text.whitespace_prefix_suffix' will be instantiated instead. In future uses, please reference artifact 'augmentors.text.whitespace_prefix_suffix' directly."
55
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "gray_scale"
2+
"__type__": "unitxt.image_operators.GrayScale"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "grid_lines"
2+
"__type__": "unitxt.image_operators.GridLines"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "oldify"
2+
"__type__": "unitxt.image_operators.Oldify"
33
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "pixel_noise"
2+
"__type__": "unitxt.image_operators.PixelNoise"
33
}

0 commit comments

Comments
 (0)