Skip to content

Commit ad41299

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add Callsite Location Support in TTIR/TTGIR (#190)
Summary: ## Overview This PR adds support for parsing MLIR callsite locations in Triton IR (TTIR/TTGIR), enabling tritonparse to capture and preserve call stack information for inlined functions. ## Problem Previously, tritonparse could not parse callsite location definitions like: ```mlir #loc220 = loc(callsite(#loc57 at #loc190)) ``` This caused incomplete source code mappings, losing critical information about: - Function call chains (which function called which) - Inlining context (how nested functions are represented) - Complete source attribution (full path through code to reach an operation) ## Solution Implemented a **hybrid approach** that: 1. Parses callsite definitions and inherits location from the callee (code being called) 2. Preserves caller references as metadata for call stack traversal 3. Propagates callsite information through the entire mapping pipeline ### Key Design Decisions - **Callee as primary location**: Maps to the actual code being executed (most relevant for debugging) - **Metadata preservation**: Stores references (loc IDs) rather than fully expanding call stacks - **Backward compatible**: Adds optional fields without breaking existing tools - **Extensible**: Future enhancements can traverse and expand call chains on demand ## Implementation Details ### 1. Added Callsite Pattern (`ir_parser.py`) ```python CALLSITE_PATTERN = re.compile( r"#loc(\d+)\s*=\s*loc\(\s*callsite\(\s*#loc(\d*)\s+at\s+#loc(\d*)\s*\)\s*\)" ) ``` ### 2. Enhanced `extract_loc_definitions()` (`ir_parser.py`) - Collects all callsite definitions during IR parsing - Resolves callsite references by inheriting location info from callee - Stores callsite metadata: `is_callsite`, `callsite_callee`, `callsite_caller` - Validates references with warning messages for undefined locs ### 3. Updated `generate_source_mappings()` (`trace_processor.py`) - Propagates callsite metadata from `loc_defs` to final `mappings` - Enables downstream tools to identify and traverse call chains ## Data Structure Example For a nested callsite like: ```mlir #loc7 = loc("file.py":1091:8) #loc57 = loc("file.py":421:16) #loc58 = loc("file.py":853:16) #loc190 = loc(callsite(#loc58 at #loc7)) #loc220 = loc(callsite(#loc57 at #loc190)) %0 = tt.load %ptr loc(#loc220) ``` The resulting mapping for line 131 (where `tt.load` is): ```json { "file": "file.py", "line": 421, // From callee (loc57) - actual code executing "column": 16, "ttir_line": 131, "is_callsite": true, "callsite_callee": "57", // Reference to called code "callsite_caller": "190" // Reference to caller (can traverse chain) } ``` **Call chain represented**: `_ragged_hstu_attn_fwd` (1091:8) → `_ragged_hstu_attn_fwd_compute` (853:16) → `_ragged_hstu_attn_fwd_one_block` (421:16) ← **executing here** ## Testing Added comprehensive unit tests in `tests/test_tritonparse.py`: - `TestTritonparseCPU::test_callsite_parsing` - Validates nested callsite parsing - Verifies metadata propagation to mappings - Tests both simple and nested callsite scenarios **All tests pass ✅** ## Impact ### Benefits - ✅ Complete source mapping for inlined functions - ✅ Preserves call stack information for debugging - ✅ Enables future call chain visualization - ✅ Backward compatible with existing tools ### No Breaking Changes - Only adds optional fields to existing mappings - Existing code that doesn't check for callsite fields continues to work - No changes to public APIs ## Files Changed 1. `tritonparse/ir_parser.py` - Added callsite parsing logic 2. `tritonparse/trace_processor.py` - Propagate callsite metadata 3. `tests/test_tritonparse.py` - Added unit tests 4. `CALLSITE_IMPLEMENTATION.md` - Detailed implementation documentation ## Future Work Potential enhancements (not in this PR): 1. Automatic call stack expansion utility functions 2. Call stack caching for performance optimization 3. Frontend UI support for call chain visualization 4. Cycle detection for complex callsite graphs Pull Request resolved: #190 Reviewed By: njriasan Differential Revision: D85681542 Pulled By: FindHao fbshipit-source-id: aeadd4cb65a453f50a8d48ca7ea68a40515a1f07
1 parent cd46292 commit ad41299

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

tests/test_tritonparse.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,63 @@ def clear_all_caches(*kernels):
139139
class TestTritonparseCPU(unittest.TestCase):
140140
"""CPU-only tests (no CUDA required)"""
141141

142+
def test_callsite_parsing(self):
143+
"""Test parsing of callsite locations in TTIR/TTGIR"""
144+
from tritonparse.ir_parser import extract_loc_definitions
145+
from tritonparse.trace_processor import generate_source_mappings
146+
147+
# Test MLIR callsite location definitions
148+
ir_with_callsite = """
149+
module {
150+
#loc7 = loc("/tmp/test.py":1091:8)
151+
#loc57 = loc("/tmp/test.py":421:16)
152+
#loc58 = loc("/tmp/test.py":853:16)
153+
#loc190 = loc(callsite(#loc58 at #loc7))
154+
#loc220 = loc(callsite(#loc57 at #loc190))
155+
%0 = tt.load %ptr loc(#loc220)
156+
}
157+
"""
158+
# Extract loc definitions
159+
locs = extract_loc_definitions(ir_with_callsite)
160+
161+
# Verify loc220 (nested callsite)
162+
self.assertIn("220", locs)
163+
self.assertEqual(locs["220"]["file"], "/tmp/test.py")
164+
self.assertEqual(locs["220"]["line"], 421) # Inherited from callee loc57
165+
self.assertEqual(locs["220"]["column"], 16)
166+
self.assertTrue(locs["220"].get("is_callsite"))
167+
self.assertEqual(locs["220"]["callsite_callee"], "57")
168+
self.assertEqual(locs["220"]["callsite_caller"], "190")
169+
170+
# Verify loc190 (simple callsite)
171+
self.assertIn("190", locs)
172+
self.assertEqual(locs["190"]["line"], 853) # Inherited from callee loc58
173+
self.assertTrue(locs["190"].get("is_callsite"))
174+
self.assertEqual(locs["190"]["callsite_callee"], "58")
175+
self.assertEqual(locs["190"]["callsite_caller"], "7")
176+
177+
# Test source mappings generation
178+
mappings = generate_source_mappings(ir_with_callsite, "ttir")
179+
180+
# Find the line with tt.load
181+
line_with_load = None
182+
for line_num, content in enumerate(ir_with_callsite.split("\n"), start=1):
183+
if "tt.load" in content:
184+
line_with_load = str(line_num)
185+
break
186+
187+
self.assertIsNotNone(line_with_load)
188+
self.assertIn(line_with_load, mappings)
189+
190+
mapping = mappings[line_with_load]
191+
self.assertEqual(mapping["file"], "/tmp/test.py")
192+
self.assertEqual(mapping["line"], 421) # From loc220 -> loc57
193+
self.assertTrue(mapping.get("is_callsite"))
194+
self.assertEqual(mapping["callsite_callee"], "57")
195+
self.assertEqual(mapping["callsite_caller"], "190")
196+
197+
print("✓ Callsite parsing tests passed")
198+
142199
def test_convert(self):
143200
"""Test convert function with various data types"""
144201
# Test with primitive types

tritonparse/ir_parser.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444
# Example: #loc20 = loc(#loc16)
4545
ALIAS_SIMPLE_PATTERN = re.compile(r"#loc(\d+)\s*=\s*loc\(\s*#loc(\d*)\s*\)")
4646

47+
# Callsite loc definitions in TTIR/TTGIR
48+
# Example: #loc220 = loc(callsite(#loc57 at #loc190))
49+
# Captures: loc_id, callee_loc_id, caller_loc_id
50+
# Note: Uses (\d*) to match optional numbers (for bare #loc references)
51+
CALLSITE_PATTERN = re.compile(
52+
r"#loc(\d+)\s*=\s*loc\(\s*callsite\(\s*#loc(\d*)\s+at\s+#loc(\d*)\s*\)\s*\)"
53+
)
54+
4755

4856
def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
4957
"""
@@ -141,6 +149,50 @@ def resolve_alias(current_id: str) -> Dict[str, Any]:
141149
for alias_id, target_id in alias_map.items():
142150
if alias_id not in locations:
143151
resolve_alias(alias_id)
152+
153+
# Collect callsite definitions
154+
callsite_defs = []
155+
for i, line in enumerate(ir_content.split("\n"), start=1):
156+
if m := CALLSITE_PATTERN.search(line):
157+
loc_id, callee_id, caller_id = m.groups()
158+
# Empty strings map to main loc key ""
159+
callsite_defs.append((loc_id, callee_id or "", caller_id or "", i))
160+
161+
# Resolve callsite definitions
162+
# A callsite inherits the location from its callee (the code being called)
163+
# and stores a reference to its caller (the code doing the calling)
164+
for loc_id, callee_id, caller_id, def_line in callsite_defs:
165+
if loc_id not in locations: # Avoid overwriting existing definitions
166+
if callee_id in locations:
167+
# Inherit location info from callee
168+
callee_info = locations[callee_id]
169+
locations[loc_id] = {
170+
"file": callee_info["file"],
171+
"line": callee_info["line"],
172+
"column": callee_info["column"],
173+
"def_line": def_line,
174+
"is_callsite": True,
175+
"callsite_callee": callee_id,
176+
"callsite_caller": caller_id,
177+
}
178+
else:
179+
logger.warning(
180+
f"Callsite #loc{loc_id} references undefined callee #loc{callee_id}"
181+
)
182+
# Note: We don't add this callsite to locations since callee is missing
183+
184+
# Verify caller references (warning only, don't block)
185+
for loc_id, _callee_id, caller_id, _def_line in callsite_defs:
186+
if loc_id in locations and caller_id and caller_id not in locations:
187+
logger.warning(
188+
f"Callsite #loc{loc_id} references undefined caller #loc{caller_id}"
189+
)
190+
191+
# Attach definition line and alias metadata
192+
for k, v in def_line_map.items():
193+
if k in locations:
194+
locations[k]["def_line"] = v
195+
for alias_id, target_id in alias_map.items():
144196
if alias_id in locations:
145197
locations[alias_id]["alias_of"] = target_id
146198
if alias_id in alias_name_map:

tritonparse/trace_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def generate_source_mappings(
7777
"column": info["column"],
7878
f"{ir_type}_line": ln,
7979
}
80+
# Propagate callsite metadata if present
81+
if info.get("is_callsite"):
82+
entry["is_callsite"] = True
83+
entry["callsite_callee"] = info["callsite_callee"]
84+
entry["callsite_caller"] = info["callsite_caller"]
8085
# Propagate alias metadata if present
8186
if "alias_name" in info:
8287
entry["alias_name"] = info["alias_name"]

0 commit comments

Comments
 (0)