|
3 | 3 | import json |
4 | 4 | import os |
5 | 5 | import re |
| 6 | +import subprocess |
6 | 7 | import sys |
7 | 8 | import sysconfig |
8 | 9 | import warnings |
|
24 | 25 | separate_inside_and_outside_square_brackets, |
25 | 26 | ) |
26 | 27 | from .settings_utils import get_constants, get_settings |
| 28 | +from .text_utils import snake_to_camel_case |
27 | 29 | from .type_utils import isoftype, issubtype |
28 | 30 | from .utils import ( |
29 | 31 | artifacts_json_cache, |
@@ -227,9 +229,25 @@ def get_module_class_names(artifact_type: dict): |
227 | 229 | return artifact_type["module"], artifact_type["name"] |
228 | 230 |
|
229 | 231 |
|
| 232 | +def convert_str_type_to_dict(type: str) -> dict: |
| 233 | + class_name = snake_to_camel_case(type) |
| 234 | + return { |
| 235 | + "module": find_unitxt_module_by_classname(camel_case_class_name=class_name), |
| 236 | + "name": class_name, |
| 237 | + } |
| 238 | + |
| 239 | + |
230 | 240 | # type is the dict read from a catelog entry, the value of a key "__type__" |
231 | 241 | def get_class_from_artifact_type(type: dict): |
232 | | - module_path, class_name = get_module_class_names(type) |
| 242 | + if isinstance(type, str): |
| 243 | + if type in Artifact._class_register: |
| 244 | + return Artifact._class_register[type] |
| 245 | + |
| 246 | + class_name = snake_to_camel_case(type) |
| 247 | + module_path = find_unitxt_module_by_classname(camel_case_class_name=class_name) |
| 248 | + else: |
| 249 | + module_path, class_name = get_module_class_names(type) |
| 250 | + |
233 | 251 | if module_path == "class_register": |
234 | 252 | if class_name not in Artifact._class_register: |
235 | 253 | raise ValueError( |
@@ -487,12 +505,15 @@ def is_artifact_file(cls, path): |
487 | 505 | @classmethod |
488 | 506 | def load(cls, path, artifact_identifier=None, overwrite_args=None): |
489 | 507 | d = artifacts_json_cache(path) |
490 | | - if "__type__" in d and d["__type__"]["name"].endswith("ArtifactLink"): |
491 | | - from_dict(d) # for verifications and warnings |
492 | | - catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) |
493 | | - return catalog.get_with_overwrite( |
494 | | - artifact_rep, overwrite_args=overwrite_args |
495 | | - ) |
| 508 | + if "__type__" in d: |
| 509 | + if isinstance(d["__type__"], str): |
| 510 | + d["__type__"] = convert_str_type_to_dict(d["__type__"]) |
| 511 | + if d["__type__"]["name"].endswith("ArtifactLink"): |
| 512 | + from_dict(d) # for verifications and warnings |
| 513 | + catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) |
| 514 | + return catalog.get_with_overwrite( |
| 515 | + artifact_rep, overwrite_args=overwrite_args |
| 516 | + ) |
496 | 517 |
|
497 | 518 | new_artifact = from_dict(d, overwrite_args=overwrite_args) |
498 | 519 | new_artifact.__id__ = artifact_identifier |
@@ -898,3 +919,22 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: |
898 | 919 | return None |
899 | 920 |
|
900 | 921 | return data_classification.get(artifact) |
| 922 | + |
| 923 | + |
| 924 | +def find_unitxt_module_by_classname(camel_case_class_name: str): |
| 925 | + """Find a module, a member of src/unitxt, that contains the definition of the class.""" |
| 926 | + dir = os.path.dirname(__file__) # dir src/unitxt |
| 927 | + try: |
| 928 | + result = subprocess.run( |
| 929 | + ["grep", "-lrwE", "^class +" + camel_case_class_name, dir], |
| 930 | + capture_output=True, |
| 931 | + ).stdout.decode("ascii") |
| 932 | + results = result.split("\n") |
| 933 | + assert len(results) == 2, f"returned: {results}" |
| 934 | + assert results[-1] == "", f"last result is {results[-1]} rather than ''" |
| 935 | + to_return = results[0][:-3].replace("/", ".") # trim the .py and replace |
| 936 | + return to_return[to_return.rfind("unitxt.") :] |
| 937 | + except Exception as e: |
| 938 | + raise ValueError( |
| 939 | + f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined" |
| 940 | + ) from e |
0 commit comments