Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
from types import ModuleType
from .hint_manager import hint_trigger


def mangle_ty(ty):
Expand Down Expand Up @@ -516,9 +517,8 @@ def visit_Assign(self, node):
value = language.semantic.to_tensor(value, self.builder)
self.set_value(name, value)

#flagtree backend specialization
from triton.runtime.driver import spec
spec("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)
# switch into hintmanager
hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)

def visit_AugAssign(self, node):
name = node.target.id
Expand Down Expand Up @@ -953,10 +953,8 @@ def visit_For(self, node):
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
else:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')

# flagtree backend specialization
from triton.runtime.driver import spec
new_bind_sub_block = spec("check_override_bind_sub_block", self, node, bind_sub_block)
# hint manager
new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block)
if new_bind_sub_block is not None:
bind_sub_block = new_bind_sub_block

Expand Down Expand Up @@ -1026,10 +1024,9 @@ def visit_For(self, node):
# flagtree backend specialization
from triton.runtime.driver import spec
spec("for_op_set_ext_attrs", for_op, self.builder, for_op_ext_attrs)
# flagtree backend specialization
# hint manager
if bind_sub_block:
from triton.runtime.driver import spec
spec("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)
hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)

self.scf_stack.append(node)
self.builder.set_insertion_point_to_start(for_op.get_body(0))
Expand Down
147 changes: 147 additions & 0 deletions python/triton/compiler/hint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import sys
import importlib


class BaseHintHandler:
# dynamicly find method
def trigger(self, hook_name, *args, **kwargs):
if hasattr(self, hook_name):
method = getattr(self, hook_name)
if callable(method):
try:
return method(*args, **kwargs)

except TypeError as e:
import inspect

try:
sig = inspect.signature(method)
expected = str(sig)
except Exception:
expected = "(unknown)"

actual_args = f"{len(args)} positional"
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"

print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
print(f" > Expect : {expected}")
print(f" > Actual : {actual_args}, {actual_kwargs}")
print(f" > Reason : {e}\n")

raise e
return None


class HintManager:

def __init__(self, backend_name):
self.backend_name = backend_name
# load Handler with backend name
self.handler = self._load_handler(backend_name)

def _load_handler(self, backend):
if backend == 'npu':
try:
module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler")
return module.AscendHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'aipu':
from .backends.aipu import AipuHintHandler
return AipuHintHandler()
else:
return BaseHintHandler()


# supported backend with matched version
SUPPORTED_CONFIG = {
"cuda": {"3.5"},
"npu": {"3.2"},
"aipu": {"3.3"},
}

# mapping name
BACKEND_ALIASES = {
"ascend": "npu",
"huawei": "npu",
"nv": "cuda",
}


def normalize_backend_name(name: str) -> str:
if not name:
return ""
name = name.lower()
return BACKEND_ALIASES.get(name, name)


def hint_get_flagtree_backend() -> str:
detected_backend = ""

import torch
import triton

# Priority 1: Triton Driver
try:
from triton.runtime import driver
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
device = driver.active.get_active_torch_device()
if isinstance(device, torch.device):
detected_backend = device.type
# unimplemented support
elif isinstance(device, str):
detected_backend = device
except ImportError:
pass

# Priority 2: Torch Global State
if not detected_backend:
candidates = list(SUPPORTED_CONFIG.keys())
# cuda priority least
candidates.sort(key=lambda x: 1 if x == "cuda" else 0)

# 3. parse according to benefit
for candidate in candidates:
module_name = candidate
module = getattr(torch, module_name, None)
if module and hasattr(module, "is_available") and module.is_available():
detected_backend = candidate
break

# Priority 3: Environment Variable (need to remove!!!)
if not detected_backend:
detected_backend = os.environ.get("FLAGTREE_BACKEND", "")

# (Normalization and Validation)
canonical_backend = normalize_backend_name(detected_backend)

if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG:
return ""

# verify name and version match
try:
current_triton_version = ".".join(triton.__version__.split(".")[:2])
supported_versions = SUPPORTED_CONFIG[canonical_backend]
if current_triton_version not in supported_versions:
msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version "
f"'{current_triton_version}' matches no supported versions {supported_versions}.")
print(msg, file=sys.stderr)
return ""
except Exception:
pass

return canonical_backend


# lazy load after first call hint trigger
_global_hint_manager = None


def hint_trigger(hook_name, *args, **kwargs):
global _global_hint_manager

if _global_hint_manager is None:
_global_hint_manager = HintManager(hint_get_flagtree_backend())
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)
12 changes: 7 additions & 5 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,17 +796,19 @@ def preload(self, specialization_data):
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
# flagtree backend specialization
from triton.runtime.driver import spec
line_flagtree_hints = spec('maps_line_numbers_to_comment_hints', self)
# hint manager
# after removing flagtree backend specialization, hiding the implementation into hintmanager
from ..compiler.hint_manager import hint_trigger
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)

tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)

# flagtree backend specialization
spec('attach_line_number_to_comment_mapping', tree, line_flagtree_hints)
# hint manager
# Attach the line number to comment mapping to the function definition node
hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints)

return tree

Expand Down
79 changes: 79 additions & 0 deletions third_party/ascend/backend/ascend_hint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# should store at thrid_party/???/backend/
from triton.compiler.hint_manager import BaseHintHandler
import triton.language as language
import ast
from triton.compiler.code_generator import _is_triton_value


class AscendHintHandler(BaseHintHandler):

@staticmethod
def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values):
import ast
from triton.compiler.code_generator import _is_triton_value
# flagtree: After normal processing, check if we need to add hint annotation
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a tl.load call with dot_pad_only_k hint
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name)
and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'):

# Add hint annotation to the loaded tensor(s)
for name, value in zip(names, values):
if _is_triton_value(value):
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
# Create hint annotation
hint_val = code_generator.builder.get_unit_attr()
code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)

@staticmethod
def check_override_bind_sub_block(code_generator, node, bind_sub_block):
# flagtree: After normal processing, check if we need to override bind_sub_block
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a range/for loop with bind_sub_block hint
if flagtree_hints and 'bind_sub_block' in flagtree_hints:
return True
# print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}")
return bind_sub_block

@staticmethod
def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block):
for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block))

@staticmethod
def maps_line_numbers_to_comment_hints(jit_fn):
import tokenize
from io import StringIO
# Maps line numbers to comment hints
line_flagtree_hints = {}
code_str = jit_fn.src
g = tokenize.generate_tokens(StringIO(code_str).readline)
for tok_type, tok_text, start, end, _ in g:
if tok_type == tokenize.COMMENT:
comment = tok_text.replace(" ", "").strip()
if comment.startswith('#@hint:'):
flagtree_hints = comment[len('#@hint:'):].strip()
# Record the line number of the comment
line_num = start[0]
line_flagtree_hints[line_num] = flagtree_hints

# print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}")

return line_flagtree_hints

@staticmethod
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
# Attach the line number to comment mapping to the function definition node
tree.body[0].line_flagtree_hints = line_flagtree_hints
Loading