diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index 50dad4678..3bdea8f15 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -12,8 +12,15 @@ concurrency: jobs: ascend-build-and-test: - runs-on: ascend + runs-on: flagtree-ascend + if: ${{ github.repository == 'FlagTree/flagtree' || github.repository == 'flagos-ai/flagtree' }} steps: + - name: Setup environment + shell: bash + run: | + source ~/env.sh + env | grep -E '^(http_proxy|https_proxy|all_proxy|no_proxy)=' >> $GITHUB_ENV || true + - name: Checkout code (attempt 1) id: checkout1 uses: actions/checkout@v6 @@ -72,7 +79,6 @@ jobs: run: | set -x source /usr/local/Ascend/ascend-toolkit/set_env.sh - python3 third_party/tests/ascend/vector-add.py python3 third_party/ascend/examples/tutorials/01-vector-add.py python3 third_party/ascend/examples/tutorials/02-fused-softmax.py python3 third_party/ascend/examples/tutorials/03-layer-norm.py diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 76674f770..4bda6d07a 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,6 +15,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from .hint_manager import hint_trigger def mangle_ty(ty): @@ -516,9 +517,8 @@ def visit_Assign(self, node): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) - #flagtree backend specialization - from triton.runtime.driver import spec - spec("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) + # switch into hintmanager + hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) def visit_AugAssign(self, node): name = node.target.id @@ -953,10 +953,8 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') - - # flagtree backend specialization - from triton.runtime.driver import spec - new_bind_sub_block = spec("check_override_bind_sub_block", self, node, bind_sub_block) + # hint manager + new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block) if new_bind_sub_block is not None: bind_sub_block = new_bind_sub_block @@ -1026,10 +1024,9 @@ def visit_For(self, node): # flagtree backend specialization from triton.runtime.driver import spec spec("for_op_set_ext_attrs", for_op, self.builder, for_op_ext_attrs) - # flagtree backend specialization + # hint manager if bind_sub_block: - from triton.runtime.driver import spec - spec("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) + hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py new file mode 100644 index 000000000..719605175 --- /dev/null +++ b/python/triton/compiler/hint_manager.py @@ -0,0 +1,147 @@ +import os +import sys +import importlib + + +class BaseHintHandler: + # dynamicly find method + def trigger(self, hook_name, *args, **kwargs): + if hasattr(self, hook_name): + method = getattr(self, hook_name) + if callable(method): + try: + return method(*args, **kwargs) + + except TypeError as e: + import inspect + + try: + sig = inspect.signature(method) + expected = str(sig) + except Exception: + expected = "(unknown)" + + actual_args = f"{len(args)} positional" + actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" + + print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") + print(f" > Expect : {expected}") + print(f" > Actual : {actual_args}, {actual_kwargs}") + print(f" > Reason : {e}\n") + + raise e + return None + + +class HintManager: + + def __init__(self, backend_name): + self.backend_name = backend_name + # load Handler with backend name + self.handler = self._load_handler(backend_name) + + def _load_handler(self, backend): + if backend == 'npu': + try: + module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler") + return module.AscendHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'aipu': + from .backends.aipu import AipuHintHandler + return AipuHintHandler() + else: + return BaseHintHandler() + + +# supported backend with matched version +SUPPORTED_CONFIG = { + "cuda": {"3.5"}, + "npu": {"3.2"}, + "aipu": {"3.3"}, +} + +# mapping name +BACKEND_ALIASES = { + "ascend": "npu", + "huawei": "npu", + "nv": "cuda", +} + + +def normalize_backend_name(name: str) -> str: + if not name: + return "" + name = name.lower() + return BACKEND_ALIASES.get(name, name) + + +def hint_get_flagtree_backend() -> str: + detected_backend = "" + + import torch + import triton + + # Priority 1: Triton Driver + try: + from triton.runtime import driver + if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): + device = driver.active.get_active_torch_device() + if isinstance(device, torch.device): + detected_backend = device.type + # unimplemented support + elif isinstance(device, str): + detected_backend = device + except ImportError: + pass + + # Priority 2: Torch Global State + if not detected_backend: + candidates = list(SUPPORTED_CONFIG.keys()) + # cuda priority least + candidates.sort(key=lambda x: 1 if x == "cuda" else 0) + + # 3. parse according to benefit + for candidate in candidates: + module_name = candidate + module = getattr(torch, module_name, None) + if module and hasattr(module, "is_available") and module.is_available(): + detected_backend = candidate + break + + # Priority 3: Environment Variable (need to remove!!!) + if not detected_backend: + detected_backend = os.environ.get("FLAGTREE_BACKEND", "") + + # (Normalization and Validation) + canonical_backend = normalize_backend_name(detected_backend) + + if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG: + return "" + + # verify name and version match + try: + current_triton_version = ".".join(triton.__version__.split(".")[:2]) + supported_versions = SUPPORTED_CONFIG[canonical_backend] + if current_triton_version not in supported_versions: + msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " + f"'{current_triton_version}' matches no supported versions {supported_versions}.") + print(msg, file=sys.stderr) + return "" + except Exception: + pass + + return canonical_backend + + +# lazy load after first call hint trigger +_global_hint_manager = None + + +def hint_trigger(hook_name, *args, **kwargs): + global _global_hint_manager + + if _global_hint_manager is None: + _global_hint_manager = HintManager(hint_get_flagtree_backend()) + return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 105e26b24..0a8cc1442 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -796,17 +796,19 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): - # flagtree backend specialization - from triton.runtime.driver import spec - line_flagtree_hints = spec('maps_line_numbers_to_comment_hints', self) + # hint manager + # after removing flagtree backend specialization, hiding the implementation into hintmanager + from ..compiler.hint_manager import hint_trigger + line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self) tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) - # flagtree backend specialization - spec('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + # hint manager + # Attach the line number to comment mapping to the function definition node + hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) return tree diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py new file mode 100644 index 000000000..65e492c6c --- /dev/null +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -0,0 +1,79 @@ +# should store at thrid_party/???/backend/ +from triton.compiler.hint_manager import BaseHintHandler +import triton.language as language +import ast +from triton.compiler.code_generator import _is_triton_value + + +class AscendHintHandler(BaseHintHandler): + + @staticmethod + def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) + and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + @staticmethod + def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + + @staticmethod + def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + + @staticmethod + def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + + @staticmethod + def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints