-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix: added backend validation for dataset adapters across backends #21789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix: added backend validation for dataset adapters across backends #21789
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @Shekar-77, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the robustness of Keras 3's data handling by implementing crucial backend compatibility validations within its Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds important backend validation to the dataset adapters, which is a great improvement for user experience by providing clear error messages on backend-dataset incompatibility. The changes in py_dataset_adapter.py, tf_dataset_adapter.py, and torch_data_loader_adapter.py are logical and well-implemented. I've added a few minor suggestions to improve code cleanliness by removing a duplicate import and some development-related comments.
| if not isinstance( | ||
| dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) | ||
| ): | ||
| # --- ✅ Backend compatibility check --- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "This adapter only supports the TensorFlow backend." | ||
| ) | ||
|
|
||
| # --- ✅ Dataset type validation --- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| import torch | ||
| import keras | ||
|
|
||
| # --- ✅ Backend compatibility check --- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21789 +/- ##
==========================================
- Coverage 82.63% 82.62% -0.02%
==========================================
Files 577 577
Lines 59316 59430 +114
Branches 9300 9317 +17
==========================================
+ Hits 49018 49106 +88
- Misses 7910 7916 +6
- Partials 2388 2408 +20
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @Shekar-77 , Can you please sign the CLA? Thank you ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self._use_multiprocessing = use_multiprocessing | ||
| self._max_queue_size = max_queue_size | ||
| backend_name = backend.backend() | ||
| if backend_name not in ("torch", "jax", "tensorflow", "numpy"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a no-op, all the backends are in the list of supported backends.
|
|
||
| # --- ✅ Backend compatibility check --- | ||
| backend = keras.backend.backend() | ||
| if backend not in ("tensorflow", "numpy", "torch", "jax"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a no-op, all the backends are in the list of supported backends.
| import keras | ||
|
|
||
| backend = keras.backend.backend() | ||
| if backend not in ("torch", "tensorflow", "numpy", "jax"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a no-op, all the backends are in the list of supported backends.
This PR improves backend compatibility checks for DatasetAdapter.
It raises clear errors when a backend (e.g., PyTorch) is used with an incompatible dataset type (e.g., TensorFlow or JAX).
As Keras 3 provides a collection of backend-agnostic operations, this PR ensures that the dataset type matches the active backend.
For example:
-Use tf.data.Dataset with the TensorFlow backend
-Use torch.utils.data.DataLoader with the PyTorch backend
-Use jax datasets with the JAX backend
Changes
Added backend-type validation to DatasetAdapter.init in:
-tf_dataset_adapter.py
-py_dataset_adapter.py
-torch_data_loader_adapter.py
Verified behavior using pytest.
When testing, the KERAS_BACKEND environment variable must match the dataset type.
For example:
-$env:KERAS_BACKEND = "jax"
-pytest keras/src/trainers/data_adapters/py_dataset_adapter_test.py -v
Related Issues
Fixes #21785