Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using PIL Image directly as input #26

Merged
merged 11 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: [ "3.9", "3.10", "3.11" ]
python: [ "3.9", "3.10", "3.11", "3.12" ]
backend: [ tensorflow, jax ]
include:
- backend: tensorflow
Expand Down
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,23 @@ For more details, please refer to the [API](#api) section.
```python
import opennsfw2 as n2

# To get the NSFW probability of a single image.
image_path = "path/to/your/image.jpg"
# To get the NSFW probability of a single image, provide your image file path,
# or a `PIL.Image.Image` object.
image_handle = "path/to/your/image.jpg"

nsfw_probability = n2.predict_image(image_path)
nsfw_probability = n2.predict_image(image_handle)

# To get the NSFW probabilities of a list of images.
# This is better than looping with `predict_image` as the model will only be instantiated once
# and batching is used during inference.
image_paths = [
# To get the NSFW probabilities of a list of images, provide a list of file paths,
# or a list of `PIL.Image.Image` objects.
# Using this function is better than looping with `predict_image` as the model
# will only be instantiated once and batching is done during inference.
image_handles = [
"path/to/your/image1.jpg",
"path/to/your/image2.jpg",
# ...
]

nsfw_probabilities = n2.predict_images(image_paths)
nsfw_probabilities = n2.predict_images(image_handles)
```

## Video
Expand Down Expand Up @@ -154,8 +156,8 @@ Create an instance of the NSFW model, optionally with pre-trained weights from Y
### `predict_image`
End-to-end pipeline function from the input image to the predicted NSFW probability.
- Parameters:
- `image_path` (`str`): Path to the input image file.
The image format must be supported by Pillow.
- `image_handle` (`Union[str, PIL.Image.Image]`):
Path to the input image file with a format supported by Pillow, or a `PIL.Image.Image` object.
- `preprocessing`: Same as that in `preprocess_image`.
- `weights_path`: Same as that in `make_open_nsfw_model`.
- `grad_cam_path` (`Optional[str]`, default `None`): If not `None`, e.g., `cam.jpg`,
Expand All @@ -171,8 +173,9 @@ End-to-end pipeline function from the input image to the predicted NSFW probabil
### `predict_images`
End-to-end pipeline function from the input images to the predicted NSFW probabilities.
- Parameters:
- `image_paths` (`Sequence[str]`): List of paths to the input image files.
The image format must be supported by Pillow.
- `image_handles` (`Union[Sequence[str], Sequence[PIL.Image.Image]]`):
List of paths to the input image files with formats supported by Pillow,
or list of `PIL.Image.Image` objects.
- `batch_size` (`int`, default `8`): Batch size to be used for model inference.
Choose a value that works the best with your device resources.
- `preprocessing`: Same as that in `preprocess_image`.
Expand Down
41 changes: 24 additions & 17 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Inference utilities.
"""
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
Expand All @@ -28,18 +28,24 @@ def _update_global_model_if_needed(weights_path: Optional[str]) -> None:
global_model_path = weights_path


def _load_pil_image(image_handle: Union[str, Image.Image]) -> Image.Image:
if isinstance(image_handle, Image.Image):
return image_handle
return Image.open(image_handle)


def predict_image(
image_path: str,
image_handle: Union[str, Image.Image],
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_path: Optional[str] = None,
alpha: float = 0.8
) -> float:
"""
Pipeline from single image path to predicted NSFW probability.
Pipeline from single image handle to predicted NSFW probability.
Optionally generate and save the Grad-CAM plot.
"""
pil_image = Image.open(image_path)
pil_image = _load_pil_image(image_handle)
image = preprocess_image(pil_image, preprocessing)
_update_global_model_if_needed(weights_path)
assert global_model is not None
Expand All @@ -56,9 +62,9 @@ def predict_image(
return nsfw_probability


def _predict_from_image_paths_in_batches(
def _predict_from_image_handles_in_batches(
model_: Model,
image_paths: Sequence[str],
image_handles: Union[Sequence[str], Sequence[Image.Image]],
batch_size: int,
preprocessing: Preprocessing
) -> NDFloat32Array:
Expand All @@ -68,43 +74,44 @@ def _predict_from_image_paths_in_batches(
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[Any] = []
for i in range(0, len(image_paths), batch_size):
path_batch = image_paths[i: i + batch_size]
for i in range(0, len(image_handles), batch_size):
handle_batch = image_handles[i: i + batch_size]
image_batch = [
preprocess_image(Image.open(path), preprocessing)
for path in path_batch
preprocess_image(_load_pil_image(handle), preprocessing)
for handle in handle_batch
]
prediction_batches.append(model_(np.array(image_batch)))
predictions: NDFloat32Array = np.concatenate(prediction_batches, axis=0)
return predictions


def predict_images(
image_paths: Sequence[str],
image_handles: Union[Sequence[str], Sequence[Image.Image]],
batch_size: int = 8,
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_paths: Optional[Sequence[str]] = None,
alpha: float = 0.8
) -> List[float]:
"""
Pipeline from image paths to predicted NSFW probabilities.
Pipeline from image handles to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
_update_global_model_if_needed(weights_path)
predictions = _predict_from_image_paths_in_batches(
global_model, image_paths, batch_size, preprocessing
predictions = _predict_from_image_handles_in_batches(
global_model, image_handles, batch_size, preprocessing
)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
# TensorFlow will only be imported here.
from ._inspection import make_and_save_nsfw_grad_cam

for image_path, grad_cam_path in zip(image_paths, grad_cam_paths):
for image_handle, grad_cam_path in zip(image_handles, grad_cam_paths):
assert isinstance(image_handle, (str, Image.Image)) # For mypy.
pil_image = _load_pil_image(image_handle)
make_and_save_nsfw_grad_cam(
Image.open(image_path), preprocessing, global_model,
grad_cam_path, alpha
pil_image, preprocessing, global_model, grad_cam_path, alpha
)

return nsfw_probabilities
Expand Down
6 changes: 6 additions & 0 deletions tests/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Sequence

from keras import backend as keras_backend
from PIL import Image

import opennsfw2 as n2

Expand Down Expand Up @@ -61,9 +62,14 @@ def test_predict_images_simple_preprocessing(self) -> None:
self._assert(expected_probabilities, predicted_probabilities)

def test_predict_image(self) -> None:
# From path.
self.assertAlmostEqual(
0.983, n2.predict_image(IMAGE_PATHS[1]), places=3
)
# From PIL Image.
self.assertAlmostEqual(
0.983, n2.predict_image(Image.open(IMAGE_PATHS[1])), places=3
)

def test_predict_video_frames(self) -> None:
elapsed_seconds, nsfw_probabilities = n2.predict_video_frames(
Expand Down
2 changes: 1 addition & 1 deletion tests/run_code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set -e
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --errors-only
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --exit-zero
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 mypy --strict --implicit-reexport --disable-error-code attr-defined --disable-error-code unused-ignore
find tests -iname "*.py" | xargs -L 1 python3 -m unittest
find tests -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 python3 -m unittest