Skip to content

Commit c7585a5

Browse files
committed
Add MLIR callsite location parsing support
Implement parsing for MLIR callsite locations (e.g., `loc(callsite(#loc57 at #loc190))`) to preserve function call chain information for inlined functions. Changes: - Add CALLSITE_PATTERN regex to recognize callsite definitions - Extend extract_loc_definitions() to collect and resolve callsite references - Update generate_source_mappings() to propagate callsite metadata - Add comprehensive unit tests in TestTritonparseCPU::test_callsite_parsing Callsite locations inherit file/line/column from the callee (actual code executing) while preserving caller references as metadata for call stack traversal. This enables complete source mapping for inlined functions without breaking existing tools. Test Plan: python -m pytest tests/test_tritonparse.py::TestTritonparseCPU::test_callsite_parsing -v
1 parent 908a9a2 commit c7585a5

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

tests/test_tritonparse.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,63 @@ def clear_all_caches(*kernels):
139139
class TestTritonparseCPU(unittest.TestCase):
140140
"""CPU-only tests (no CUDA required)"""
141141

142+
def test_callsite_parsing(self):
143+
"""Test parsing of callsite locations in TTIR/TTGIR"""
144+
from tritonparse.ir_parser import extract_loc_definitions
145+
from tritonparse.trace_processor import generate_source_mappings
146+
147+
# Test MLIR callsite location definitions
148+
ir_with_callsite = """
149+
module {
150+
#loc7 = loc("/tmp/test.py":1091:8)
151+
#loc57 = loc("/tmp/test.py":421:16)
152+
#loc58 = loc("/tmp/test.py":853:16)
153+
#loc190 = loc(callsite(#loc58 at #loc7))
154+
#loc220 = loc(callsite(#loc57 at #loc190))
155+
%0 = tt.load %ptr loc(#loc220)
156+
}
157+
"""
158+
# Extract loc definitions
159+
locs = extract_loc_definitions(ir_with_callsite)
160+
161+
# Verify loc220 (nested callsite)
162+
self.assertIn("220", locs)
163+
self.assertEqual(locs["220"]["file"], "/tmp/test.py")
164+
self.assertEqual(locs["220"]["line"], 421) # Inherited from callee loc57
165+
self.assertEqual(locs["220"]["column"], 16)
166+
self.assertTrue(locs["220"].get("is_callsite"))
167+
self.assertEqual(locs["220"]["callsite_callee"], "57")
168+
self.assertEqual(locs["220"]["callsite_caller"], "190")
169+
170+
# Verify loc190 (simple callsite)
171+
self.assertIn("190", locs)
172+
self.assertEqual(locs["190"]["line"], 853) # Inherited from callee loc58
173+
self.assertTrue(locs["190"].get("is_callsite"))
174+
self.assertEqual(locs["190"]["callsite_callee"], "58")
175+
self.assertEqual(locs["190"]["callsite_caller"], "7")
176+
177+
# Test source mappings generation
178+
mappings = generate_source_mappings(ir_with_callsite, "ttir")
179+
180+
# Find the line with tt.load
181+
line_with_load = None
182+
for line_num, content in enumerate(ir_with_callsite.split("\n"), start=1):
183+
if "tt.load" in content:
184+
line_with_load = str(line_num)
185+
break
186+
187+
self.assertIsNotNone(line_with_load)
188+
self.assertIn(line_with_load, mappings)
189+
190+
mapping = mappings[line_with_load]
191+
self.assertEqual(mapping["file"], "/tmp/test.py")
192+
self.assertEqual(mapping["line"], 421) # From loc220 -> loc57
193+
self.assertTrue(mapping.get("is_callsite"))
194+
self.assertEqual(mapping["callsite_callee"], "57")
195+
self.assertEqual(mapping["callsite_caller"], "190")
196+
197+
print("✓ Callsite parsing tests passed")
198+
142199
def test_convert(self):
143200
"""Test convert function with various data types"""
144201
# Test with primitive types
@@ -1506,6 +1563,63 @@ def test_loc_alias_parsing(self):
15061563

15071564
print("✓ All loc alias parsing tests passed")
15081565

