|
3 | 3 | import json |
4 | 4 | import os |
5 | 5 | import re |
| 6 | +import subprocess |
6 | 7 | import sys |
7 | 8 | import sysconfig |
8 | 9 | import warnings |
|
24 | 25 | separate_inside_and_outside_square_brackets, |
25 | 26 | ) |
26 | 27 | from .settings_utils import get_constants, get_settings |
| 28 | +from .text_utils import snake_to_camel_case |
27 | 29 | from .type_utils import isoftype, issubtype |
28 | 30 | from .utils import ( |
29 | 31 | artifacts_json_cache, |
@@ -227,9 +229,29 @@ def get_module_class_names(artifact_type: dict): |
227 | 229 | return artifact_type["module"], artifact_type["name"] |
228 | 230 |
|
229 | 231 |
|
| 232 | +def convert_str_type_to_dict(type: str) -> dict: |
| 233 | + class_name = snake_to_camel_case(type) |
| 234 | + module, class_name = find_unitxt_module_and_class_by_classname( |
| 235 | + camel_case_class_name=class_name |
| 236 | + ) |
| 237 | + return { |
| 238 | + "module": module, |
| 239 | + "name": class_name, |
| 240 | + } |
| 241 | + |
| 242 | + |
230 | 243 | # type is the dict read from a catelog entry, the value of a key "__type__" |
231 | 244 | def get_class_from_artifact_type(type: dict): |
232 | | - module_path, class_name = get_module_class_names(type) |
| 245 | + if isinstance(type, str): |
| 246 | + if type in Artifact._class_register: |
| 247 | + return Artifact._class_register[type] |
| 248 | + |
| 249 | + module_path, class_name = find_unitxt_module_and_class_by_classname( |
| 250 | + snake_to_camel_case(type) |
| 251 | + ) |
| 252 | + else: |
| 253 | + module_path, class_name = get_module_class_names(type) |
| 254 | + |
233 | 255 | if module_path == "class_register": |
234 | 256 | if class_name not in Artifact._class_register: |
235 | 257 | raise ValueError( |
@@ -487,12 +509,15 @@ def is_artifact_file(cls, path): |
487 | 509 | @classmethod |
488 | 510 | def load(cls, path, artifact_identifier=None, overwrite_args=None): |
489 | 511 | d = artifacts_json_cache(path) |
490 | | - if "__type__" in d and d["__type__"]["name"].endswith("ArtifactLink"): |
491 | | - from_dict(d) # for verifications and warnings |
492 | | - catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) |
493 | | - return catalog.get_with_overwrite( |
494 | | - artifact_rep, overwrite_args=overwrite_args |
495 | | - ) |
| 512 | + if "__type__" in d: |
| 513 | + if isinstance(d["__type__"], str): |
| 514 | + d["__type__"] = convert_str_type_to_dict(d["__type__"]) |
| 515 | + if d["__type__"]["name"].endswith("ArtifactLink"): |
| 516 | + from_dict(d) # for verifications and warnings |
| 517 | + catalog, artifact_rep, _ = get_catalog_name_and_args(name=d["to"]) |
| 518 | + return catalog.get_with_overwrite( |
| 519 | + artifact_rep, overwrite_args=overwrite_args |
| 520 | + ) |
496 | 521 |
|
497 | 522 | new_artifact = from_dict(d, overwrite_args=overwrite_args) |
498 | 523 | new_artifact.__id__ = artifact_identifier |
@@ -898,3 +923,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: |
898 | 923 | return None |
899 | 924 |
|
900 | 925 | return data_classification.get(artifact) |
| 926 | + |
| 927 | + |
| 928 | +def find_unitxt_module_and_class_by_classname(camel_case_class_name: str): |
| 929 | + """Find a module, a member of src/unitxt, that contains the definition of the class.""" |
| 930 | + dir = os.path.dirname(__file__) # dir src/unitxt |
| 931 | + try: |
| 932 | + result = subprocess.run( |
| 933 | + ["grep", "-irwE", "^class +" + camel_case_class_name, dir], |
| 934 | + capture_output=True, |
| 935 | + ).stdout.decode("ascii") |
| 936 | + results = result.split("\n") |
| 937 | + assert len(results) == 2, f"returned: {results}" |
| 938 | + assert results[-1] == "", f"last result is {results[-1]} rather than ''" |
| 939 | + to_return_module = ( |
| 940 | + results[0].split(":")[0][:-3].replace("/", ".") |
| 941 | + ) # trim the .py and replace |
| 942 | + to_return_class_name = results[0].split(":")[1][ |
| 943 | + 6 : 6 + len(camel_case_class_name) |
| 944 | + ] |
| 945 | + return to_return_module[ |
| 946 | + to_return_module.rfind("unitxt.") : |
| 947 | + ], to_return_class_name |
| 948 | + except Exception as e: |
| 949 | + raise ValueError( |
| 950 | + f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined" |
| 951 | + ) from e |
0 commit comments