Skip to content

Commit bec6d9f

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Install tritonbench as a library (#81)
Summary: Now users can import Tritonbench as a library: ``` $ pip install -e . $ python -c "import tritonbench; op = tritonbench.load_opbench_by_name('addmm');" ``` Clean up the init file so that dependencies like `os` and `importlib` will not pollute the namespace. Pull Request resolved: #81 Reviewed By: FindHao Differential Revision: D66544679 Pulled By: xuzhao9 fbshipit-source-id: 0339adef58d28f2c7f207a2869c21d4b55575386
1 parent c666f87 commit bec6d9f

File tree

24 files changed

+205
-151
lines changed

24 files changed

+205
-151
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ __pycache__/
1313
.DS_Store
1414
.ipynb_checkpoints/
1515
.idea
16-
torchbench.egg-info/
16+
*.egg-info/

README.md

+15-1
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,26 @@ By default, it will install the latest PyTorch nightly release and use the Trito
2626

2727
## Basic Usage
2828

29-
To benchmark an operator, use the following command:
29+
To benchmark an operator, run the following command:
3030

3131
```
3232
$ python run.py --op gemm
3333
```
3434

35+
## Install as a library
36+
37+
To install as a library:
38+
39+
```
40+
$ pip install -e .
41+
# in your own benchmark script
42+
import tritonbench
43+
from tritonbench.utils import parser
44+
op_args = parser.parse_args()
45+
addmm_bench = tritonbench.load_opbench_by_name("addmm")(op_args)
46+
addmm_bench.run()
47+
```
48+
3549
## Submodules
3650

3751
We depend on the following projects as a source of customized Triton or CUTLASS kernels:

pyproject.toml

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
[build-system]
2+
requires = ["setuptools>=40.8.0", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "tritonbench"
7+
version = "0.0.1"
8+
dependencies = [
9+
"torch",
10+
"triton",
11+
]
12+
13+
[tool.setuptools.packages.find]
14+
include = ["tritonbench*"]
15+
116
[tool.ufmt]
217
formatter = "ruff-api"
318
excludes = ["submodules/"]

tritonbench/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .operators import list_operators, load_opbench_by_name
2+
from .operators_collection import (
3+
list_operator_collections,
4+
list_operators_by_collection,
5+
)

tritonbench/operators/__init__.py

+1-78
Original file line numberDiff line numberDiff line change
@@ -1,78 +1 @@
1-
import importlib
2-
import os
3-
import pathlib
4-
from typing import List
5-
6-
OPBENCH_DIR = "operators"
7-
INTERNAL_OPBENCH_DIR = "fb"
8-
9-
10-
def _dir_contains_file(dir, file_name) -> bool:
11-
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
12-
return file_name in names
13-
14-
15-
def _is_internal_operator(op_name: str) -> bool:
16-
p = (
17-
pathlib.Path(__file__)
18-
.parent.parent.joinpath(OPBENCH_DIR)
19-
.joinpath(INTERNAL_OPBENCH_DIR)
20-
.joinpath(op_name)
21-
)
22-
if p.exists() and p.joinpath("__init__.py").exists():
23-
return True
24-
return False
25-
26-
27-
def _list_opbench_paths() -> List[str]:
28-
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
29-
# Only load the model directories that contain a "__init.py__" file
30-
opbench = sorted(
31-
str(child.absolute())
32-
for child in p.iterdir()
33-
if child.is_dir() and _dir_contains_file(child, "__init__.py")
34-
)
35-
p = p.joinpath(INTERNAL_OPBENCH_DIR)
36-
if p.exists():
37-
o = sorted(
38-
str(child.absolute())
39-
for child in p.iterdir()
40-
if child.is_dir() and _dir_contains_file(child, "__init__.py")
41-
)
42-
opbench.extend(o)
43-
return opbench
44-
45-
46-
def list_operators() -> List[str]:
47-
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
48-
if INTERNAL_OPBENCH_DIR in operators:
49-
operators.remove(INTERNAL_OPBENCH_DIR)
50-
return operators
51-
52-
53-
def load_opbench_by_name(op_name: str):
54-
opbench_list = filter(
55-
lambda x: op_name.lower() == x.lower(),
56-
map(lambda y: os.path.basename(y), _list_opbench_paths()),
57-
)
58-
opbench_list = list(opbench_list)
59-
if not opbench_list:
60-
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
61-
assert (
62-
len(opbench_list) == 1
63-
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
64-
op_name = opbench_list[0]
65-
op_pkg = (
66-
op_name
67-
if not _is_internal_operator(op_name)
68-
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
69-
)
70-
module = importlib.import_module(f".{op_pkg}", package=__name__)
71-
72-
Operator = getattr(module, "Operator", None)
73-
if Operator is None:
74-
print(f"Warning: {module} does not define attribute Operator, skip it")
75-
return None
76-
if not hasattr(Operator, "name"):
77-
Operator.name = op_name
78-
return Operator
1+
from .op import list_operators, load_opbench_by_name

tritonbench/operators/addmm/hstu.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import triton
7+
78
from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH
89

910
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):

tritonbench/operators/fp8_gemm_rowwise/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import triton
9+
910
from tritonbench.utils.data_utils import get_production_shapes
1011

1112
from tritonbench.utils.triton_op import (

tritonbench/operators/geglu/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from transformers.models.llama.configuration_llama import LlamaConfig
77
from transformers.models.llama.modeling_llama import LlamaMLP
8+
89
from tritonbench.utils.triton_op import (
910
BenchmarkOperator,
1011
register_benchmark,

tritonbench/operators/gemm/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch._inductor.config as inductor_config
88
import triton
9+
910
from tritonbench.utils.data_utils import get_production_shapes
1011

1112
from tritonbench.utils.path_utils import REPO_PATH

tritonbench/operators/jagged_layer_norm/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import triton
10+
1011
from tritonbench.utils.jagged_utils import (
1112
ABSOLUTE_TOLERANCE,
1213
EPSILON,

tritonbench/operators/jagged_mean/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import triton
10+
1011
from tritonbench.utils.jagged_utils import (
1112
ABSOLUTE_TOLERANCE,
1213
generate_input_vals,

tritonbench/operators/jagged_softmax/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import triton
10+
1011
from tritonbench.utils.jagged_utils import (
1112
ABSOLUTE_TOLERANCE,
1213
generate_input_vals,

tritonbench/operators/jagged_sum/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import triton
10+
1011
from tritonbench.utils.jagged_utils import (
1112
ABSOLUTE_TOLERANCE,
1213
generate_input_vals,

tritonbench/operators/layer_norm/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.nn.functional as F
55
import triton
6+
67
from tritonbench.utils.triton_op import (
78
BenchmarkOperator,
89
BenchmarkOperatorMetrics,

tritonbench/operators/op.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import importlib
2+
import os
3+
import pathlib
4+
from typing import List
5+
6+
OPBENCH_DIR = "operators"
7+
INTERNAL_OPBENCH_DIR = "fb"
8+
9+
10+
def _dir_contains_file(dir, file_name) -> bool:
11+
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
12+
return file_name in names
13+
14+
15+
def _is_internal_operator(op_name: str) -> bool:
16+
p = (
17+
pathlib.Path(__file__)
18+
.parent.parent.joinpath(OPBENCH_DIR)
19+
.joinpath(INTERNAL_OPBENCH_DIR)
20+
.joinpath(op_name)
21+
)
22+
if p.exists() and p.joinpath("__init__.py").exists():
23+
return True
24+
return False
25+
26+
27+
def _list_opbench_paths() -> List[str]:
28+
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
29+
# Only load the model directories that contain a "__init.py__" file
30+
opbench = sorted(
31+
str(child.absolute())
32+
for child in p.iterdir()
33+
if child.is_dir() and _dir_contains_file(child, "__init__.py")
34+
)
35+
p = p.joinpath(INTERNAL_OPBENCH_DIR)
36+
if p.exists():
37+
o = sorted(
38+
str(child.absolute())
39+
for child in p.iterdir()
40+
if child.is_dir() and _dir_contains_file(child, "__init__.py")
41+
)
42+
opbench.extend(o)
43+
return opbench
44+
45+
46+
def list_operators() -> List[str]:
47+
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
48+
if INTERNAL_OPBENCH_DIR in operators:
49+
operators.remove(INTERNAL_OPBENCH_DIR)
50+
return operators
51+
52+
53+
def load_opbench_by_name(op_name: str):
54+
opbench_list = filter(
55+
lambda x: op_name.lower() == x.lower(),
56+
map(lambda y: os.path.basename(y), _list_opbench_paths()),
57+
)
58+
opbench_list = list(opbench_list)
59+
if not opbench_list:
60+
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
61+
assert (
62+
len(opbench_list) == 1
63+
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
64+
op_name = opbench_list[0]
65+
op_pkg = (
66+
op_name
67+
if not _is_internal_operator(op_name)
68+
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
69+
)
70+
module = importlib.import_module(f"..{op_pkg}", package=__name__)
71+
72+
Operator = getattr(module, "Operator", None)
73+
if Operator is None:
74+
print(f"Warning: {module} does not define attribute Operator, skip it")
75+
return None
76+
if not hasattr(Operator, "name"):
77+
Operator.name = op_name
78+
return Operator

tritonbench/operators/op_task.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict, List, Optional
88

99
import torch
10+
1011
from tritonbench.components.tasks import base as base_task
1112
from tritonbench.components.workers import subprocess_worker
1213

tritonbench/operators/ragged_attention/hstu.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import triton
3+
34
from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH
45
from tritonbench.utils.triton_op import IS_FBCODE
56

tritonbench/operators/ragged_attention/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, List, Optional
44

55
import torch
6+
67
from tritonbench.utils.input import input_filter
78

89
from tritonbench.utils.triton_op import (

tritonbench/operators/softmax/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import triton
55
import triton.language as tl
6+
67
from tritonbench.utils.data_utils import get_production_shapes
78

89
from tritonbench.utils.triton_op import (

tritonbench/operators/sum/operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import triton
99
import triton.language as tl
10+
1011
from tritonbench.utils.triton_op import (
1112
BenchmarkOperator,
1213
BenchmarkOperatorMetrics,

0 commit comments

Comments
 (0)