1566+
def test_callsite_parsing(self):
1567+
"""Test parsing of callsite locations in TTIR/TTGIR"""
1568+
from tritonparse.ir_parser import extract_loc_definitions
1569+
from tritonparse.trace_processor import generate_source_mappings
1570+
1571+
# Test MLIR callsite location definitions
1572+
ir_with_callsite = """
1573+
module {
1574+
#loc7 = loc("/tmp/test.py":1091:8)
1575+
#loc57 = loc("/tmp/test.py":421:16)
1576+
#loc58 = loc("/tmp/test.py":853:16)
1577+
#loc190 = loc(callsite(#loc58 at #loc7))
1578+
#loc220 = loc(callsite(#loc57 at #loc190))
1579+
%0 = tt.load %ptr loc(#loc220)
1580+
}
1581+
"""
1582+
# Extract loc definitions
1583+
locs = extract_loc_definitions(ir_with_callsite)
1584+
1585+
# Verify loc220 (nested callsite)
1586+
self.assertIn("220", locs)
1587+
self.assertEqual(locs["220"]["file"], "/tmp/test.py")
1588+
self.assertEqual(locs["220"]["line"], 421) # Inherited from callee loc57
1589+
self.assertEqual(locs["220"]["column"], 16)
1590+
self.assertTrue(locs["220"].get("is_callsite"))
1591+
self.assertEqual(locs["220"]["callsite_callee"], "57")
1592+
self.assertEqual(locs["220"]["callsite_caller"], "190")
1593+
1594+
# Verify loc190 (simple callsite)
1595+
self.assertIn("190", locs)
1596+
self.assertEqual(locs["190"]["line"], 853) # Inherited from callee loc58
1597+
self.assertTrue(locs["190"].get("is_callsite"))
1598+
self.assertEqual(locs["190"]["callsite_callee"], "58")
1599+
self.assertEqual(locs["190"]["callsite_caller"], "7")
1600+
1601+
# Test source mappings generation
1602+
mappings = generate_source_mappings(ir_with_callsite, "ttir")
1603+
1604+
# Find the line with tt.load
1605+
line_with_load = None
1606+
for line_num, content in enumerate(ir_with_callsite.split("\n"), start=1):
1607+
if "tt.load" in content:
1608+
line_with_load = str(line_num)
1609+
break
1610+
1611+
self.assertIsNotNone(line_with_load)
1612+
self.assertIn(line_with_load, mappings)
1613+
1614+
mapping = mappings[line_with_load]
1615+
self.assertEqual(mapping["file"], "/tmp/test.py")
1616+
self.assertEqual(mapping["line"], 421) # From loc220 -> loc57
1617+
self.assertTrue(mapping.get("is_callsite"))
1618+
self.assertEqual(mapping["callsite_callee"], "57")
1619+
self.assertEqual(mapping["callsite_caller"], "190")
1620+
1621+
print("✓ Callsite parsing tests passed")
1622+
15091623

15101624
if __name__ == "__main__":
15111625
unittest.main()

tritonparse/ir_parser.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444
# Example: #loc20 = loc(#loc16)
4545
ALIAS_SIMPLE_PATTERN = re.compile(r"#loc(\d+)\s*=\s*loc\(\s*#loc(\d*)\s*\)")
4646

47+
# Callsite loc definitions in TTIR/TTGIR
48+
# Example: #loc220 = loc(callsite(#loc57 at #loc190))
49+
# Captures: loc_id, callee_loc_id, caller_loc_id
50+
# Note: Uses (\d*) to match optional numbers (for bare #loc references)
51+
CALLSITE_PATTERN = re.compile(
52+
r"#loc(\d+)\s*=\s*loc\(\s*callsite\(\s*#loc(\d*)\s+at\s+#loc(\d*)\s*\)\s*\)"
53+
)
54+
4755

4856
def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
4957
"""
@@ -141,6 +149,44 @@ def resolve_alias(current_id: str) -> Dict[str, Any]:
141149
if alias_id not in locations:
142150
resolve_alias(alias_id)
143151

152+
# Collect callsite definitions
153+
callsite_defs = []
154+
for i, line in enumerate(ir_content.split("\n"), start=1):
155+
if m := CALLSITE_PATTERN.search(line):
156+
loc_id, callee_id, caller_id = m.groups()
157+
# Empty strings map to main loc key ""
158+
callsite_defs.append((loc_id, callee_id or "", caller_id or "", i))
159+
160+
# Resolve callsite definitions
161+
# A callsite inherits the location from its callee (the code being called)
162+
# and stores a reference to its caller (the code doing the calling)
163+
for loc_id, callee_id, caller_id, def_line in callsite_defs:
164+
if loc_id not in locations: # Avoid overwriting existing definitions
165+
if callee_id in locations:
166+
# Inherit location info from callee
167+
callee_info = locations[callee_id]
168+
locations[loc_id] = {
169+
"file": callee_info["file"],
170+
"line": callee_info["line"],
171+
"column": callee_info["column"],
172+
"def_line": def_line,
173+
"is_callsite": True,
174+
"callsite_callee": callee_id,
175+
"callsite_caller": caller_id,
176+
}
177+
else:
178+
logger.warning(
179+
f"Callsite #loc{loc_id} references undefined callee #loc{callee_id}"
180+
)
181+
# Note: We don't add this callsite to locations since callee is missing
182+
183+
# Verify caller references (warning only, don't block)
184+
for loc_id, _callee_id, caller_id, _def_line in callsite_defs:
185+
if loc_id in locations and caller_id and caller_id not in locations:
186+
logger.warning(
187+
f"Callsite #loc{loc_id} references undefined caller #loc{caller_id}"
188+
)
189+
144190
# Attach definition line and alias metadata
145191
for k, v in def_line_map.items():
146192
if k in locations:

tritonparse/trace_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def generate_source_mappings(
7777
"column": info["column"],
7878
f"{ir_type}_line": ln,
7979
}
80+
# Propagate callsite metadata if present
81+
if info.get("is_callsite"):
82+
entry["is_callsite"] = True
83+
entry["callsite_callee"] = info["callsite_callee"]
84+
entry["callsite_caller"] = info["callsite_caller"]
8085
# Propagate alias metadata if present
8186
if "alias_name" in info:
8287
entry["alias_name"] = info["alias_name"]

0 commit comments

Comments
 (0)