-
Notifications
You must be signed in to change notification settings - Fork 38
[FEAT] Triton v3.2.x hint manager #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
starrryz
wants to merge
17
commits into
triton_v3.2.x
Choose a base branch
from
triton_v3.2.x_hint_manager
base: triton_v3.2.x
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+248
−17
Open
Changes from 14 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
a4c2fb1
the initial design for unified hint framework
starrryz ca9c939
update the logic of how to call backend method in basehinthandler
starrryz ded0530
fix input attr bug
starrryz 6b204a7
update hintmanager, wrap additional code into hintmanager, back no-hi…
starrryz f9dbdb5
remove redundant code
starrryz 4112810
fix import and python bugs
starrryz 36edfa6
fix import and python bugs_2
starrryz dd100aa
apply code-format change
starrryz 623c42e
apply code-format change_2
starrryz 80f8ade
fix bug : circular import
starrryz cb0707b
fix bug : hintmanager name into hint_manager
starrryz 260043d
Merge branch 'triton_v3.2.x' into triton_v3.2.x_hint_manager
sunnycase 67bc525
fix bug : massive useless print
starrryz 8872723
Merge branch 'triton_v3.2.x' into triton_v3.2.x_hint_manager
starrryz 3117267
update CI
starrryz 730f25d
update CI 2
starrryz 3ace7be
update CI 3
starrryz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import os | ||
| import sys | ||
| import importlib | ||
|
|
||
|
|
||
| class BaseHintHandler: | ||
| # dynamicly find method | ||
| def trigger(self, hook_name, *args, **kwargs): | ||
| if hasattr(self, hook_name): | ||
| method = getattr(self, hook_name) | ||
| if callable(method): | ||
| try: | ||
| return method(*args, **kwargs) | ||
|
|
||
| except TypeError as e: | ||
| import inspect | ||
|
|
||
| try: | ||
| sig = inspect.signature(method) | ||
| expected = str(sig) | ||
| except Exception: | ||
| expected = "(unknown)" | ||
|
|
||
| actual_args = f"{len(args)} positional" | ||
| actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" | ||
|
|
||
| print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") | ||
| print(f" > Expect : {expected}") | ||
| print(f" > Actual : {actual_args}, {actual_kwargs}") | ||
| print(f" > Reason : {e}\n") | ||
|
|
||
| raise e | ||
| return None | ||
|
|
||
|
|
||
| class HintManager: | ||
|
|
||
| def __init__(self, backend_name): | ||
| self.backend_name = backend_name | ||
| # load Handler with backend name | ||
| self.handler = self._load_handler(backend_name) | ||
|
|
||
| def _load_handler(self, backend): | ||
| if backend == 'npu': | ||
| try: | ||
| module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler") | ||
| return module.AscendHintHandler() | ||
| except ImportError as e: | ||
| print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) | ||
| return BaseHintHandler() | ||
| elif backend == 'aipu': | ||
| from .backends.aipu import AipuHintHandler | ||
| return AipuHintHandler() | ||
| else: | ||
| return BaseHintHandler() | ||
|
|
||
|
|
||
| # supported backend with matched version | ||
| SUPPORTED_CONFIG = { | ||
| "cuda": {"3.5"}, | ||
| "npu": {"3.2"}, | ||
| "aipu": {"3.3"}, | ||
| } | ||
|
|
||
| # mapping name | ||
| BACKEND_ALIASES = { | ||
| "ascend": "npu", | ||
| "huawei": "npu", | ||
| "nv": "cuda", | ||
| } | ||
|
|
||
|
|
||
| def normalize_backend_name(name: str) -> str: | ||
| if not name: | ||
| return "" | ||
| name = name.lower() | ||
| return BACKEND_ALIASES.get(name, name) | ||
|
|
||
|
|
||
| def hint_get_flagtree_backend() -> str: | ||
| detected_backend = "" | ||
|
|
||
| import torch | ||
| import triton | ||
|
|
||
| # Priority 1: Triton Driver | ||
| try: | ||
| from triton.runtime import driver | ||
| if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): | ||
| device = driver.active.get_active_torch_device() | ||
| if isinstance(device, torch.device): | ||
| detected_backend = device.type | ||
| # unimplemented support | ||
| elif isinstance(device, str): | ||
| detected_backend = device | ||
| except ImportError: | ||
| pass | ||
|
|
||
| # Priority 2: Torch Global State | ||
| if not detected_backend: | ||
| candidates = list(SUPPORTED_CONFIG.keys()) | ||
| # cuda priority least | ||
| candidates.sort(key=lambda x: 1 if x == "cuda" else 0) | ||
|
|
||
| # 3. parse according to benefit | ||
| for candidate in candidates: | ||
| module_name = candidate | ||
| module = getattr(torch, module_name, None) | ||
| if module and hasattr(module, "is_available") and module.is_available(): | ||
| detected_backend = candidate | ||
| break | ||
|
|
||
| # Priority 3: Environment Variable (need to remove!!!) | ||
| if not detected_backend: | ||
| detected_backend = os.environ.get("FLAGTREE_BACKEND", "") | ||
|
|
||
| # (Normalization and Validation) | ||
| canonical_backend = normalize_backend_name(detected_backend) | ||
|
|
||
| if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG: | ||
| return "" | ||
|
|
||
| # verify name and version match | ||
| try: | ||
| current_triton_version = ".".join(triton.__version__.split(".")[:2]) | ||
| supported_versions = SUPPORTED_CONFIG[canonical_backend] | ||
| if current_triton_version not in supported_versions: | ||
| msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " | ||
| f"'{current_triton_version}' matches no supported versions {supported_versions}.") | ||
| print(msg, file=sys.stderr) | ||
| return "" | ||
| except Exception: | ||
| pass | ||
|
|
||
| return canonical_backend | ||
|
|
||
|
|
||
| # lazy load after first call hint trigger | ||
| _global_hint_manager = None | ||
|
|
||
|
|
||
| def hint_trigger(hook_name, *args, **kwargs): | ||
| global _global_hint_manager | ||
|
|
||
| if _global_hint_manager is None: | ||
| _global_hint_manager = HintManager(hint_get_flagtree_backend()) | ||
| return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # should store at thrid_party/???/backend/ | ||
| from triton.compiler.hint_manager import BaseHintHandler | ||
| import triton.language as language | ||
| import ast | ||
| from triton.compiler.code_generator import _is_triton_value | ||
|
|
||
|
|
||
| class AscendHintHandler(BaseHintHandler): | ||
|
|
||
| @staticmethod | ||
| def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): | ||
| import ast | ||
| from triton.compiler.code_generator import _is_triton_value | ||
| # flagtree: After normal processing, check if we need to add hint annotation | ||
| if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): | ||
| line_num = node.lineno | ||
| # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later | ||
| function_def = code_generator.jit_fn.parse() | ||
| line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) | ||
| flagtree_hints = line_flagtree_hints.get(line_num) | ||
|
|
||
| # Check if this is a tl.load call with dot_pad_only_k hint | ||
| if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) | ||
| and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) | ||
| and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): | ||
|
|
||
| # Add hint annotation to the loaded tensor(s) | ||
| for name, value in zip(names, values): | ||
| if _is_triton_value(value): | ||
| # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") | ||
| # Create hint annotation | ||
| hint_val = code_generator.builder.get_unit_attr() | ||
| code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) | ||
|
|
||
| @staticmethod | ||
| def check_override_bind_sub_block(code_generator, node, bind_sub_block): | ||
| # flagtree: After normal processing, check if we need to override bind_sub_block | ||
| if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): | ||
| line_num = node.lineno | ||
| # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later | ||
| function_def = code_generator.jit_fn.parse() | ||
| line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) | ||
| flagtree_hints = line_flagtree_hints.get(line_num) | ||
|
|
||
| # Check if this is a range/for loop with bind_sub_block hint | ||
| if flagtree_hints and 'bind_sub_block' in flagtree_hints: | ||
| return True | ||
| # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") | ||
starrryz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return bind_sub_block | ||
|
|
||
| @staticmethod | ||
| def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): | ||
| for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) | ||
|
|
||
| @staticmethod | ||
| def maps_line_numbers_to_comment_hints(jit_fn): | ||
| import tokenize | ||
| from io import StringIO | ||
| # Maps line numbers to comment hints | ||
| line_flagtree_hints = {} | ||
| code_str = jit_fn.src | ||
| g = tokenize.generate_tokens(StringIO(code_str).readline) | ||
| for tok_type, tok_text, start, end, _ in g: | ||
| if tok_type == tokenize.COMMENT: | ||
| comment = tok_text.replace(" ", "").strip() | ||
| if comment.startswith('#@hint:'): | ||
| flagtree_hints = comment[len('#@hint:'):].strip() | ||
| # Record the line number of the comment | ||
| line_num = start[0] | ||
| line_flagtree_hints[line_num] = flagtree_hints | ||
|
|
||
| # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") | ||
|
|
||
| return line_flagtree_hints | ||
|
|
||
| @staticmethod | ||
| def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): | ||
| # Attach the line number to comment mapping to the function definition node | ||
| tree.body[0].line_flagtree_hints = line_flagtree_hints | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.