-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convnext architecture dev #356
base: development
Are you sure you want to change the base?
Convnext architecture dev #356
Conversation
Test errors on CI: In my local, all tests run ok. My TF version is |
Hi @Lorenzobattistela, looks like convnext wasn't introduced until TF v2.10. I wonder if we can do a TF version check for this? The other option is we could provide a general architecture class that accepts a tf.keras.application as input and wraps it. I'm not sure if it would be simple to apply to all applications, but would be cleaner for the package and avoids the version issue altogether. |
So, @owenvallis I thought about what you said. Anyway, I updated the code to do some version checking on test (to skip if tf < 2.10), and maybe we can add some version checking to inform a more useful error to the user if it tries to use it with a minor tf version. However, I think the refactoring path to a wrapper for keras applications is the best approach. I'm willing to work on this, will start refactoring it. It is up to you to merge or not this, maybe we can use this as a "hotfix" and then refactor to something better. Thanks for the review. |
@owenvallis something is going wrong with isort, but i did ran it |
convnext.trainable = True | ||
for layer in convnext.layers: | ||
# freeze all layeres befor the last 3 blocks | ||
if not re.search("^block[5,6,7]|^top", layer.name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also trying out this architecture. But does this EfficientNetV2 layer naming apply to convnext?
model = tf.keras.applications.ConvNeXtBase()
[l.name for l in model.layers if re.search("^block[5,6,7]|^top", l.name)]
# this outputs []
The test also suggests partial
is not being applied as expected since the number of trainable layers is 0 with partial.
edit: another candidate might be "convnext_base_stage_3_block_2"
, also unfreezing the last layer norm since it comes after the final block.
model.trainable = True
for layer in model.layers:
# freeze all layers before the last block
if not re.search("^convnext_base_stage_3_block_2", layer.name):
layer.trainable = False
model.layers[-1].trainable = True
This results in about 10% of weights being unfrozen and only the final block [1].
Total params: 87566464 (334.04 MB)
Trainable params: 8450048 (32.23 MB)
Non-trainable params: 79116416 (301.81 MB)
Implementing ConvNeXt architecture referred in this paper and #353 .
Reviewer: @owenvallis
This PR re implements #354 but based on the development branch to fix some test and formatting issues.