Skip to content

Commit

Permalink
[Metax] fix accuracy error of scatter in metax
Browse files Browse the repository at this point in the history
  • Loading branch information
mx-flaggems-user committed Feb 27, 2025
1 parent 43d3c3f commit 71f4ea0
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.newline()
code.writeline("from flag_gems.utils import libentry")
code.writeline("from flag_gems import runtime")
code.writeline("import flag_gems")
# code.writeline("from flag_gems.utils import triton_lang_extension as tle")
code.newline()
code.newline()
Expand All @@ -35,6 +36,9 @@ def generate_scatter_kernel(

code.writeline("def heur_block(args):")
with code.indent():
code.writeline("if(flag_gems.vendor_name=='metax'):")
with code.indent():
code.writeline("return 256")
code.writeline("return 128")
code.newline()
code.newline()
Expand Down

0 comments on commit 71f4ea0

Please sign in to comment.