diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 18865af026cf..1599cf61f1b8 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -9,6 +9,7 @@ import numpy as np +from keras import backend from keras.src.api_export import keras_export from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter @@ -94,6 +95,21 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): self._workers = workers self._use_multiprocessing = use_multiprocessing self._max_queue_size = max_queue_size + backend_name = backend.backend() + if backend_name not in ("torch", "jax", "tensorflow", "numpy"): + raise ValueError( + f"PyDataset supports tf, torch, jax, numpy backend" + f"Received unsupported backend: '{backend_name}'." + ) + # Optionally warn if using TF (since tf.data.Dataset is better) + if backend_name == "tensorflow": + import warnings + + warnings.warn( + "You are using PyDataset with the TensorFlow backend. " + "Consider using `tf.data.Dataset` for better performance.", + stacklevel=2, + ) def _warn_if_super_not_called(self): warn = False diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index 492deb764c3e..aed249812dd6 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -17,15 +17,28 @@ def __init__(self, dataset, class_weight=None, distribution=None): shard the input dataset into per worker/process dataset instance. """ + import keras from keras.src.utils.module_utils import tensorflow as tf + # --- ✅ Backend compatibility check --- + backend = keras.backend.backend() + if backend not in ("tensorflow", "numpy", "torch", "jax"): + raise ValueError( + f"Incompatible backend '{backend}' for TFDatasetAdapter. " + "This adapter only supports the TensorFlow , numpy , torch ," + " jax backend." + ) + + # --- ✅ Dataset type validation --- if not isinstance( dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) ): raise ValueError( - "Expected argument `dataset` to be a tf.data.Dataset. " + "Expected argument `dataset` to be a tf.data.Dataset or " + "tf.distribute.DistributedDataset. " f"Received: {dataset}" ) + if class_weight is not None: dataset = dataset.map( make_class_weight_map_fn(class_weight) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index f0b2f524f4dd..2600c9b58308 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -11,8 +11,19 @@ class TorchDataLoaderAdapter(DataAdapter): """Adapter that handles `torch.utils.data.DataLoader`.""" def __init__(self, dataloader): + # --- ✅ Backend compatibility check --- import torch + import keras + + backend = keras.backend.backend() + if backend not in ("torch", "tensorflow", "numpy", "jax"): + raise ValueError( + f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. " + "This adapter only supports the PyTorch, tensorflow, jax, numpy" + " backend. " + ) + if not isinstance(dataloader, torch.utils.data.DataLoader): raise ValueError( f"Expected argument `dataloader` to be an instance of"