-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tweak torch parameter registration mechanism #19908
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19908 +/- ##
==========================================
- Coverage 79.01% 73.44% -5.57%
==========================================
Files 499 499
Lines 46441 46476 +35
Branches 8550 8556 +6
==========================================
- Hits 36694 34134 -2560
- Misses 8020 10670 +2650
+ Partials 1727 1672 -55
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
the failing pytorch test is actually passing on my env:
|
The uniqueness of the variable path should come from the parent object name, not from the variable name (e.g. |
for seed generator if using path it will be |
Variable names are never unique. For a unique string you can use |
i thought under same layer the variable name (excluding the variables from its sub layers) should be unique as an implicit requirement since otherwise my original thought is that then i found out that all seed_generator actually can actually create with same variable name if there are multiple seed generator in one layer since seed generator is not a layer. |
and i do notice that i probably want to add a test with nested seed generator. in theory, seed states should be recursively collected by torch since it basically get all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good -- thanks for the changes. I will apply docstring fixes after merging.
Works for me locally as well. Might be a fluke. |
There are actually various tests that reliably fail here: https://btx.cloud.google.com/invocations/c55a2ca4-5df3-411b-bd52-7c9873e839ce/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fpresubmit/log (not the numerical integration test) |
i will address those today / tmr 👍 - and is it possible to configure ci to run pytest regardless whether integration test passes or not? |
We'd have to move the integration testing to go after the general pytest command in |
i am seeing a weird issue on
i am able to isolate that the model json looks good but restored model here: https://github.com/keras-team/keras/blob/master/keras/src/saving/saving_lib.py#L242 have duplicated |
I don't understand the connection. You could try pruning things from your change until the test passes, then you'll have a good idea what particular lines are causing the issue. |
@fchollet most of the unit tests are fixed with one issue left. torch basically requires user to use there are two options that i think could work, let me know your thoughts:
let me know what do you think. |
I think we could do this, via |
in theory - let me try it |
it technically works but i think this will be a pretty impact workflow change for pytorch users:
I think supporting a |
this is a follow up from #19885 discussion where i am trying to make torch / keras well played together on tracking parameters.
the solution i ended up with:
recurse=True
variable.name
with just tracking variables the current layer holds. however, current seed generator actually create duplicated variable names. if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something likef"{self.name}_generator_state"
it will work with ParameterDict approach._post_track/untrack_variables
, refresh the entire torch params and it's sublayers. this could be changed to not re-create all sublayers if this function ever becomes too slow.i also added few torch specific tests to reflect some of the assumptions and usecases that torch user might have. eg. use
state_dict
.