@@ -60,6 +60,11 @@ class Edge:
6060 call_type : str
6161 callee_descriptor : FuncDescriptor
6262
63+ def split_by_the_last_dot (s : str ) -> Optional [Tuple [str , str ]]:
64+ if "." in s :
65+ return s .rsplit ("." , 1 )
66+ else :
67+ return None , s
6368
6469class CallGraph (ast .NodeVisitor ):
6570 def __init__ (
@@ -158,6 +163,12 @@ def _record_call(
158163 ):
159164 if caller is None :
160165 caller = self ._cur_scope () or "<module>"
166+ # replace callee with caller class name if it is "self." call
167+ if "." in caller and callee .startswith ("self." ):
168+ caller_prefix , _ = split_by_the_last_dot (caller )
169+ # remove the "self." prefix
170+ callee_name = callee [5 :]
171+ callee = caller_prefix + "." + callee_name
161172 site = Site (
162173 self .filename , getattr (node , "lineno" , - 1 ), getattr (node , "col_offset" , - 1 )
163174 )
@@ -167,8 +178,6 @@ def _record_call(
167178 and any ([f"Operator.{ backend } " in caller for backend in self .backends ])
168179 and not callee == "tritonbench.utils.triton_op.register_benchmark"
169180 ):
170- if callee .startswith ("self." ):
171- return
172181 if is_fbcode () and callee .startswith ("liger_kernel." ):
173182 return
174183 # identify this call belongs to which backend
@@ -188,9 +197,7 @@ def _record_call(
188197 )
189198 )
190199 elif any ([backend in caller for backend in self .backends ]):
191- # "torch.ops" is a binary custom ops
192- if callee .startswith ("self." ):
193- return
200+ # skip aten calls
194201 if callee .startswith ("torch." ) and not callee .startswith ("torch.ops." ):
195202 return
196203 # we are sure there is no kernel defined in this package ;-)
@@ -202,7 +209,7 @@ def _record_call(
202209 if maybe_triton and callee_descriptor == None :
203210 callee_descriptor = FuncDescriptor (
204211 callee ,
205- ["triton.jit" ],
212+ ["maybe. triton.jit" ],
206213 Site (
207214 self .filename ,
208215 getattr (node , "lineno" , - 1 ),
@@ -362,6 +369,27 @@ def visit_Call(self, node: ast.Call):
362369 callee = fn .value .id
363370 elif isinstance (fn .value , ast .Attribute ):
364371 callee = fn .value .value .id
372+ if hasattr (fn .value , "attr" ):
373+ callee = callee + "." + fn .value .attr
374+ elif isinstance (fn .value , ast .Call ):
375+ # add hack to handle torch._library.capture_triton
376+ if (
377+ isinstance (fn .value .func , ast .Name )
378+ and fn .value .func .id == "capture_triton"
379+ ) or (
380+ isinstance (fn .value .func .value , ast .Attribute )
381+ and fn .value .func .value .attr == "_library"
382+ and fn .value .func .attr == "capture_triton"
383+ ):
384+ if isinstance (fn .value .args [0 ], ast .Call ):
385+ callee_func = fn .value .args [0 ].func .id
386+ elif isinstance (fn .value .args [0 ], ast .Name ):
387+ callee_func = fn .value .args [0 ].id
388+ elif isinstance (fn .value .args [0 ], ast .Attribute ):
389+ callee_func = fn .value .args [0 ].value .id
390+ else :
391+ callee_func = "unknown"
392+ callee = f"<torch._library.capture_triton({ callee_func } )>"
365393 else :
366394 callee = "<dynamic_call>"
367395 maybe_triton = True # FIXME: this could also be cute, see blackwell_attentions cute dsl
@@ -377,16 +405,22 @@ def validate_edges(edges) -> Dict[str, str]:
377405 result_tags ["tags" ] = []
378406 result_tags ["kernels" ] = []
379407 for edge in edges :
408+ if edge .callee == "cutlass.cute.compile" :
409+ result_tags ["tags" ].append ("cutedsl" )
410+ result_tags ["kernels" ].append (edge .caller )
380411 if edge .callee_descriptor and (
381412 "triton.jit" in edge .callee_descriptor .decorators
382413 or "<dynamic_decorator_triton.jit>" in edge .callee_descriptor .decorators
383414 ):
384415 result_tags ["tags" ].append ("triton" )
385416 result_tags ["kernels" ].append (edge .callee )
386- if edge .callee == "cutlass.cute.compile" :
387- result_tags ["tags" ].append ("cutedsl" )
388- result_tags ["kernels" ].append (edge .caller )
389- if edge .callee .startswith ("torch.ops." ):
417+ if edge .callee .startswith ("<torch._library.capture_triton" ):
418+ result_tags ["tags" ].append ("triton" )
419+ result_tags ["kernels" ].append (edge .callee )
420+ if (
421+ edge .callee .startswith ("torch.ops." )
422+ and not "cutedsl" in result_tags ["tags" ]
423+ ):
390424 result_tags ["tags" ].append ("native_custom_ops" )
391425 # definition is in cpp, so we don't have the definition site
392426 result_tags ["kernels" ].append (edge .callee )
@@ -398,16 +432,27 @@ def validate_edges(edges) -> Dict[str, str]:
398432 if edge .callee .startswith ("tilelang.compile" ):
399433 result_tags ["tags" ].append ("tilelang" )
400434 result_tags ["kernels" ].append (edge .caller )
435+ if "torch.ops.fbgemm" in edge .callee :
436+ result_tags ["tags" ].append ("fbgemm" )
437+ # heuristic: if no tag is found and maybe triton, apply triton tag
438+ if (
439+ not result_tags ["tags" ]
440+ and edge .callee_descriptor
441+ and "maybe.triton.jit" in edge .callee_descriptor .decorators
442+ ):
443+ result_tags ["tags" ].append ("triton" )
444+ result_tags ["kernels" ].append (edge .callee )
401445 # remove duplicates
402446 result_tags ["tags" ] = list (set (result_tags ["tags" ]))
447+ result_tags ["kernels" ] = list (set (result_tags ["kernels" ]))
403448 if not result_tags ["kernels" ] and not result_tags ["tags" ]:
404449 return None
405450 return result_tags
406451
407452
408453def gen_static_extension_tags (callee : str ) -> Dict [str , str ]:
409454 result_tags = {}
410- result_tags ["tags" ] = { "native_extension" }
455+ result_tags ["tags" ] = [ "native_extension" ]
411456 result_tags ["kernels" ] = [callee ]
412457 return result_tags
413458
@@ -433,55 +478,64 @@ def trace_callees(callees_with_module: List[Tuple[str, str]], depth=8):
433478 if callee .endswith (".apply" ):
434479 callee = callee [: callee .rfind (".apply" )] + ".forward"
435480
436- callee_module = callee [: callee .rfind ("." )] if "." in callee else None
437- callee_name = callee [callee .rfind ("." ) + 1 :] if "." in callee else callee
438- maybe_callee_module = (
439- callee_module [: callee_module .rfind ("." )]
440- if callee_module and "." in callee_module
441- else None
442- )
443- maybe_callee_class = (
444- callee_module [callee_module .rfind ("." ) + 1 :]
445- if callee_module and "." in callee_module
446- else None
447- )
481+ callee_module , callee_name = split_by_the_last_dot (callee )
482+ maybe_callee_module , maybe_callee_class = split_by_the_last_dot (callee_module )
483+ parent_module_name , _child_module_name = split_by_the_last_dot (module_name )
484+
448485 # best effort to find and import the module
449486 # print(f"callee: {callee}")
450487 # print(f"callee module: {callee_module}")
451488 # print(f"callee name: {callee_name}")
452489 # print(f"module name: {module_name}")
453490 # print(f"maybe callee module: {maybe_callee_module}")
454491 # print(f"maybe callee class: {maybe_callee_class}")
455- if not callee_module and not maybe_callee_module :
492+ if callee_module == None and maybe_callee_module == None :
456493 continue
457494 try :
458495 module = importlib .import_module (callee_module )
459496 source_file = inspect .getfile (module )
497+ if not hasattr (module , callee_name ):
498+ raise ModuleNotFoundError (f"Not found { callee_name } in { module } " )
460499 except (ModuleNotFoundError , TypeError ):
461500 try :
462501 # try with relative import
463- module = importlib .import_module (f"{ module_name } .{ callee_module } " )
502+ if parent_module_name is None :
503+ parent_module_name = module_name
504+ module = importlib .import_module (
505+ f"{ parent_module_name } .{ callee_module } "
506+ )
464507 source_file = inspect .getfile (module )
465508 except (ModuleNotFoundError , TypeError ):
466509 if maybe_callee_module == None :
467510 continue
468511 try :
469512 module = importlib .import_module (maybe_callee_module )
470513 source_file = inspect .getfile (module )
514+ if not hasattr (module , maybe_callee_class ):
515+ raise ModuleNotFoundError (
516+ f"Not found { maybe_callee_class } in { module } "
517+ )
471518 callee_name = f"{ maybe_callee_class } .{ callee_name } "
472519 except (ModuleNotFoundError , TypeError ):
473520 try :
521+ # try with relative import
522+ if parent_module_name is None :
523+ parent_module_name = module_name
474524 module = importlib .import_module (
475- f"{ module_name } .{ maybe_callee_module } "
525+ f"{ parent_module_name } .{ maybe_callee_module } "
476526 )
477527 source_file = inspect .getfile (module )
478528 callee_name = f"{ maybe_callee_class } .{ callee_name } "
479- except Exception :
529+ except Exception as e :
480530 # give up
481531 print (
482- f"Failed to load module { maybe_callee_module } from entity { callee } "
532+ f"Failed to load module { maybe_callee_module } from entity { callee } : { e } "
483533 )
484534 continue
535+ except Exception :
536+ # give up
537+ print (f"Failed to load module { callee_module } from entity { callee } " )
538+ continue
485539 if not module :
486540 print (f"Failed to find { callee } at module { callee_module } " )
487541 continue
0 commit comments