Skip to content

Commit 16124e7

Browse files
committed
use _class_register as a cache for instantiated classes
Signed-off-by: dafnapension <[email protected]>
1 parent 7f0f502 commit 16124e7

File tree

2 files changed

+19
-29
lines changed

2 files changed

+19
-29
lines changed

src/unitxt/artifact.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,40 +48,30 @@ def import_module_from_file(file_path):
4848
return module
4949

5050

51-
# type is read from a catelog entry, the value of a key "__type__"
52-
def get_class_from_artifact_type(type: str):
53-
if type in Artifact._class_register:
54-
return Artifact._class_register[type]
51+
# snake_case_class_name is read from a catelog entry, the value of a key "__type__"
52+
# this method replaces the Artifact._class_register lookup, for all unitxt classes defined
53+
# top level in any of the src/unitxt/*.py modules, which are all the classes that were registered
54+
# by register_all_artifacts
55+
def get_class_from_artifact_type(snake_case_class_name: str):
56+
if snake_case_class_name in Artifact._class_register:
57+
return Artifact._class_register[snake_case_class_name]
5558

5659
module_path, class_name = find_unitxt_module_and_class_by_classname(
57-
snake_to_camel_case(type)
60+
snake_to_camel_case(snake_case_class_name)
5861
)
59-
if module_path == "class_register":
60-
if class_name not in Artifact._class_register:
61-
raise ValueError(
62-
f"Can not instantiate a class from type {type}, because {class_name} is currently not registered in Artifact._class_register."
63-
)
64-
return Artifact._class_register[class_name]
6562

6663
module = importlib.import_module(module_path)
6764

68-
if "." not in class_name:
69-
if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)):
70-
return getattr(module, class_name)
71-
if class_name in Artifact._class_register:
72-
return Artifact._class_register[class_name]
73-
module_file = module.__file__ if hasattr(module, "__file__") else None
74-
if module_file:
75-
module = import_module_from_file(module_file)
76-
77-
assert class_name in Artifact._class_register
78-
return Artifact._class_register[class_name]
79-
80-
class_name_components = class_name.split(".")
81-
klass = getattr(module, class_name_components[0])
82-
for i in range(1, len(class_name_components)):
83-
klass = getattr(klass, class_name_components[i])
84-
return klass
65+
if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)):
66+
klass = getattr(module, class_name)
67+
Artifact._class_register[
68+
snake_case_class_name
69+
] = klass # use _class_register as a cache
70+
return klass
71+
72+
raise ValueError(
73+
f"Could not find the definition of class whose name, snake-cased is {snake_case_class_name}"
74+
)
8575

8676

8777
def is_name_legal_for_catalog(name):

tests/library/test_artifact_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class DummyShouldBeRegistered(Artifact):
99
pass
1010

1111
# assert Artifact.is_registered_type("dummy_should_be_registered")
12-
# assert Artifact.is_registered_class(DummyShouldBeRegistered)
12+
assert "dummy_should_be_registered" in Artifact._class_register

0 commit comments

Comments
 (0)