From 71f4ea0a00aea5c9e4e435facac54715ce460ce7 Mon Sep 17 00:00:00 2001 From: mx-flaggems-user Date: Tue, 25 Feb 2025 15:22:34 +0800 Subject: [PATCH 1/2] [Metax] fix accuracy error of scatter in metax --- src/flag_gems/ops/scatter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/flag_gems/ops/scatter.py b/src/flag_gems/ops/scatter.py index 4970395c0..d15e61334 100644 --- a/src/flag_gems/ops/scatter.py +++ b/src/flag_gems/ops/scatter.py @@ -17,6 +17,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.newline() code.writeline("from flag_gems.utils import libentry") code.writeline("from flag_gems import runtime") + code.writeline("import flag_gems") # code.writeline("from flag_gems.utils import triton_lang_extension as tle") code.newline() code.newline() @@ -35,6 +36,9 @@ def generate_scatter_kernel( code.writeline("def heur_block(args):") with code.indent(): + code.writeline("if(flag_gems.vendor_name=='metax'):") + with code.indent(): + code.writeline("return 256") code.writeline("return 128") code.newline() code.newline() From acc79c709c90c4754d5d928c8724c369852de490 Mon Sep 17 00:00:00 2001 From: mx-flaggems-user Date: Fri, 28 Feb 2025 11:33:53 +0800 Subject: [PATCH 2/2] [Metax] modify metax backend attention config --- src/flag_gems/runtime/backend/_metax/tune_configs.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml index 7afa01729..e61f44a7b 100644 --- a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml @@ -22,8 +22,7 @@ attention: - 8 stages: - 1 - - 2 - - 3 + bmm: - META: TILE_M: 32