Skip to content
Open
16 changes: 16 additions & 0 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np

from keras import backend
from keras.src.api_export import keras_export
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
Expand Down Expand Up @@ -94,6 +95,21 @@ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10):
self._workers = workers
self._use_multiprocessing = use_multiprocessing
self._max_queue_size = max_queue_size
backend_name = backend.backend()
if backend_name not in ("torch", "jax", "tensorflow", "numpy"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a no-op, all the backends are in the list of supported backends.

Copy link
Author

@Shekar-77 Shekar-77 Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My code was raising an error if the dataset and the backend weren't from the same framework, ex: tensorflow dataset and tensorflow backend. But the code wasn't passing the tests, error basically said that tf dataset is ok to be used with any backend even though we get an error if the backend is not the same. If it's open I will would like to work on this issue.

raise ValueError(
f"PyDataset supports tf, torch, jax, numpy backend"
f"Received unsupported backend: '{backend_name}'."
)
# Optionally warn if using TF (since tf.data.Dataset is better)
if backend_name == "tensorflow":
import warnings

warnings.warn(
"You are using PyDataset with the TensorFlow backend. "
"Consider using `tf.data.Dataset` for better performance.",
stacklevel=2,
)

def _warn_if_super_not_called(self):
warn = False
Expand Down
15 changes: 14 additions & 1 deletion keras/src/trainers/data_adapters/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,28 @@ def __init__(self, dataset, class_weight=None, distribution=None):
shard the input dataset into per worker/process dataset
instance.
"""
import keras
from keras.src.utils.module_utils import tensorflow as tf

# --- ✅ Backend compatibility check ---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be for debugging or development purposes and adds some noise to the code. It would be cleaner to remove it.

backend = keras.backend.backend()
if backend not in ("tensorflow", "numpy", "torch", "jax"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a no-op, all the backends are in the list of supported backends.

raise ValueError(
f"Incompatible backend '{backend}' for TFDatasetAdapter. "
"This adapter only supports the TensorFlow , numpy , torch ,"
" jax backend."
)

# --- ✅ Dataset type validation ---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be for debugging or development purposes and adds some noise to the code. It would be cleaner to remove it.

if not isinstance(
dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)
):
raise ValueError(
"Expected argument `dataset` to be a tf.data.Dataset. "
"Expected argument `dataset` to be a tf.data.Dataset or "
"tf.distribute.DistributedDataset. "
f"Received: {dataset}"
)

if class_weight is not None:
dataset = dataset.map(
make_class_weight_map_fn(class_weight)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/trainers/data_adapters/torch_data_loader_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,19 @@ class TorchDataLoaderAdapter(DataAdapter):
"""Adapter that handles `torch.utils.data.DataLoader`."""

def __init__(self, dataloader):
# --- ✅ Backend compatibility check ---
import torch

import keras

backend = keras.backend.backend()
if backend not in ("torch", "tensorflow", "numpy", "jax"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a no-op, all the backends are in the list of supported backends.

raise ValueError(
f"Incompatible backend '{backend}' for TorchDataLoaderAdapter. "
"This adapter only supports the PyTorch, tensorflow, jax, numpy"
" backend. "
)

if not isinstance(dataloader, torch.utils.data.DataLoader):
raise ValueError(
f"Expected argument `dataloader` to be an instance of"
Expand Down