@@ -139,6 +139,63 @@ def clear_all_caches(*kernels):
139139class TestTritonparseCPU (unittest .TestCase ):
140140 """CPU-only tests (no CUDA required)"""
141141
142+ def test_callsite_parsing (self ):
143+ """Test parsing of callsite locations in TTIR/TTGIR"""
144+ from tritonparse .ir_parser import extract_loc_definitions
145+ from tritonparse .trace_processor import generate_source_mappings
146+
147+ # Test MLIR callsite location definitions
148+ ir_with_callsite = """
149+ module {
150+ #loc7 = loc("/tmp/test.py":1091:8)
151+ #loc57 = loc("/tmp/test.py":421:16)
152+ #loc58 = loc("/tmp/test.py":853:16)
153+ #loc190 = loc(callsite(#loc58 at #loc7))
154+ #loc220 = loc(callsite(#loc57 at #loc190))
155+ %0 = tt.load %ptr loc(#loc220)
156+ }
157+ """
158+ # Extract loc definitions
159+ locs = extract_loc_definitions (ir_with_callsite )
160+
161+ # Verify loc220 (nested callsite)
162+ self .assertIn ("220" , locs )
163+ self .assertEqual (locs ["220" ]["file" ], "/tmp/test.py" )
164+ self .assertEqual (locs ["220" ]["line" ], 421 ) # Inherited from callee loc57
165+ self .assertEqual (locs ["220" ]["column" ], 16 )
166+ self .assertTrue (locs ["220" ].get ("is_callsite" ))
167+ self .assertEqual (locs ["220" ]["callsite_callee" ], "57" )
168+ self .assertEqual (locs ["220" ]["callsite_caller" ], "190" )
169+
170+ # Verify loc190 (simple callsite)
171+ self .assertIn ("190" , locs )
172+ self .assertEqual (locs ["190" ]["line" ], 853 ) # Inherited from callee loc58
173+ self .assertTrue (locs ["190" ].get ("is_callsite" ))
174+ self .assertEqual (locs ["190" ]["callsite_callee" ], "58" )
175+ self .assertEqual (locs ["190" ]["callsite_caller" ], "7" )
176+
177+ # Test source mappings generation
178+ mappings = generate_source_mappings (ir_with_callsite , "ttir" )
179+
180+ # Find the line with tt.load
181+ line_with_load = None
182+ for line_num , content in enumerate (ir_with_callsite .split ("\n " ), start = 1 ):
183+ if "tt.load" in content :
184+ line_with_load = str (line_num )
185+ break
186+
187+ self .assertIsNotNone (line_with_load )
188+ self .assertIn (line_with_load , mappings )
189+
190+ mapping = mappings [line_with_load ]
191+ self .assertEqual (mapping ["file" ], "/tmp/test.py" )
192+ self .assertEqual (mapping ["line" ], 421 ) # From loc220 -> loc57
193+ self .assertTrue (mapping .get ("is_callsite" ))
194+ self .assertEqual (mapping ["callsite_callee" ], "57" )
195+ self .assertEqual (mapping ["callsite_caller" ], "190" )
196+
197+ print ("✓ Callsite parsing tests passed" )
198+
142199 def test_convert (self ):
143200 """Test convert function with various data types"""
144201 # Test with primitive types
@@ -1506,6 +1563,63 @@ def test_loc_alias_parsing(self):
15061563
15071564 print ("✓ All loc alias parsing tests passed" )
15081565
1566+ def test_callsite_parsing (self ):
1567+ """Test parsing of callsite locations in TTIR/TTGIR"""
1568+ from tritonparse .ir_parser import extract_loc_definitions
1569+ from tritonparse .trace_processor import generate_source_mappings
1570+
1571+ # Test MLIR callsite location definitions
1572+ ir_with_callsite = """
1573+ module {
1574+ #loc7 = loc("/tmp/test.py":1091:8)
1575+ #loc57 = loc("/tmp/test.py":421:16)
1576+ #loc58 = loc("/tmp/test.py":853:16)
1577+ #loc190 = loc(callsite(#loc58 at #loc7))
1578+ #loc220 = loc(callsite(#loc57 at #loc190))
1579+ %0 = tt.load %ptr loc(#loc220)
1580+ }
1581+ """
1582+ # Extract loc definitions
1583+ locs = extract_loc_definitions (ir_with_callsite )
1584+
1585+ # Verify loc220 (nested callsite)
1586+ self .assertIn ("220" , locs )
1587+ self .assertEqual (locs ["220" ]["file" ], "/tmp/test.py" )
1588+ self .assertEqual (locs ["220" ]["line" ], 421 ) # Inherited from callee loc57
1589+ self .assertEqual (locs ["220" ]["column" ], 16 )
1590+ self .assertTrue (locs ["220" ].get ("is_callsite" ))
1591+ self .assertEqual (locs ["220" ]["callsite_callee" ], "57" )
1592+ self .assertEqual (locs ["220" ]["callsite_caller" ], "190" )
1593+
1594+ # Verify loc190 (simple callsite)
1595+ self .assertIn ("190" , locs )
1596+ self .assertEqual (locs ["190" ]["line" ], 853 ) # Inherited from callee loc58
1597+ self .assertTrue (locs ["190" ].get ("is_callsite" ))
1598+ self .assertEqual (locs ["190" ]["callsite_callee" ], "58" )
1599+ self .assertEqual (locs ["190" ]["callsite_caller" ], "7" )
1600+
1601+ # Test source mappings generation
1602+ mappings = generate_source_mappings (ir_with_callsite , "ttir" )
1603+
1604+ # Find the line with tt.load
1605+ line_with_load = None
1606+ for line_num , content in enumerate (ir_with_callsite .split ("\n " ), start = 1 ):
1607+ if "tt.load" in content :
1608+ line_with_load = str (line_num )
1609+ break
1610+
1611+ self .assertIsNotNone (line_with_load )
1612+ self .assertIn (line_with_load , mappings )
1613+
1614+ mapping = mappings [line_with_load ]
1615+ self .assertEqual (mapping ["file" ], "/tmp/test.py" )
1616+ self .assertEqual (mapping ["line" ], 421 ) # From loc220 -> loc57
1617+ self .assertTrue (mapping .get ("is_callsite" ))
1618+ self .assertEqual (mapping ["callsite_callee" ], "57" )
1619+ self .assertEqual (mapping ["callsite_caller" ], "190" )
1620+
1621+ print ("✓ Callsite parsing tests passed" )
1622+
15091623
15101624if __name__ == "__main__" :
15111625 unittest .main ()
0 commit comments