Skip to content

Commit 2976887

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add full Python source extraction support
Summary: Add support for extracting full Python source files in addition to the default function-only mode. This provides better context for debugging and analysis by showing the complete file with function boundaries highlighted. Changes: - Add TRITON_FULL_PYTHON_SOURCE env var (default: 0, function-only mode) - Add TRITON_MAX_SOURCE_SIZE env var (default: 10MB size limit) - Modify extract_python_source_info() to support two modes: * Default: Extract only function definition (existing behavior) * Full mode: Extract entire Python file with function range markers - Add function_start_line and function_end_line fields to output - Add file size checking and error handling for large files - Fix docstring: correct default size limit from 1MB to 10MB Reviewed By: adamomainz Differential Revision: D85825470 fbshipit-source-id: 90c0f7f915a708761c043b9973a22faa6e8c3938
1 parent 8b2f09f commit 2976887

File tree

1 file changed

+81
-8
lines changed

1 file changed

+81
-8
lines changed

tritonparse/structured_logging.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@
4848
TRITONPARSE_MORE_TENSOR_INFORMATION = os.getenv(
4949
"TRITONPARSE_MORE_TENSOR_INFORMATION", None
5050
) in ["1", "true", "True"]
51+
# Enable full Python source file extraction instead of just the function definition
52+
TRITON_FULL_PYTHON_SOURCE = os.getenv("TRITON_FULL_PYTHON_SOURCE", "0") in [
53+
"1",
54+
"true",
55+
"True",
56+
]
57+
# Maximum file size for full source extraction (default 10MB)
58+
TRITON_MAX_SOURCE_SIZE = int(os.getenv("TRITON_MAX_SOURCE_SIZE", str(10 * 1024 * 1024)))
5159
# Inductor compiled kernel's launch tracing needs this flag to be set.
5260
# If TRITON_TRACE_LAUNCH is enabled, also enable TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK
5361
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = (
@@ -727,6 +735,17 @@ def extract_python_source_info(trace_data: Dict[str, Any], source):
727735
from the provided source object (typically an ASTSource or IRSource instance).
728736
It adds file path, line numbers, and the actual source code to the trace_data.
729737
738+
By default, only the function definition is extracted. Set TRITON_FULL_PYTHON_SOURCE=1
739+
to extract the entire Python source file.
740+
@TODO: we should enable it by default in next diff and track the compilation time regression
741+
742+
Environment Variables:
743+
TRITON_FULL_PYTHON_SOURCE: If set to "1", extract the full Python file
744+
instead of just the function definition.
745+
TRITON_MAX_SOURCE_SIZE: Maximum file size in bytes for full source extraction
746+
(default: 10MB). Files larger than this will fall back
747+
to function-only mode.
748+
730749
Args:
731750
trace_data (Dict[str, Any]): Dictionary to store extracted information
732751
source (Union[ASTSource, IRSource]): Source object containing kernel function information
@@ -738,23 +757,77 @@ def extract_python_source_info(trace_data: Dict[str, Any], source):
738757
if isinstance(source, IRSource):
739758
return
740759

741-
# Get the original Python source code for the kernel
760+
# Get the function reference
761+
if isinstance(fn := source.fn, JITFunction):
762+
fn_ref = fn.fn
763+
else:
764+
fn_ref = source.fn
765+
766+
python_source_file = inspect.getfile(fn_ref)
767+
768+
# Get function range information
742769
if (
743770
isinstance(fn := source.fn, JITFunction)
744771
and hasattr(fn, "starting_line_number")
745772
and hasattr(fn, "raw_src")
746773
):
747-
start_line_number = fn.starting_line_number
774+
function_start_line = fn.starting_line_number
748775
source_lines = fn.raw_src
749776
else:
750-
source_lines, start_line_number = inspect.getsourcelines(fn.fn)
777+
source_lines, function_start_line = inspect.getsourcelines(fn_ref)
778+
779+
function_end_line = function_start_line + len(source_lines) - 1
780+
781+
if TRITON_FULL_PYTHON_SOURCE:
782+
# Full file mode: read the entire Python file
783+
try:
784+
# Check file size before reading
785+
file_size = os.path.getsize(python_source_file)
786+
except OSError as e:
787+
log.warning(
788+
f"Failed to check file size for {python_source_file}: {e}. "
789+
f"Falling back to function-only mode."
790+
)
791+
use_full_source = False
792+
else:
793+
if file_size > TRITON_MAX_SOURCE_SIZE:
794+
log.warning(
795+
f"Source file {python_source_file} is too large ({file_size} bytes, "
796+
f"limit: {TRITON_MAX_SOURCE_SIZE} bytes). Falling back to function-only mode."
797+
)
798+
use_full_source = False
799+
else:
800+
use_full_source = True
801+
802+
if use_full_source:
803+
try:
804+
with open(python_source_file, "r", encoding="utf-8") as f:
805+
file_content = f.read()
806+
807+
# Calculate total lines
808+
total_lines = len(file_content.split("\n"))
809+
810+
trace_data["python_source"] = {
811+
"file_path": python_source_file,
812+
"start_line": 1,
813+
"end_line": total_lines,
814+
"code": file_content,
815+
# Add function range for frontend highlighting and scrolling
816+
"function_start_line": function_start_line,
817+
"function_end_line": function_end_line,
818+
}
819+
return
820+
except (OSError, UnicodeDecodeError) as e:
821+
log.warning(
822+
f"Failed to read full source file {python_source_file}: {e}. "
823+
f"Falling back to function-only mode."
824+
)
751825

752-
python_source_file = inspect.getfile(fn.fn)
753-
end_line_number = start_line_number + len(source_lines)
826+
# Default behavior: only extract function definition
754827
trace_data["python_source"] = {
755828
"file_path": python_source_file,
756-
"start_line": start_line_number,
757-
"end_line": end_line_number,
829+
"start_line": function_start_line,
830+
"end_line": function_end_line,
758831
"code": "".join(source_lines),
759832
}
760833

@@ -910,7 +983,7 @@ def get_root_dir(self):
910983
)
911984
elif not os.access(TRACE_LOG_DIR, os.W_OK):
912985
log.info(
913-
"TritonTraceHandler: disabled because %s is not writeable",
986+
"TritonTraceHandler: disabled because %s is not writable",
914987
TRACE_LOG_DIR,
915988
)
916989
else:

0 commit comments

Comments
 (0)