-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use "true" and "false" instead of 0 and 1 #890
base: main
Are you sure you want to change the base?
Conversation
Jax/XLA seem to expect true and false strings instead of using 0 and 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Is there a test that can added or updated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on adding test case.
@@ -24,10 +24,10 @@ def f(x: Tensor) -> Tensor: | |||
) | |||
self.assertEqual(f_compiled(5), 15) | |||
|
|||
def atest_xla_flags_from_options(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test was never being run because of the a
before test
in function name.
I updated the existing test and renamed the test so it's being run correctly now. |
@@ -144,7 +144,7 @@ def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str | |||
flags = [] | |||
for k, v in xla_options.items(): | |||
if isinstance(v, bool): | |||
v = "1" if v else "0" | |||
v = "true" if v else "false" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we convert all the "true" and "false" strings in default_xla_options
to bools to be consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that should be preferred. I just wanted to keep this PR small to make merging and reviewing easier. Happy to include that change in this PR or do a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think it makes sense to include in this PR.
@@ -144,7 +144,7 @@ def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str | |||
flags = [] | |||
for k, v in xla_options.items(): | |||
if isinstance(v, bool): | |||
v = "1" if v else "0" | |||
v = "true" if v else "false" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to do an aot compilation or run the trainer for a few steps to confirm these changes work? (If you haven't already?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did test it on v6e afair last year, but any additional runs would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe test it again on v5p using the jax/jaxlib version that is currently pinned in AXLearn? (0.4.33 IIRC)
Jax/XLA expects true and false strings instead of using 0 and 1 for certain kind of boolean flags (TriStateFlag which also accepts "auto").
This should solve this issue:
Confirmed by Jax team that we should use "true" and "false" so TriStateflags which are like boolean flags work as well.