Skip to content

Commit 7df7830

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Variable loc info PR2: Add location alias parsing support in IR parser and trace processor (#187)
Summary: fix #86 This PR implements comprehensive support for parsing location aliases in TTIR/TTGIR files. It adds regex patterns to match alias definitions, implements alias resolution with cycle detection, tracks metadata (definition lines, alias names, alias targets), and propagates this information through the trace processor to the frontend. ## Background Recent Triton compiler updates introduced a new location pattern with aliases: - Named aliases: `#loc16 = loc("pid"(#loc2))` - references `#loc2` with name "pid" - Bare #loc aliases: `#loc13 = loc("x_ptr"(#loc))` - references main `#loc` with name "x_ptr" - Simple aliases: `#loc20 = loc(#loc16)` - direct reference without a name These aliases are used extensively for function parameters and variable names in the new TTIR format. ## Changes ### Backend - IR Parser (`tritonparse/ir_parser.py`) **New regex patterns:** - `ALIAS_WITH_NAME_PATTERN`: Matches `#loc13 = loc("x_ptr"(#loc))` - `ALIAS_SIMPLE_PATTERN`: Matches `#loc20 = loc(#loc16)` - Updated `LOC_PATTERN`: Now uses `\d+` (only matches numbered locs, not bare `#loc`) **Alias resolution logic:** - Build alias map: `alias_id` → `target_id` - Track definition lines for each location - Extract alias names from named aliases - Resolve alias chains recursively with cycle detection - Store metadata: `def_line`, `alias_name`, `alias_of` **Key implementation details:** - Empty string `""` represents bare `#loc` (from PR1) - Recursive `resolve_alias()` function follows chains to base locations - Cycle detection prevents infinite loops in malformed IR ### Backend - Trace Processor (`tritonparse/trace_processor.py`) **Metadata propagation:** - Add `alias_name` to code reference entries - Add `loc_id` to identify which location is being referenced - Create separate entries for loc definition lines with `kind="loc_def"` - Include all metadata (`alias_name`, `alias_of`, `loc_id`) for definition lines ### Testing (`tests/test_tritonparse.py`) **New test: `test_loc_alias_parsing()`** - Tests bare `#loc` references (e.g., function parameters) - Tests named aliases with numbered references - Tests simple aliases without names - Verifies definition line tracking - Validates alias chain resolution - Moved to `TestTritonparseCPU` class (doesn't require CUDA) **Test coverage:** - Bare #loc: `#loc13 = loc("x_ptr"(#loc))` - Numbered alias: `#loc16 = loc("y"(#loc1))` - Chain resolution: `#loc17 → #loc3 → file location` - Metadata: `alias_name`, `alias_of`, `def_line` ## Benefits 1. **Correct parsing**: Handles new Triton compiler location format 2. **Rich metadata**: Preserves alias names and relationships 3. **Robust**: Cycle detection prevents crashes on malformed IR 4. **Testable**: Comprehensive unit tests ensure correctness 5. **Foundation**: Enables frontend visualization (PR3) ## Backward Compatibility - ✅ Old TTIR format without aliases continues to work - ✅ New metadata fields are optional - ✅ No breaking changes to existing API ## Dependencies - **Requires**: PR1 (main #loc key fix) - **Enables**: PR3 (frontend UI visualization) ## Example Input TTIR: ```mlir #loc = loc("/scratch/test.py":308:0) #loc13 = loc("x_ptr"(#loc)) #loc1 = loc("/scratch/test.py":315:12) #loc16 = loc("y"(#loc1)) ``` Parsed locations: ```python { "": {"file": "/scratch/test.py", "line": 308, "def_line": 1}, "13": {"file": "/scratch/test.py", "line": 308, "alias_name": "x_ptr", "alias_of": "", "def_line": 2}, "1": {"file": "/scratch/test.py", "line": 315, "def_line": 3}, "16": {"file": "/scratch/test.py", "line": 315, "alias_name": "y", "alias_of": "1", "def_line": 4} } ``` ## Files Changed - `tritonparse/ir_parser.py`: +93 lines (regex patterns, alias resolution) - `tritonparse/trace_processor.py`: +27 lines (metadata propagation) - `tests/test_tritonparse.py`: +64 lines (comprehensive tests) ## Testing ```bash # Run the new test python -m unittest tests.test_tritonparse.TestTritonparseCPU.test_loc_alias_parsing -v ``` All tests pass ✓ Pull Request resolved: #187 Reviewed By: adamomainz Differential Revision: D85605707 Pulled By: FindHao fbshipit-source-id: 567595c6ef735a743014e0432cd88e7325ed0fa5
1 parent af7b65c commit 7df7830

File tree

3 files changed

+183
-2
lines changed

3 files changed

+183
-2
lines changed

tests/test_tritonparse.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,70 @@ class NestedDataClass:
187187

188188
assert convert(nested_structure) == expected_nested
189189

190+
def test_loc_alias_parsing(self):
191+
"""Test parsing of location aliases in TTIR/TTGIR"""
192+
from tritonparse.ir_parser import extract_loc_definitions
193+
194+
# Test case 1: Bare #loc reference (no number)
195+
ir_with_bare_loc = """
196+
module {
197+
#loc = loc("/tmp/test.py":10:5)
198+
#loc13 = loc("x_ptr"(#loc))
199+
func @kernel(%arg0: !tt.ptr<f32> loc(#loc13)) {
200+
return loc(#loc)
201+
}
202+
}
203+
"""
204+
locs = extract_loc_definitions(ir_with_bare_loc)
205+
# Main #loc should be stored with "" key
206+
assert "" in locs, "Main #loc not found"
207+
assert locs[""]["file"] == "/tmp/test.py"
208+
assert locs[""]["line"] == 10
209+
# Alias #loc13 should resolve to same location
210+
assert "13" in locs, "#loc13 not found"
211+
assert locs["13"]["file"] == "/tmp/test.py"
212+
assert locs["13"]["line"] == 10
213+
assert locs["13"]["alias_name"] == "x_ptr"
214+
assert locs["13"]["alias_of"] == ""
215+
216+
# Test case 2: Named alias with numbered reference
217+
ir_with_numbered_alias = """
218+
#loc = loc("/tmp/test.py":5:0)
219+
#loc2 = loc("/tmp/test.py":20:28)
220+
#loc16 = loc("pid"(#loc2))
221+
%0 = tt.get_program_id x : i32 loc(#loc16)
222+
"""
223+
locs = extract_loc_definitions(ir_with_numbered_alias)
224+
assert "2" in locs
225+
assert locs["2"]["line"] == 20
226+
assert "16" in locs
227+
assert locs["16"]["file"] == "/tmp/test.py"
228+
assert locs["16"]["line"] == 20
229+
assert locs["16"]["alias_name"] == "pid"
230+
assert locs["16"]["alias_of"] == "2"
231+
232+
# Test case 3: Simple alias (no name)
233+
ir_with_simple_alias = """
234+
#loc = loc("/tmp/test.py":1:1)
235+
#loc1 = loc("/tmp/test.py":15:10)
236+
#loc20 = loc(#loc1)
237+
%1 = arith.constant 0 : i32 loc(#loc20)
238+
"""
239+
locs = extract_loc_definitions(ir_with_simple_alias)
240+
assert "1" in locs
241+
assert "20" in locs
242+
assert locs["20"]["file"] == "/tmp/test.py"
243+
assert locs["20"]["line"] == 15
244+
assert locs["20"]["alias_of"] == "1"
245+
assert "alias_name" not in locs["20"]
246+
247+
# Test case 4: Definition line tracking
248+
assert "def_line" in locs[""]
249+
assert "def_line" in locs["1"]
250+
assert "def_line" in locs["20"]
251+
252+
print("✓ All loc alias parsing tests passed")
253+
190254

191255
class TestTritonparseCUDA(unittest.TestCase):
192256
"""CUDA tests (require GPU)"""

tritonparse/ir_parser.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
# the definition of the #loc directive. they are in the bottom of the IR files
1212
# Example:#loc2 = loc("/tmp/torchinductor_yhao/yp/abcdef.py":20:28)
13-
LOC_PATTERN = re.compile(r'#loc(\d*) = loc\("([^"]+)":(\d+):(\d+)\)')
13+
# Note: This should only match numbered locs like #loc1, #loc2, not bare #loc
14+
LOC_PATTERN = re.compile(r'#loc(\d+) = loc\("([^"]+)":(\d+):(\d+)\)')
1415

1516
# the reference to the #loc directive. they are in the end of lines of the IR files
1617
# Example: loc(#loc2)
@@ -33,6 +34,17 @@
3334
)
3435

