-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metax] fix accuracy error of scatter in metax #459
Conversation
src/flag_gems/ops/scatter.py
Outdated
@@ -17,6 +17,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: | |||
code.newline() | |||
code.writeline("from flag_gems.utils import libentry") | |||
code.writeline("from flag_gems import runtime") | |||
code.writeline("from flag_gems.runtime import device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import flag_gems
and now has attr vendor_name
, metax can use flag_gems.vendor_name
,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/flag_gems/ops/scatter.py
Outdated
@@ -35,6 +36,9 @@ def generate_scatter_kernel( | |||
|
|||
code.writeline("def heur_block(args):") | |||
with code.indent(): | |||
code.writeline("if(device.get_vendor_name()=='metax'):") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flag_gems.vendor_name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/flag_gems/ops/scatter.py
Outdated
@@ -17,6 +17,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: | |||
code.newline() | |||
code.writeline("from flag_gems.utils import libentry") | |||
code.writeline("from flag_gems import runtime") | |||
code.writeline("from flag_gems") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import flag_gems, not from flag_gems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
1b68e14
to
71f4ea0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
PR Category
Operater
Type of Change
Bug Fix
Description
We found that scatter with
heur_block=128
will calculate wrong answer, so need to change it to 256 while running in metax backend