Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use "true" and "false" instead of 0 and 1 #890

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

samos123
Copy link
Contributor

@samos123 samos123 commented Dec 12, 2024

Jax/XLA expects true and false strings instead of using 0 and 1 for certain kind of boolean flags (TriStateFlag which also accepts "auto").

This should solve this issue:

ERROR: Illegal value '1' specified for flag 'xla_should_add_loop_invariant_op_in_chain'; expected one of true/enabled, false/disabled or auto

Confirmed by Jax team that we should use "true" and "false" so TriStateflags which are like boolean flags work as well.

Jax/XLA seem to expect true and false strings instead of using 0 and 1.
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Is there a test that can added or updated?

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on adding test case.

@samos123 samos123 requested a review from a team as a code owner January 17, 2025 21:20
@@ -24,10 +24,10 @@ def f(x: Tensor) -> Tensor:
)
self.assertEqual(f_compiled(5), 15)

def atest_xla_flags_from_options(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test was never being run because of the a before test in function name.

@samos123
Copy link
Contributor Author

I updated the existing test and renamed the test so it's being run correctly now.

@samos123 samos123 requested a review from markblee January 17, 2025 21:21
@@ -144,7 +144,7 @@ def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str
flags = []
for k, v in xla_options.items():
if isinstance(v, bool):
v = "1" if v else "0"
v = "true" if v else "false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we convert all the "true" and "false" strings in default_xla_options to bools to be consistent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be preferred. I just wanted to keep this PR small to make merging and reviewing easier. Happy to include that change in this PR or do a follow up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think it makes sense to include in this PR.

@@ -144,7 +144,7 @@ def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str
flags = []
for k, v in xla_options.items():
if isinstance(v, bool):
v = "1" if v else "0"
v = "true" if v else "false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to do an aot compilation or run the trainer for a few steps to confirm these changes work? (If you haven't already?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did test it on v6e afair last year, but any additional runs would be great.

Copy link
Contributor

@apghml apghml Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe test it again on v5p using the jax/jaxlib version that is currently pinned in AXLearn? (0.4.33 IIRC)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants