-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
configure random seed dtype based on backend #19928
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #19928 +/- ##
==========================================
- Coverage 79.01% 78.97% -0.04%
==========================================
Files 499 499
Lines 46441 46506 +65
Branches 8550 8561 +11
==========================================
+ Hits 36694 36730 +36
- Misses 8020 8044 +24
- Partials 1727 1732 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. This is a reasonable change.
@@ -14,15 +14,15 @@ def tf_draw_seed(seed): | |||
|
|||
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): | |||
dtype = dtype or floatx() | |||
seed = tf_draw_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the tf_draw_seed function then
keras/src/backend/jax/core.py
Outdated
@@ -346,6 +346,11 @@ def unstack(x, num=None, axis=0): | |||
] | |||
|
|||
|
|||
def random_seed_dtype(): | |||
# jax random seed uses uint32. | |||
return standardize_dtype("uint32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
standardize_dtype
will be a no-op here. You can skip it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you
It seems this breaks TF GPU CI due to a strange issue -- TF automatically places |
seeing below error on running tensorflow distributed training with multiple workers. the issue being seed state getting broadcasted is in a dtype that tensorflow doesn't support:
this pr introduces a seed dtype function to customize for an ideal seed dtype for backends.