Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configure random seed dtype based on backend #19928

Merged
merged 2 commits into from
Jun 28, 2024

Conversation

haohuanw
Copy link
Contributor

seeing below error on running tensorflow distributed training with multiple workers. the issue being seed state getting broadcasted is in a dtype that tensorflow doesn't support:

Exception encountered: ''Value for attr 'T' of uint32 is not in the list of allowed values: bool, float, half, double, int32, int64
	; NodeDef: {{node CollectiveBcastSend}}; Op<name=CollectiveBcastSend; signature=input:T -> data:T; attr=T:type,allowed=[DT_BOOL, DT_FLOAT, DT_HALF, DT_DOUBLE, DT_INT32, DT_INT64]; attr=group_size:int; attr=group_key:int; attr=instance_key:int; attr=shape:shape; attr=communication_hint:string,default="auto"; attr=timeout_seconds:float,default=0; is_stateful=true; is_distributed_communication=true> [Op:CollectiveBcastSend]''

this pr introduces a seed dtype function to customize for an ideal seed dtype for backends.

@codecov-commenter
Copy link

codecov-commenter commented Jun 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.97%. Comparing base (c8a7f28) to head (8033c45).
Report is 21 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19928      +/-   ##
==========================================
- Coverage   79.01%   78.97%   -0.04%     
==========================================
  Files         499      499              
  Lines       46441    46506      +65     
  Branches     8550     8561      +11     
==========================================
+ Hits        36694    36730      +36     
- Misses       8020     8044      +24     
- Partials     1727     1732       +5     
Flag Coverage Δ
keras 78.83% <100.00%> (-0.04%) ⬇️
keras-jax 62.41% <50.00%> (+<0.01%) ⬆️
keras-numpy 57.32% <58.33%> (+0.10%) ⬆️
keras-tensorflow 63.60% <79.16%> (-0.03%) ⬇️
keras-torch 62.38% <50.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This is a reasonable change.

@@ -14,15 +14,15 @@ def tf_draw_seed(seed):

def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the tf_draw_seed function then

@@ -346,6 +346,11 @@ def unstack(x, num=None, axis=0):
]


def random_seed_dtype():
# jax random seed uses uint32.
return standardize_dtype("uint32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

standardize_dtype will be a no-op here. You can skip it.

@haohuanw haohuanw requested a review from fchollet June 28, 2024 04:46
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Jun 28, 2024
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 28, 2024
PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer Jun 28, 2024
@fchollet fchollet merged commit 272af9c into keras-team:master Jun 28, 2024
6 checks passed
PR Queue automation moved this from Approved by Reviewer to Merged Jun 28, 2024
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Jun 28, 2024
@fchollet
Copy link
Member

It seems this breaks TF GPU CI due to a strange issue -- TF automatically places int32 constants on CPU (this is different compared to every other dtype). I think using int64 instead would work -- trying it now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

4 participants