|
1 |
| -import importlib |
2 |
| -import os |
3 |
| -import pathlib |
4 |
| -from typing import List |
5 |
| - |
6 |
| -OPBENCH_DIR = "operators" |
7 |
| -INTERNAL_OPBENCH_DIR = "fb" |
8 |
| - |
9 |
| - |
10 |
| -def _dir_contains_file(dir, file_name) -> bool: |
11 |
| - names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) |
12 |
| - return file_name in names |
13 |
| - |
14 |
| - |
15 |
| -def _is_internal_operator(op_name: str) -> bool: |
16 |
| - p = ( |
17 |
| - pathlib.Path(__file__) |
18 |
| - .parent.parent.joinpath(OPBENCH_DIR) |
19 |
| - .joinpath(INTERNAL_OPBENCH_DIR) |
20 |
| - .joinpath(op_name) |
21 |
| - ) |
22 |
| - if p.exists() and p.joinpath("__init__.py").exists(): |
23 |
| - return True |
24 |
| - return False |
25 |
| - |
26 |
| - |
27 |
| -def _list_opbench_paths() -> List[str]: |
28 |
| - p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) |
29 |
| - # Only load the model directories that contain a "__init.py__" file |
30 |
| - opbench = sorted( |
31 |
| - str(child.absolute()) |
32 |
| - for child in p.iterdir() |
33 |
| - if child.is_dir() and _dir_contains_file(child, "__init__.py") |
34 |
| - ) |
35 |
| - p = p.joinpath(INTERNAL_OPBENCH_DIR) |
36 |
| - if p.exists(): |
37 |
| - o = sorted( |
38 |
| - str(child.absolute()) |
39 |
| - for child in p.iterdir() |
40 |
| - if child.is_dir() and _dir_contains_file(child, "__init__.py") |
41 |
| - ) |
42 |
| - opbench.extend(o) |
43 |
| - return opbench |
44 |
| - |
45 |
| - |
46 |
| -def list_operators() -> List[str]: |
47 |
| - operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) |
48 |
| - if INTERNAL_OPBENCH_DIR in operators: |
49 |
| - operators.remove(INTERNAL_OPBENCH_DIR) |
50 |
| - return operators |
51 |
| - |
52 |
| - |
53 |
| -def load_opbench_by_name(op_name: str): |
54 |
| - opbench_list = filter( |
55 |
| - lambda x: op_name.lower() == x.lower(), |
56 |
| - map(lambda y: os.path.basename(y), _list_opbench_paths()), |
57 |
| - ) |
58 |
| - opbench_list = list(opbench_list) |
59 |
| - if not opbench_list: |
60 |
| - raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") |
61 |
| - assert ( |
62 |
| - len(opbench_list) == 1 |
63 |
| - ), f"Found more than one operators {opbench_list} matching the required name: {op_name}" |
64 |
| - op_name = opbench_list[0] |
65 |
| - op_pkg = ( |
66 |
| - op_name |
67 |
| - if not _is_internal_operator(op_name) |
68 |
| - else f"{INTERNAL_OPBENCH_DIR}.{op_name}" |
69 |
| - ) |
70 |
| - module = importlib.import_module(f".{op_pkg}", package=__name__) |
71 |
| - |
72 |
| - Operator = getattr(module, "Operator", None) |
73 |
| - if Operator is None: |
74 |
| - print(f"Warning: {module} does not define attribute Operator, skip it") |
75 |
| - return None |
76 |
| - if not hasattr(Operator, "name"): |
77 |
| - Operator.name = op_name |
78 |
| - return Operator |
| 1 | +from .op import list_operators, load_opbench_by_name |
0 commit comments