3536

37+
# alias loc definitions in TTGIR/TTIR
38+
# Example: #loc16 = loc("pid"(#loc2))
39+
# Example: #loc13 = loc("x_ptr"(#loc)) - bare #loc without number
40+
ALIAS_WITH_NAME_PATTERN = re.compile(
41+
r'#loc(\d+)\s*=\s*loc\("([^"]+)"\s*\(\s*#loc(\d*)\s*\)\s*\)'
42+
)
43+
44+
# Example: #loc20 = loc(#loc16)
45+
ALIAS_SIMPLE_PATTERN = re.compile(r"#loc(\d+)\s*=\s*loc\(\s*#loc(\d*)\s*\)")
46+
47+
3648
def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
3749
"""
3850
Extracts location definitions from the given IR content.
@@ -50,6 +62,7 @@ def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
5062
"""
5163
locations = {}
5264
# The first #loc directive is a special case. It locates at the top of the IR files
65+
# Store it with empty string "" as key to avoid conflict with #loc1
5366
main_match = re.search(r'#loc = loc\("([^"]+)":(\d+):(\d+)\)', ir_content)
5467
if main_match:
5568
locations[""] = {
@@ -61,6 +74,84 @@ def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
6174
for loc_id, filename, line, col in LOC_PATTERN.findall(ir_content):
6275
key = loc_id
6376
locations[key] = {"file": filename, "line": int(line), "column": int(col)}
77+
78+
# Handle alias-style loc definitions that reference another #loc
79+
# Build alias map first: alias_id -> target_id
80+
alias_map: Dict[str, str] = {}
81+
for m in ALIAS_WITH_NAME_PATTERN.finditer(ir_content):
82+
alias_id, _name, target_id = m.groups()
83+
# Empty target_id means bare #loc, map to "" (main loc key)
84+
alias_map[alias_id] = target_id or ""
85+
for m in ALIAS_SIMPLE_PATTERN.finditer(ir_content):
86+
alias_id, target_id = m.groups()
87+
# Empty target_id means bare #loc, map to "" (main loc key)
88+
alias_map[alias_id] = target_id or ""
89+
90+
# Build definition line map and alias name map by scanning lines
91+
def_line_map: Dict[str, int] = {}
92+
alias_name_map: Dict[str, str] = {}
93+
main_loc_line: int = 0
94+
for i, line in enumerate(ir_content.split("\n"), start=1):
95+
if m := ALIAS_WITH_NAME_PATTERN.search(line):
96+
alias_id, name, target_id = m.groups()
97+
def_line_map[alias_id] = i
98+
alias_name_map[alias_id] = name
99+
# ensure alias map is populated even if only found in line scan
100+
# Empty target_id means bare #loc, map to "" (main loc key)
101+
alias_map.setdefault(alias_id, target_id or "")
102+
elif m := ALIAS_SIMPLE_PATTERN.search(line):
103+
alias_id, target_id = m.groups()
104+
def_line_map[alias_id] = i
105+
# Empty target_id means bare #loc, map to "" (main loc key)
106+
alias_map.setdefault(alias_id, target_id or "")
107+
if m2 := LOC_PATTERN.search(line):
108+
base_id, _fn, _ln, _col = m2.groups()
109+
def_line_map[base_id] = i
110+
if re.search(r'#loc\s*=\s*loc\("[^"]+":\d+:\d+\)', line):
111+
# main #loc = loc("file":line:col) without id
112+
main_loc_line = main_loc_line or i
113+
114+
# Resolve aliases to base locations (file/line/column)
115+
resolving_stack = set()
116+
117+
def resolve_alias(current_id: str) -> Dict[str, Any]:
118+
# Already a concrete location
119+
if current_id in locations:
120+
return locations[current_id]
121+
# Detect cycles
122+
if current_id in resolving_stack:
123+
return {}
124+
resolving_stack.add(current_id)
125+
parent_id = alias_map.get(current_id)
126+
result: Dict[str, Any] = {}
127+
if parent_id is not None:
128+
base = resolve_alias(parent_id)
129+
if base:
130+
# copy to avoid sharing the same dict by reference
131+
result = {
132+
"file": base.get("file"),
133+
"line": base.get("line"),
134+
"column": base.get("column"),
135+
}
136+
locations[current_id] = result
137+
resolving_stack.remove(current_id)
138+
return result
139+
140+
# Resolve aliases and attach alias metadata
141+
for alias_id, target_id in alias_map.items():
142+
if alias_id not in locations:
143+
resolve_alias(alias_id)
144+
if alias_id in locations:
145+
locations[alias_id]["alias_of"] = target_id
146+
if alias_id in alias_name_map:
147+
locations[alias_id]["alias_name"] = alias_name_map[alias_id]
148+
149+
# Attach definition line metadata
150+
for k, v in def_line_map.items():
151+
if k in locations:
152+
locations[k]["def_line"] = v
153+
if main_loc_line and "" in locations:
154+
locations[""]["def_line"] = main_loc_line
64155
return locations
65156

66157

tritonparse/trace_processor.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,38 @@ def generate_source_mappings(
7171
}
7272
elif loc_id in loc_defs:
7373
info = loc_defs[loc_id]
74-
mappings[str(ln)] = {
74+
entry = {
7575
"file": info["file"],
7676
"line": info["line"],
7777
"column": info["column"],
7878
f"{ir_type}_line": ln,
7979
}
80+
# Propagate alias metadata if present
81+
if "alias_name" in info:
82+
entry["alias_name"] = info["alias_name"]
83+
if "alias_of" in info:
84+
entry["loc_id"] = loc_id
85+
mappings[str(ln)] = entry
86+
87+
# Add separate entries for loc definition lines
88+
for loc_id, info in loc_defs.items():
89+
if "def_line" not in info:
90+
continue
91+
def_ln = info["def_line"]
92+
# Only create mapping if this line doesn't already have one
93+
if str(def_ln) not in mappings:
94+
entry = {
95+
"file": info["file"],
96+
"line": info["line"],
97+
"column": info["column"],
98+
f"{ir_type}_line": def_ln,
99+
"kind": "loc_def",
100+
}
101+
if "alias_name" in info:
102+
entry["alias_name"] = info["alias_name"]
103+
if "alias_of" in info:
104+
entry["loc_id"] = loc_id
105+
mappings[str(def_ln)] = entry
80106

81107
return mappings
82108

0 commit comments

Comments
 (0)