Skip to content

Commit f28882e

Browse files
authored
Generate metadata file for internal ops
Differential Revision: D86682742 Pull Request resolved: #650
1 parent 41b3d6f commit f28882e

File tree

4 files changed

+115
-43
lines changed

4 files changed

+115
-43
lines changed

benchmarks/tagging/ast_analyzer.py

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ class Edge:
6060
call_type: str
6161
callee_descriptor: FuncDescriptor
6262

63+
def split_by_the_last_dot(s: str) -> Optional[Tuple[str, str]]:
64+
if "." in s:
65+
return s.rsplit(".", 1)
66+
else:
67+
return None, s
6368

6469
class CallGraph(ast.NodeVisitor):
6570
def __init__(
@@ -158,6 +163,12 @@ def _record_call(
158163
):
159164
if caller is None:
160165
caller = self._cur_scope() or "<module>"
166+
# replace callee with caller class name if it is "self." call
167+
if "." in caller and callee.startswith("self."):
168+
caller_prefix, _ = split_by_the_last_dot(caller)
169+
# remove the "self." prefix
170+
callee_name = callee[5 :]
171+
callee = caller_prefix + "." + callee_name
161172
site = Site(
162173
self.filename, getattr(node, "lineno", -1), getattr(node, "col_offset", -1)
163174
)
@@ -167,8 +178,6 @@ def _record_call(
167178
and any([f"Operator.{backend}" in caller for backend in self.backends])
168179
and not callee == "tritonbench.utils.triton_op.register_benchmark"
169180
):
170-
if callee.startswith("self."):
171-
return
172181
if is_fbcode() and callee.startswith("liger_kernel."):
173182
return
174183
# identify this call belongs to which backend
@@ -188,9 +197,7 @@ def _record_call(
188197
)
189198
)
190199
elif any([backend in caller for backend in self.backends]):
191-
# "torch.ops" is a binary custom ops
192-
if callee.startswith("self."):
193-
return
200+
# skip aten calls
194201
if callee.startswith("torch.") and not callee.startswith("torch.ops."):
195202
return
196203
# we are sure there is no kernel defined in this package ;-)
@@ -202,7 +209,7 @@ def _record_call(
202209
if maybe_triton and callee_descriptor == None:
203210
callee_descriptor = FuncDescriptor(
204211
callee,
205-
["triton.jit"],
212+
["maybe.triton.jit"],
206213
Site(
207214
self.filename,
208215
getattr(node, "lineno", -1),
@@ -362,6 +369,27 @@ def visit_Call(self, node: ast.Call):
362369
callee = fn.value.id
363370
elif isinstance(fn.value, ast.Attribute):
364371
callee = fn.value.value.id
372+
if hasattr(fn.value, "attr"):
373+
callee = callee + "." + fn.value.attr
374+
elif isinstance(fn.value, ast.Call):
375+
# add hack to handle torch._library.capture_triton
376+
if (
377+
isinstance(fn.value.func, ast.Name)
378+
and fn.value.func.id == "capture_triton"
379+
) or (
380+
isinstance(fn.value.func.value, ast.Attribute)
381+
and fn.value.func.value.attr == "_library"
382+
and fn.value.func.attr == "capture_triton"
383+
):
384+
if isinstance(fn.value.args[0], ast.Call):
385+
callee_func = fn.value.args[0].func.id
386+
elif isinstance(fn.value.args[0], ast.Name):
387+
callee_func = fn.value.args[0].id
388+
elif isinstance(fn.value.args[0], ast.Attribute):
389+
callee_func = fn.value.args[0].value.id
390+
else:
391+
callee_func = "unknown"
392+
callee = f"<torch._library.capture_triton({callee_func})>"
365393
else:
366394
callee = "<dynamic_call>"
367395
maybe_triton = True # FIXME: this could also be cute, see blackwell_attentions cute dsl
@@ -377,16 +405,22 @@ def validate_edges(edges) -> Dict[str, str]:
377405
result_tags["tags"] = []
378406
result_tags["kernels"] = []
379407
for edge in edges:
408+
if edge.callee == "cutlass.cute.compile":
409+
result_tags["tags"].append("cutedsl")
410+
result_tags["kernels"].append(edge.caller)
380411
if edge.callee_descriptor and (
381412
"triton.jit" in edge.callee_descriptor.decorators
382413
or "<dynamic_decorator_triton.jit>" in edge.callee_descriptor.decorators
383414
):
384415
result_tags["tags"].append("triton")
385416
result_tags["kernels"].append(edge.callee)
386-
if edge.callee == "cutlass.cute.compile":
387-
result_tags["tags"].append("cutedsl")
388-
result_tags["kernels"].append(edge.caller)
389-
if edge.callee.startswith("torch.ops."):
417+
if edge.callee.startswith("<torch._library.capture_triton"):
418+
result_tags["tags"].append("triton")
419+
result_tags["kernels"].append(edge.callee)
420+
if (
421+
edge.callee.startswith("torch.ops.")
422+
and not "cutedsl" in result_tags["tags"]
423+
):
390424
result_tags["tags"].append("native_custom_ops")
391425
# definition is in cpp, so we don't have the definition site
392426
result_tags["kernels"].append(edge.callee)
@@ -398,16 +432,27 @@ def validate_edges(edges) -> Dict[str, str]:
398432
if edge.callee.startswith("tilelang.compile"):
399433
result_tags["tags"].append("tilelang")
400434
result_tags["kernels"].append(edge.caller)
435+
if "torch.ops.fbgemm" in edge.callee:
436+
result_tags["tags"].append("fbgemm")
437+
# heuristic: if no tag is found and maybe triton, apply triton tag
438+
if (
439+
not result_tags["tags"]
440+
and edge.callee_descriptor
441+
and "maybe.triton.jit" in edge.callee_descriptor.decorators
442+
):
443+
result_tags["tags"].append("triton")
444+
result_tags["kernels"].append(edge.callee)
401445
# remove duplicates
402446
result_tags["tags"] = list(set(result_tags["tags"]))
447+
result_tags["kernels"] = list(set(result_tags["kernels"]))
403448
if not result_tags["kernels"] and not result_tags["tags"]:
404449
return None
405450
return result_tags
406451

407452

408453
def gen_static_extension_tags(callee: str) -> Dict[str, str]:
409454
result_tags = {}
410-
result_tags["tags"] = {"native_extension"}
455+
result_tags["tags"] = ["native_extension"]
411456
result_tags["kernels"] = [callee]
412457
return result_tags
413458

@@ -433,55 +478,64 @@ def trace_callees(callees_with_module: List[Tuple[str, str]], depth=8):
433478
if callee.endswith(".apply"):
434479
callee = callee[: callee.rfind(".apply")] + ".forward"
435480

436-
callee_module = callee[: callee.rfind(".")] if "." in callee else None
437-
callee_name = callee[callee.rfind(".") + 1 :] if "." in callee else callee
438-
maybe_callee_module = (
439-
callee_module[: callee_module.rfind(".")]
440-
if callee_module and "." in callee_module
441-
else None
442-
)
443-
maybe_callee_class = (
444-
callee_module[callee_module.rfind(".") + 1 :]
445-
if callee_module and "." in callee_module
446-
else None
447-
)
481+
callee_module, callee_name = split_by_the_last_dot(callee)
482+
maybe_callee_module, maybe_callee_class = split_by_the_last_dot(callee_module)
483+
parent_module_name, _child_module_name = split_by_the_last_dot(module_name)
484+
448485
# best effort to find and import the module
449486
# print(f"callee: {callee}")
450487
# print(f"callee module: {callee_module}")
451488
# print(f"callee name: {callee_name}")
452489
# print(f"module name: {module_name}")
453490
# print(f"maybe callee module: {maybe_callee_module}")
454491
# print(f"maybe callee class: {maybe_callee_class}")
455-
if not callee_module and not maybe_callee_module:
492+
if callee_module == None and maybe_callee_module == None:
456493
continue
457494
try:
458495
module = importlib.import_module(callee_module)
459496
source_file = inspect.getfile(module)
497+
if not hasattr(module, callee_name):
498+
raise ModuleNotFoundError(f"Not found {callee_name} in {module}")
460499
except (ModuleNotFoundError, TypeError):
461500
try:
462501
# try with relative import
463-
module = importlib.import_module(f"{module_name}.{callee_module}")
502+
if parent_module_name is None:
503+
parent_module_name = module_name
504+
module = importlib.import_module(
505+
f"{parent_module_name}.{callee_module}"
506+
)
464507
source_file = inspect.getfile(module)
465508
except (ModuleNotFoundError, TypeError):
466509
if maybe_callee_module == None:
467510
continue
468511
try:
469512
module = importlib.import_module(maybe_callee_module)
470513
source_file = inspect.getfile(module)
514+
if not hasattr(module, maybe_callee_class):
515+
raise ModuleNotFoundError(
516+
f"Not found {maybe_callee_class} in {module}"
517+
)
471518
callee_name = f"{maybe_callee_class}.{callee_name}"
472519
except (ModuleNotFoundError, TypeError):
473520
try:
521+
# try with relative import
522+
if parent_module_name is None:
523+
parent_module_name = module_name
474524
module = importlib.import_module(
475-
f"{module_name}.{maybe_callee_module}"
525+
f"{parent_module_name}.{maybe_callee_module}"
476526
)
477527
source_file = inspect.getfile(module)
478528
callee_name = f"{maybe_callee_class}.{callee_name}"
479-
except Exception:
529+
except Exception as e:
480530
# give up
481531
print(
482-
f"Failed to load module {maybe_callee_module} from entity {callee}"
532+
f"Failed to load module {maybe_callee_module} from entity {callee}: {e}"
483533
)
484534
continue
535+
except Exception:
536+
# give up
537+
print(f"Failed to load module {callee_module} from entity {callee}")
538+
continue
485539
if not module:
486540
print(f"Failed to find {callee} at module {callee_module} ")
487541
continue

benchmarks/tagging/run.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
from os.path import abspath, exists
1010
from pathlib import Path
11+
from typing import Any
1112

1213
import yaml
1314

@@ -125,11 +126,14 @@ def prevalidate_backends(backend_edges):
125126
op_with_tags[backend] = {"tags": ["xformers"]}
126127
elif any([callee.startswith("torch.ops.") for callee in callees]):
127128
custom_op_category = [
128-
callee[callee.rfind(".") + 1 :]
129-
for callee in callees
130-
if callee.startswith("torch.ops.")
129+
callee for callee in callees if callee.startswith("torch.ops.")
131130
]
132-
op_with_tags[backend] = {"tags": custom_op_category + ["native_custom_ops"]}
131+
op_with_tags[backend] = {
132+
"tags": ["native_custom_ops"],
133+
"kernels": custom_op_category,
134+
}
135+
if any(["fbgemm" in callee for callee in callees]):
136+
op_with_tags[backend]["tags"].append("fbgemm")
133137

134138
# Apply name-based heuristics for all prevalidated backends
135139
for backend in op_with_tags.keys():
@@ -167,9 +171,8 @@ def trace_op(op):
167171
for backend in remaining_backends:
168172
# special case for torch.compile
169173
callees = backend_edges[backend]
170-
base_module_name = module_name[: module_name.rfind(".")]
171-
callees_with_module: list[tuple[Unknown, Unknown]] = [
172-
(callee, base_module_name) for callee in callees
174+
callees_with_module: list[tuple[Any, Any]] = [
175+
(callee, module_name) for callee in callees
173176
]
174177
op_with_tags[op][backend] = trace_callees(callees_with_module)
175178
# Apply name-based heuristics

tritonbench/operators/gdpa/operator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
if has_tlx():
4545
from tritonbench.operators.gdpa.gdpa_blackwell_tlx import gdpa_forward_tlx
4646

47-
from .gdpa import gdpa
47+
from .gdpa import gdpa as gdpa_kernel
4848
from .gdpa_utils import generate_jagged_data
4949

5050

@@ -220,7 +220,7 @@ def gdpa(
220220
activation,
221221
):
222222
def _inner():
223-
real_output = gdpa(
223+
real_output = gdpa_kernel(
224224
query=jagged_q,
225225
key=jagged_k,
226226
value=jagged_v,
@@ -368,7 +368,7 @@ def gdpa_opt(
368368
activation,
369369
):
370370
def _inner():
371-
real_output = gdpa(
371+
real_output = gdpa_kernel(
372372
query=jagged_q,
373373
key=jagged_k,
374374
value=jagged_v,
@@ -406,7 +406,7 @@ def gdpa_opt_sorted(
406406
activation,
407407
):
408408
def _inner():
409-
real_output = gdpa(
409+
real_output = gdpa_kernel(
410410
query=jagged_q,
411411
key=jagged_k,
412412
value=jagged_v,

tritonbench/operators/op.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77

88
OPBENCH_DIR = "operators"
99
INTERNAL_OPBENCH_DIR = "fb"
10+
LIGER_OPERATORS = [
11+
"embedding",
12+
"rms_norm",
13+
"rope",
14+
"jsd",
15+
"fused_linear_jsd",
16+
"cross_entropy",
17+
"fused_linear_cross_entropy",
18+
"geglu",
19+
"kl_div",
20+
"swiglu",
21+
]
1022

1123

1224
def _dir_contains_file(dir, file_name) -> bool:
@@ -42,14 +54,17 @@ def _list_opbench_paths() -> List[str]:
4254
if child.is_dir() and _dir_contains_file(child, "__init__.py")
4355
)
4456
opbench.extend(o)
57+
opbench = [
58+
op
59+
for op in opbench
60+
if os.path.basename(op) not in LIGER_OPERATORS
61+
and not os.path.basename(op) == INTERNAL_OPBENCH_DIR
62+
]
4563
return opbench
4664

4765

4866
def list_operators() -> List[str]:
49-
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
50-
if INTERNAL_OPBENCH_DIR in operators:
51-
operators.remove(INTERNAL_OPBENCH_DIR)
52-
return operators
67+
return [os.path.basename(y) for y in _list_opbench_paths()]
5368

5469

5570
def list_custom_triton_operators(

0 commit comments

Comments
 (0)