4848TRITONPARSE_MORE_TENSOR_INFORMATION = os .getenv (
4949 "TRITONPARSE_MORE_TENSOR_INFORMATION" , None
5050) in ["1" , "true" , "True" ]
51+ # Enable full Python source file extraction instead of just the function definition
52+ TRITON_FULL_PYTHON_SOURCE = os .getenv ("TRITON_FULL_PYTHON_SOURCE" , "0" ) in [
53+ "1" ,
54+ "true" ,
55+ "True" ,
56+ ]
57+ # Maximum file size for full source extraction (default 10MB)
58+ TRITON_MAX_SOURCE_SIZE = int (os .getenv ("TRITON_MAX_SOURCE_SIZE" , str (10 * 1024 * 1024 )))
5159# Inductor compiled kernel's launch tracing needs this flag to be set.
5260# If TRITON_TRACE_LAUNCH is enabled, also enable TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK
5361TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = (
@@ -727,6 +735,17 @@ def extract_python_source_info(trace_data: Dict[str, Any], source):
727735 from the provided source object (typically an ASTSource or IRSource instance).
728736 It adds file path, line numbers, and the actual source code to the trace_data.
729737
738+ By default, only the function definition is extracted. Set TRITON_FULL_PYTHON_SOURCE=1
739+ to extract the entire Python source file.
740+ @TODO: we should enable it by default in next diff and track the compilation time regression
741+
742+ Environment Variables:
743+ TRITON_FULL_PYTHON_SOURCE: If set to "1", extract the full Python file
744+ instead of just the function definition.
745+ TRITON_MAX_SOURCE_SIZE: Maximum file size in bytes for full source extraction
746+ (default: 10MB). Files larger than this will fall back
747+ to function-only mode.
748+
730749 Args:
731750 trace_data (Dict[str, Any]): Dictionary to store extracted information
732751 source (Union[ASTSource, IRSource]): Source object containing kernel function information
@@ -738,23 +757,77 @@ def extract_python_source_info(trace_data: Dict[str, Any], source):
738757 if isinstance (source , IRSource ):
739758 return
740759
741- # Get the original Python source code for the kernel
760+ # Get the function reference
761+ if isinstance (fn := source .fn , JITFunction ):
762+ fn_ref = fn .fn
763+ else :
764+ fn_ref = source .fn
765+
766+ python_source_file = inspect .getfile (fn_ref )
767+
768+ # Get function range information
742769 if (
743770 isinstance (fn := source .fn , JITFunction )
744771 and hasattr (fn , "starting_line_number" )
745772 and hasattr (fn , "raw_src" )
746773 ):
747- start_line_number = fn .starting_line_number
774+ function_start_line = fn .starting_line_number
748775 source_lines = fn .raw_src
749776 else :
750- source_lines , start_line_number = inspect .getsourcelines (fn .fn )
777+ source_lines , function_start_line = inspect .getsourcelines (fn_ref )
778+
779+ function_end_line = function_start_line + len (source_lines ) - 1
780+
781+ if TRITON_FULL_PYTHON_SOURCE :
782+ # Full file mode: read the entire Python file
783+ try :
784+ # Check file size before reading
785+ file_size = os .path .getsize (python_source_file )
786+ except OSError as e :
787+ log .warning (
788+ f"Failed to check file size for { python_source_file } : { e } . "
789+ f"Falling back to function-only mode."
790+ )
791+ use_full_source = False
792+ else :
793+ if file_size > TRITON_MAX_SOURCE_SIZE :
794+ log .warning (
795+ f"Source file { python_source_file } is too large ({ file_size } bytes, "
796+ f"limit: { TRITON_MAX_SOURCE_SIZE } bytes). Falling back to function-only mode."
797+ )
798+ use_full_source = False
799+ else :
800+ use_full_source = True
801+
802+ if use_full_source :
803+ try :
804+ with open (python_source_file , "r" , encoding = "utf-8" ) as f :
805+ file_content = f .read ()
806+
807+ # Calculate total lines
808+ total_lines = len (file_content .split ("\n " ))
809+
810+ trace_data ["python_source" ] = {
811+ "file_path" : python_source_file ,
812+ "start_line" : 1 ,
813+ "end_line" : total_lines ,
814+ "code" : file_content ,
815+ # Add function range for frontend highlighting and scrolling
816+ "function_start_line" : function_start_line ,
817+ "function_end_line" : function_end_line ,
818+ }
819+ return
820+ except (OSError , UnicodeDecodeError ) as e :
821+ log .warning (
822+ f"Failed to read full source file { python_source_file } : { e } . "
823+ f"Falling back to function-only mode."
824+ )
751825
752- python_source_file = inspect .getfile (fn .fn )
753- end_line_number = start_line_number + len (source_lines )
826+ # Default behavior: only extract function definition
754827 trace_data ["python_source" ] = {
755828 "file_path" : python_source_file ,
756- "start_line" : start_line_number ,
757- "end_line" : end_line_number ,
829+ "start_line" : function_start_line ,
830+ "end_line" : function_end_line ,
758831 "code" : "" .join (source_lines ),
759832 }
760833
@@ -910,7 +983,7 @@ def get_root_dir(self):
910983 )
911984 elif not os .access (TRACE_LOG_DIR , os .W_OK ):
912985 log .info (
913- "TritonTraceHandler: disabled because %s is not writeable " ,
986+ "TritonTraceHandler: disabled because %s is not writable " ,
914987 TRACE_LOG_DIR ,
915988 )
916989 else :
0 commit comments