4949if HAS_TRITON :
5050 import triton # @manual
5151
52- from torch .testing ._internal .triton_utils import add_kernel
52+ from torch .testing ._internal .triton_utils import add_kernel , sub_kernel
5353
5454torch ._dynamo .config .fake_tensor_cache_enabled = True
5555torch ._dynamo .config .fake_tensor_cache_crosscheck_enabled = True
@@ -494,13 +494,41 @@ def fn2(q, k, v):
494494 @config .patch ({"fx_graph_cache" : True })
495495 @config .patch ({"fx_graph_remote_cache" : False })
496496 @parametrize ("bundle_triton" , (False , True ))
497- @parametrize ("grad" , (False , True ))
498- def test_triton_higher_order_op_bypass (self , bundle_triton , grad ):
497+ def test_higher_order_op_bypass (self , bundle_triton ):
499498 """
500- Verify that we bypass the cache when we have a triton higher order ops
499+ Verify that we bypass the cache when we have a higher order ops
501500 and that bundler start/end works with a cache bypass.
502501 """
503502
503+ def fn (x ):
504+ def true_fn (x : torch .Tensor ):
505+ return x .cos ()
506+
507+ def false_fn (x : torch .Tensor ):
508+ return x .sin ()
509+
510+ return torch .cond (x .shape [0 ], true_fn , false_fn , (x ,))
511+
512+ with config .patch (bundle_triton_into_fx_graph_cache = bundle_triton ):
513+ compiled_fn = torch .compile (fn , dynamic = True , fullgraph = True )
514+
515+ x = torch .randn (4 , 4 , device = GPU_TYPE )
516+ result = compiled_fn (x )
517+
518+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 0 )
519+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
520+ self .assertGreater (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
521+
522+ @requires_gpu ()
523+ @requires_triton ()
524+ @config .patch ({"fx_graph_cache" : True })
525+ @config .patch ({"fx_graph_remote_cache" : False })
526+ @parametrize ("bundle_triton" , (False , True ))
527+ def test_triton_higher_order_op (self , bundle_triton ):
528+ """
529+ Verify that we can cache user defined triton kernel higher order op
530+ """
531+
504532 def fn (x , y ):
505533 n_elements = x .numel ()
506534 grid = lambda meta : ( # noqa: E731
@@ -509,18 +537,54 @@ def fn(x, y):
509537 add_kernel [grid ](x , y , x , n_elements , BLOCK_SIZE = 4 )
510538 return x
511539
540+ def fn2 (x , y ):
541+ n_elements = x .numel ()
542+ grid = lambda meta : ( # noqa: E731
543+ triton .cdiv (n_elements , meta ["BLOCK_SIZE" ]),
544+ )
545+ sub_kernel [grid ](x , y , x , n_elements , BLOCK_SIZE = 4 )
546+ return x
547+
512548 with config .patch (bundle_triton_into_fx_graph_cache = bundle_triton ):
513549 compiled_fn = torch .compile (fn , fullgraph = True )
550+ compiled_fn2 = torch .compile (fn2 , fullgraph = True )
551+
552+ x = torch .randn (4 , device = GPU_TYPE )
553+ y = torch .randn (4 , device = GPU_TYPE )
514554
515- x = torch .randn (4 , device = GPU_TYPE , requires_grad = grad )
516- y = torch .randn (4 , device = GPU_TYPE , requires_grad = grad )
517555 result = compiled_fn (x , y )
518- if grad :
519- result .sum ().backward ()
520556
521- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 0 )
557+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
522558 self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
523- self .assertGreater (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
559+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
560+
561+ # A second call should hit. (First reset so in-memory guards
562+ # don't prevent compilation).
563+ self .reset ()
564+
565+ # Clean PyCodeCache and triton kernels
566+ PyCodeCache .cache_clear ()
567+ shutil .rmtree (os .path .join (cache_dir (), "triton" ), ignore_errors = True )
568+
569+ result = compiled_fn (x , y )
570+
571+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
572+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
573+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
574+
575+ # A second call should hit. (First reset so in-memory guards
576+ # don't prevent compilation).
577+ self .reset ()
578+
579+ # Clean PyCodeCache and triton kernels
580+ PyCodeCache .cache_clear ()
581+ shutil .rmtree (os .path .join (cache_dir (), "triton" ), ignore_errors = True )
582+
583+ result = compiled_fn2 (x , y )
584+
585+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
586+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
587+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
524588
525589 @config .patch ({"fx_graph_cache" : True })
526590 @config .patch ({"fx_graph_remote_cache" : False })
@@ -808,15 +872,16 @@ def test_tensor_constants(self):
808872 self .assertFalse (GraphLowering .can_inline_constant (large ))
809873
810874 # By default, we hash the metadata and values independent of the size.
811- pickler = FxGraphCachePickler ()
875+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
876+ pickler = FxGraphCachePickler (gm )
812877
813878 data = pickler .dumps (small )
814879 self .assertIsInstance (pickle .loads (data ), TensorMetadataAndValues )
815880 data = pickler .dumps (large )
816881 self .assertIsInstance (pickle .loads (data ), TensorMetadataAndValues )
817882
818883 # If include_non_inlined=False, we only hash the values of small tensors.
819- pickler = FxGraphCachePickler (False )
884+ pickler = FxGraphCachePickler (gm , False )
820885
821886 data = pickler .dumps (small )
822887 self .assertIsInstance (pickle .loads (data ), TensorMetadataAndValues )
@@ -827,7 +892,8 @@ def test_hash_fake_tensors(self):
827892 """
828893 Test hashing (pickling) FakeTensors with various characteristics.
829894 """
830- pickler = FxGraphCachePickler ()
895+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
896+ pickler = FxGraphCachePickler (gm )
831897 with torch ._subclasses .FakeTensorMode ():
832898 # Verify that FakeTensors get pickled into a TensorMetadata:
833899 data = pickler .dumps (torch .randn (1 ))
@@ -933,7 +999,8 @@ def test_hash_kwargs(self):
933999 Test the special handling of the kwargs when hashing, i.e.,
9341000 ordering of the kwargs dict and any set arguments.
9351001 """
936- pickler = FxGraphCachePickler ()
1002+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
1003+ pickler = FxGraphCachePickler (gm )
9371004
9381005 # Dict order of the kwargs should not affect hashes.
9391006 details1 = FxGraphHashDetails (None , [], {"a" : 0 , "z" : 1 }, [])
@@ -981,7 +1048,8 @@ def test_hash_config_changes(self):
9811048 with config .patch ({"max_autotune" : True }):
9821049 details3 = FxGraphHashDetails (None , [], {}, [])
9831050
984- pickler = FxGraphCachePickler ()
1051+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
1052+ pickler = FxGraphCachePickler (gm )
9851053
9861054 self .assertEqual (
9871055 pickler .dumps (details1 ),
@@ -1016,7 +1084,8 @@ def uuid(self) -> Optional[Union[bytes, str]]:
10161084 custom_pass ._uuid = "2"
10171085 details3 = FxGraphHashDetails (None , [], {}, [])
10181086
1019- pickler = FxGraphCachePickler ()
1087+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
1088+ pickler = FxGraphCachePickler (gm )
10201089
10211090 self .assertEqual (
10221091 pickler .dumps (details1 ),
@@ -1031,8 +1100,9 @@ def test_bypass_unsupported(self):
10311100 """
10321101 Test _reduce_unsupported
10331102 """
1103+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
10341104 with self .assertRaises (BypassFxGraphCache ):
1035- FxGraphCachePickler ().dumps (
1105+ FxGraphCachePickler (gm ).dumps (
10361106 torch .fx .experimental ._backward_state .BackwardState ()
10371107 )
10381108
@@ -1047,7 +1117,8 @@ def test_stable_strings(self):
10471117
10481118 self .assertNotEqual (id (s1 ), id (s2 ))
10491119
1050- pickler = FxGraphCachePickler ()
1120+ gm = torch .fx .GraphModule ({}, torch .fx .Graph ())
1121+ pickler = FxGraphCachePickler (gm )
10511122 self .assertEqual (
10521123 pickler .dumps ([s1 , s1 ]),
10531124 pickler .dumps ([s1 , s2 ]),
0 commit comments