Skip to content

Commit 1bd70cf

Browse files
oulgenpobin6
authored andcommitted
[RFC] Implement caching for user defined triton kernels (pytorch#140326)
This PR adds caching for user defined triton kernels by putting the transitive closure of source code in node.meta along with constant arguments. One HUGE hack we do here is a node looks like ``` triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(1, 1, 1)], tma_descriptor_ metadata = {}, kwargs = {'in_ptr0': arg0_1, 'in_ptr1': arg1_1, 'out_ptr': arg0_1}, tensors_to_clone = ['out_ptr']); ``` so we use regex to remove `kernel_idx = 0, constant_args_idx = 1` parts as they are not relevant to cache hash. This is horrible and I'd like to eventually not use pickle as a hashing alternative but this is a longer project. Differential Revision: [D65895744](https://our.internmc.facebook.com/intern/diff/D65895744) Pull Request resolved: pytorch#140326 Approved by: https://github.com/zou3519
1 parent 204fbb3 commit 1bd70cf

File tree

6 files changed

+196
-26
lines changed

6 files changed

+196
-26
lines changed

test/inductor/test_codecache.py

Lines changed: 89 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
if HAS_TRITON:
5050
import triton # @manual
5151

52-
from torch.testing._internal.triton_utils import add_kernel
52+
from torch.testing._internal.triton_utils import add_kernel, sub_kernel
5353

5454
torch._dynamo.config.fake_tensor_cache_enabled = True
5555
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
@@ -494,13 +494,41 @@ def fn2(q, k, v):
494494
@config.patch({"fx_graph_cache": True})
495495
@config.patch({"fx_graph_remote_cache": False})
496496
@parametrize("bundle_triton", (False, True))
497-
@parametrize("grad", (False, True))
498-
def test_triton_higher_order_op_bypass(self, bundle_triton, grad):
497+
def test_higher_order_op_bypass(self, bundle_triton):
499498
"""
500-
Verify that we bypass the cache when we have a triton higher order ops
499+
Verify that we bypass the cache when we have a higher order ops
501500
and that bundler start/end works with a cache bypass.
502501
"""
503502

503+
def fn(x):
504+
def true_fn(x: torch.Tensor):
505+
return x.cos()
506+
507+
def false_fn(x: torch.Tensor):
508+
return x.sin()
509+
510+
return torch.cond(x.shape[0], true_fn, false_fn, (x,))
511+
512+
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
513+
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)
514+
515+
x = torch.randn(4, 4, device=GPU_TYPE)
516+
result = compiled_fn(x)
517+
518+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
519+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
520+
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
521+
522+
@requires_gpu()
523+
@requires_triton()
524+
@config.patch({"fx_graph_cache": True})
525+
@config.patch({"fx_graph_remote_cache": False})
526+
@parametrize("bundle_triton", (False, True))
527+
def test_triton_higher_order_op(self, bundle_triton):
528+
"""
529+
Verify that we can cache user defined triton kernel higher order op
530+
"""
531+
504532
def fn(x, y):
505533
n_elements = x.numel()
506534
grid = lambda meta: ( # noqa: E731
@@ -509,18 +537,54 @@ def fn(x, y):
509537
add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
510538
return x
511539

540+
def fn2(x, y):
541+
n_elements = x.numel()
542+
grid = lambda meta: ( # noqa: E731
543+
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
544+
)
545+
sub_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
546+
return x
547+
512548
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
513549
compiled_fn = torch.compile(fn, fullgraph=True)
550+
compiled_fn2 = torch.compile(fn2, fullgraph=True)
551+
552+
x = torch.randn(4, device=GPU_TYPE)
553+
y = torch.randn(4, device=GPU_TYPE)
514554

515-
x = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
516-
y = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
517555
result = compiled_fn(x, y)
518-
if grad:
519-
result.sum().backward()
520556

521-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
557+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
522558
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
523-
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
559+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
560+
561+
# A second call should hit. (First reset so in-memory guards
562+
# don't prevent compilation).
563+
self.reset()
564+
565+
# Clean PyCodeCache and triton kernels
566+
PyCodeCache.cache_clear()
567+
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
568+
569+
result = compiled_fn(x, y)
570+
571+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
572+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
573+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
574+
575+
# A second call should hit. (First reset so in-memory guards
576+
# don't prevent compilation).
577+
self.reset()
578+
579+
# Clean PyCodeCache and triton kernels
580+
PyCodeCache.cache_clear()
581+
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
582+
583+
result = compiled_fn2(x, y)
584+
585+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
586+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
587+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
524588

525589
@config.patch({"fx_graph_cache": True})
526590
@config.patch({"fx_graph_remote_cache": False})
@@ -808,15 +872,16 @@ def test_tensor_constants(self):
808872
self.assertFalse(GraphLowering.can_inline_constant(large))
809873

810874
# By default, we hash the metadata and values independent of the size.
811-
pickler = FxGraphCachePickler()
875+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
876+
pickler = FxGraphCachePickler(gm)
812877

813878
data = pickler.dumps(small)
814879
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
815880
data = pickler.dumps(large)
816881
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
817882

818883
# If include_non_inlined=False, we only hash the values of small tensors.
819-
pickler = FxGraphCachePickler(False)
884+
pickler = FxGraphCachePickler(gm, False)
820885

821886
data = pickler.dumps(small)
822887
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
@@ -827,7 +892,8 @@ def test_hash_fake_tensors(self):
827892
"""
828893
Test hashing (pickling) FakeTensors with various characteristics.
829894
"""
830-
pickler = FxGraphCachePickler()
895+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
896+
pickler = FxGraphCachePickler(gm)
831897
with torch._subclasses.FakeTensorMode():
832898
# Verify that FakeTensors get pickled into a TensorMetadata:
833899
data = pickler.dumps(torch.randn(1))
@@ -933,7 +999,8 @@ def test_hash_kwargs(self):
933999
Test the special handling of the kwargs when hashing, i.e.,
9341000
ordering of the kwargs dict and any set arguments.
9351001
"""
936-
pickler = FxGraphCachePickler()
1002+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
1003+
pickler = FxGraphCachePickler(gm)
9371004

9381005
# Dict order of the kwargs should not affect hashes.
9391006
details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, [])
@@ -981,7 +1048,8 @@ def test_hash_config_changes(self):
9811048
with config.patch({"max_autotune": True}):
9821049
details3 = FxGraphHashDetails(None, [], {}, [])
9831050

984-
pickler = FxGraphCachePickler()
1051+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
1052+
pickler = FxGraphCachePickler(gm)
9851053

9861054
self.assertEqual(
9871055
pickler.dumps(details1),
@@ -1016,7 +1084,8 @@ def uuid(self) -> Optional[Union[bytes, str]]:
10161084
custom_pass._uuid = "2"
10171085
details3 = FxGraphHashDetails(None, [], {}, [])
10181086

1019-
pickler = FxGraphCachePickler()
1087+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
1088+
pickler = FxGraphCachePickler(gm)
10201089

10211090
self.assertEqual(
10221091
pickler.dumps(details1),
@@ -1031,8 +1100,9 @@ def test_bypass_unsupported(self):
10311100
"""
10321101
Test _reduce_unsupported
10331102
"""
1103+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
10341104
with self.assertRaises(BypassFxGraphCache):
1035-
FxGraphCachePickler().dumps(
1105+
FxGraphCachePickler(gm).dumps(
10361106
torch.fx.experimental._backward_state.BackwardState()
10371107
)
10381108

@@ -1047,7 +1117,8 @@ def test_stable_strings(self):
10471117

10481118
self.assertNotEqual(id(s1), id(s2))
10491119

1050-
pickler = FxGraphCachePickler()
1120+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
1121+
pickler = FxGraphCachePickler(gm)
10511122
self.assertEqual(
10521123
pickler.dumps([s1, s1]),
10531124
pickler.dumps([s1, s2]),

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def __init__(
227227

228228

229229
class AOTAutogradCachePickler(FxGraphCachePickler):
230-
def __init__(self):
231-
super().__init__()
230+
def __init__(self, gm: torch.fx.GraphModule):
231+
super().__init__(gm)
232232
self.dispatch_table: Dict
233233
self.dispatch_table.update(
234234
{
@@ -275,7 +275,7 @@ def autograd_cache_key(
275275
"""
276276
check_cacheable(gm)
277277
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
278-
pickler = AOTAutogradCachePickler()
278+
pickler = AOTAutogradCachePickler(gm)
279279
# The prefix distinguishes among the other kinds of objects we cache
280280
key = "a" + pickler.get_hash(details)
281281
debug_lines = pickler.debug_lines(details)

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def identify_mutated_tensors(
613613
# Used for wrapping a Triton Kernel
614614
class TritonKernelWrapperMutation(HigherOrderOperator):
615615
def __init__(self) -> None:
616-
super().__init__("triton_kernel_wrapper_mutation", cacheable=False)
616+
super().__init__("triton_kernel_wrapper_mutation", cacheable=True)
617617

618618
def __call__(
619619
self,
@@ -638,7 +638,7 @@ def __call__(
638638
# Used for wrapping a Triton Kernel in a functional manner
639639
class TritonKernelWrapperFunctional(HigherOrderOperator):
640640
def __init__(self) -> None:
641-
super().__init__("triton_kernel_wrapper_functional", cacheable=False)
641+
super().__init__("triton_kernel_wrapper_functional", cacheable=True)
642642

643643
def __call__(
644644
self,
@@ -774,6 +774,26 @@ def trace_triton_kernel_wrapper(
774774
proxy_args,
775775
name=func_overload.__name__ + "_proxy",
776776
)
777+
778+
from triton.runtime.autotuner import Autotuner
779+
780+
from torch._inductor.codegen.wrapper import (
781+
user_defined_triton_kernel_transitive_closure_source_code,
782+
)
783+
784+
kernel = kernel_side_table.get_kernel(proxy_args["kernel_idx"])
785+
if isinstance(kernel, Autotuner):
786+
kernel = kernel.fn
787+
788+
kernel_source = user_defined_triton_kernel_transitive_closure_source_code(kernel)
789+
constant_args = kernel_side_table.get_constant_args(proxy_args["constant_args_idx"])
790+
# we add to node here so that it gets included in the inductor cache key
791+
# when the graph is pickled
792+
out_proxy.node.meta["user_defined_triton_kernel_source_and_constant_args"] = (
793+
kernel_source,
794+
constant_args,
795+
)
796+
777797
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
778798
return ret
779799

torch/_inductor/codecache.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import hashlib
88
import importlib
99
import io
10+
import itertools
1011
import json
1112
import logging
1213
import os
@@ -525,7 +526,12 @@ class FxGraphCachePickler(pickle.Pickler):
525526
data that allow us to compute a stable, but safe hash.
526527
"""
527528

528-
def __init__(self, include_non_inlined: bool = True) -> None:
529+
def __init__(
530+
self,
531+
gm: torch.fx.GraphModule,
532+
include_non_inlined: bool = True,
533+
has_user_defined_triton_kernels: bool = False,
534+
) -> None:
529535
"""
530536
Create an FX graph pickler. If include_non_inlined=True, then pickling will
531537
include the _values_ for all Tensors. (Note that any tensors are constants
@@ -548,6 +554,11 @@ def __init__(self, include_non_inlined: bool = True) -> None:
548554
),
549555
}
550556
)
557+
if has_user_defined_triton_kernels:
558+
# Need to use runtime type as GraphModule generates a singleton in __new__ function
559+
self.dispatch_table[gm.__class__] = functools.partial(
560+
self._reduce_graph_module
561+
)
551562

552563
# Run with pickler.fast so it doesn't intern strings, making the hash result more predictable
553564
# TODO: pickler.fast is technically deprecated. Will this work on new python versions?
@@ -614,6 +625,25 @@ def _reduce_unsupported(self, s: Any) -> NoReturn:
614625
"""
615626
raise BypassFxGraphCache("Reduce unsupported")
616627

628+
def _reduce_graph_module(
629+
self, gm: torch.fx.GraphModule
630+
) -> Tuple[Any, Tuple[Dict[str, Any], str]]:
631+
"""
632+
Custom reducer for graph module to handle irrelevant data for user
633+
defined triton kernels
634+
Essentially what we are doing here is a huge hack where user defined
635+
triton kernel contain a dynamo time side table and the arguments to the
636+
call_function are indicies into this side table. These arguments are not
637+
for hashing purposes since we included the source code into the cache
638+
key and the numbers are prone to give false negatives due to ordering.
639+
"""
640+
fn, (data, imports) = gm.__reduce__()
641+
code = data["_code"]
642+
code = re.sub(r"kernel_idx = \d+", "", code)
643+
code = re.sub(r"constant_args_idx = \d+", "", code)
644+
data["_code"] = code
645+
return fn, (data, imports)
646+
617647
def dumps(self, obj: Any) -> bytes:
618648
"""
619649
Pickle an object and return a byte string.
@@ -775,6 +805,35 @@ def __init__(
775805
else:
776806
self.fx_kwargs[k] = v
777807

808+
from torch._higher_order_ops.triton_kernel_wrap import (
809+
triton_kernel_wrapper_functional,
810+
triton_kernel_wrapper_mutation,
811+
)
812+
813+
# Node meta will not be part of gm's reduce function, so lets remember
814+
# the kernel source code separately
815+
self.user_defined_triton_source: List[Any] = []
816+
if gm is not None:
817+
for module in gm.modules():
818+
if not isinstance(module, torch.fx.GraphModule):
819+
continue
820+
for node in itertools.chain(
821+
module.graph.find_nodes(
822+
op="call_function", target=triton_kernel_wrapper_functional
823+
),
824+
module.graph.find_nodes(
825+
op="call_function", target=triton_kernel_wrapper_mutation
826+
),
827+
):
828+
data = node.meta.get(
829+
"user_defined_triton_kernel_source_and_constant_args", None
830+
)
831+
if data is None:
832+
raise AssertionError(
833+
"TritonKernelWrapper does not contain source code meta"
834+
)
835+
self.user_defined_triton_source.append(data)
836+
778837
# Alignment checks
779838
self.inputs_to_check = inputs_to_check
780839

@@ -833,7 +892,10 @@ def compiled_fx_graph_hash(
833892
include_non_inlined = not has_frozen_params(gm)
834893

835894
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
836-
pickler = FxGraphCachePickler(include_non_inlined)
895+
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
896+
pickler = FxGraphCachePickler(
897+
gm, include_non_inlined, has_user_defined_triton_kernels
898+
)
837899
# The prefix distinguishes among the other kinds of objects we
838900
# cache in this module.
839901
key = "f" + pickler.get_hash(details)

torch/_inductor/codegen/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def writeline(line: str, example_grid: Optional[str] = None):
241241
def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str:
242242
"""
243243
Given a triton kernel function pointer collect the transitive closure of
244-
its dependancies
244+
its dependencies
245245
"""
246246
compile_wrapper = IndentedBuffer()
247247
compile_wrapper.splice(kernel.src, strip=True)

0 commit comments

Comments
 (0)