diff --git a/.gitignore b/.gitignore index afd700b49952..416f213f2c82 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ **/.vscode test/** **/.vscode-smoke/** **/.venv*/ +venv bin/** build/** obj/** diff --git a/integration_tests/test_save_img.py b/integration_tests/test_save_img.py new file mode 100644 index 000000000000..6ec7951564cb --- /dev/null +++ b/integration_tests/test_save_img.py @@ -0,0 +1,27 @@ +import os + +import numpy as np +import pytest + +from keras.utils import img_to_array +from keras.utils import load_img +from keras.utils import save_img + + +@pytest.mark.parametrize( + "shape, name", + [ + ((50, 50, 3), "rgb.jpg"), + ((50, 50, 4), "rgba.jpg"), + ], +) +def test_save_jpg(tmp_path, shape, name): + img = np.random.randint(0, 256, size=shape, dtype=np.uint8) + path = tmp_path / name + save_img(path, img, file_format="jpg") + assert os.path.exists(path) + + # Check that the image was saved correctly and converted to RGB if needed. + loaded_img = load_img(path) + loaded_array = img_to_array(loaded_img) + assert loaded_array.shape == (50, 50, 3) diff --git a/keras/src/applications/imagenet_utils.py b/keras/src/applications/imagenet_utils.py index f88c0af64d88..7af0c659f62f 100644 --- a/keras/src/applications/imagenet_utils.py +++ b/keras/src/applications/imagenet_utils.py @@ -278,14 +278,28 @@ def _preprocess_tensor_input(x, data_format, mode): # Zero-center by mean pixel if data_format == "channels_first": - mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2)) + if ndim == 3: + mean_tensor = ops.reshape(mean_tensor, (3, 1, 1)) + elif ndim == 4: + mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1)) + else: + raise ValueError(f"Unsupported shape for channels_first: {x.shape}") else: mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,)) x += mean_tensor if std is not None: std_tensor = ops.convert_to_tensor(np.array(std), dtype=x.dtype) if data_format == "channels_first": - std_tensor = ops.reshape(std_tensor, (-1, 1, 1)) + if ndim == 3: + std_tensor = ops.reshape(std_tensor, (3, 1, 1)) + elif ndim == 4: + std_tensor = ops.reshape(std_tensor, (1, 3, 1, 1)) + else: + raise ValueError( + f"Unsupported shape for channels_first: {x.shape}" + ) + else: + std_tensor = ops.reshape(std_tensor, (1,) * (ndim - 1) + (3,)) x /= std_tensor return x diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index ca8289c9f9b7..abf5c413fde0 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) + # Normalize jpg → jpeg + if file_format is not None and file_format.lower() == "jpg": + file_format = "jpeg" img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + if img.mode == "RGBA" and file_format == "jpeg": warnings.warn( - "The JPG format does not support RGBA images, converting to RGB." + "The JPEG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") img.save(path, format=file_format, **kwargs)