-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix: handle 3D input for channels_first in preprocess_input() #21749 #21754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…andling and tests
Summary of ChangesHello @utsab345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily addresses a critical bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21754 +/- ##
==========================================
+ Coverage 82.63% 84.83% +2.20%
==========================================
Files 572 572
Lines 58555 58566 +11
Branches 9153 9158 +5
==========================================
+ Hits 48385 49687 +1302
+ Misses 7843 6433 -1410
- Partials 2327 2446 +119
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a shape mismatch issue in preprocess_input for 3D inputs with data_format='channels_first'. The changes are well-tested and also include a nice improvement in save_img to normalize the 'jpg' format. I've suggested a couple of improvements in imagenet_utils.py to improve the error messages and reduce code duplication. Overall, this is a good contribution.
| if ndim == 3: | ||
| mean_tensor = ops.reshape(mean_tensor, (3, 1, 1)) | ||
| elif ndim == 4: | ||
| mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1)) | ||
| else: | ||
| raise ValueError(f"Unsupported shape for channels_first: {x.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this logic correctly handles 3D and 4D inputs, the error message could be more informative. According to the Keras API design guidelines, a good error message should explain what was expected and how the user can fix it.1 This message only states what was received. Consider clarifying that only 3D and 4D tensors are supported for channels_first.
Additionally, you can make the code slightly cleaner by checking for the invalid ndim case first.
| if ndim == 3: | |
| mean_tensor = ops.reshape(mean_tensor, (3, 1, 1)) | |
| elif ndim == 4: | |
| mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1)) | |
| else: | |
| raise ValueError(f"Unsupported shape for channels_first: {x.shape}") | |
| if ndim not in (3, 4): | |
| raise ValueError( | |
| f"Unsupported tensor rank: {ndim}. With `data_format='channels_first'`, " | |
| "`preprocess_input` only supports 3D (single image) and 4D (batch of " | |
| f"images) tensors. Received tensor with shape: {x.shape}" | |
| ) | |
| if ndim == 3: | |
| mean_tensor = ops.reshape(mean_tensor, (3, 1, 1)) | |
| elif ndim == 4: | |
| mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1)) |
Style Guide References
Footnotes
-
The style guide states that error messages should be contextual, informative, and actionable, explaining what happened, what was expected, and how to fix it. ↩
| if data_format == "channels_first": | ||
| std_tensor = ops.reshape(std_tensor, (-1, 1, 1)) | ||
| if ndim == 3: | ||
| std_tensor = ops.reshape(std_tensor, (3, 1, 1)) | ||
| elif ndim == 4: | ||
| std_tensor = ops.reshape(std_tensor, (1, 3, 1, 1)) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported shape for channels_first: {x.shape}" | ||
| ) | ||
| else: | ||
| std_tensor = ops.reshape(std_tensor, (1,) * (ndim - 1) + (3,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for reshaping std_tensor is a duplicate of the logic used for mean_tensor above. To improve maintainability and reduce code duplication, consider refactoring this. You could determine the reshape_shape once at the start of the data_format == 'channels_first' block and reuse it for both tensors. This would also apply to the error handling logic.
| """ | ||
| data_format = backend.standardize_data_format(data_format) | ||
| # Normalize jpg → jpeg | ||
| if file_format is not None and file_format.lower() == "jpg": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert changes in this file
| ((50, 50, 4), "rgba.jpg"), | ||
| ], | ||
| ) | ||
| def test_save_jpg(tmp_path, shape, name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a TestCase subclass and use self.assert... methods instead of naked assert staments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Summary
Fixed a shape mismatch issue when
preprocess_input()receives 3D input (single image)with
data_format='channels_first'. The bug occurred because the mean and std tensorswere reshaped for 4D input only.
Changes
Test
Closes #21749