Skip to content

Commit

Permalink
Merge pull request #26 from bhky/change-image-path-to-image-handle
Browse files Browse the repository at this point in the history
Support using PIL Image directly as input
  • Loading branch information
bhky authored Aug 18, 2024
2 parents ef9c389 + 4cd597c commit 587af95
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: [ "3.9", "3.10", "3.11" ]
python: [ "3.9", "3.10", "3.11", "3.12" ]
backend: [ tensorflow, jax ]
include:
- backend: tensorflow
Expand Down
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,23 @@ For more details, please refer to the [API](#api) section.
```python
import opennsfw2 as n2

# To get the NSFW probability of a single image.
image_path = "path/to/your/image.jpg"
# To get the NSFW probability of a single image, provide your image file path,
# or a `PIL.Image.Image` object.
image_handle = "path/to/your/image.jpg"

nsfw_probability = n2.predict_image(image_path)
nsfw_probability = n2.predict_image(image_handle)

# To get the NSFW probabilities of a list of images.
# This is better than looping with `predict_image` as the model will only be instantiated once
# and batching is used during inference.
image_paths = [
# To get the NSFW probabilities of a list of images, provide a list of file paths,
# or a list of `PIL.Image.Image` objects.
# Using this function is better than looping with `predict_image` as the model
# will only be instantiated once and batching is done during inference.
image_handles = [
"path/to/your/image1.jpg",
"path/to/your/image2.jpg",
# ...
]

nsfw_probabilities = n2.predict_images(image_paths)
nsfw_probabilities = n2.predict_images(image_handles)
```

## Video
Expand Down Expand Up @@ -154,8 +156,8 @@ Create an instance of the NSFW model, optionally with pre-trained weights from Y
### `predict_image`
End-to-end pipeline function from the input image to the predicted NSFW probability.
- Parameters:
- `image_path` (`str`): Path to the input image file.
The image format must be supported by Pillow.
- `image_handle` (`Union[str, PIL.Image.Image]`):
Path to the input image file with a format supported by Pillow, or a `PIL.Image.Image` object.
- `preprocessing`: Same as that in `preprocess_image`.
- `weights_path`: Same as that in `make_open_nsfw_model`.
- `grad_cam_path` (`Optional[str]`, default `None`): If not `None`, e.g., `cam.jpg`,
Expand All @@ -171,8 +173,9 @@ End-to-end pipeline function from the input image to the predicted NSFW probabil
### `predict_images`
End-to-end pipeline function from the input images to the predicted NSFW probabilities.
- Parameters:
- `image_paths` (`Sequence[str]`): List of paths to the input image files.
The image format must be supported by Pillow.
- `image_handles` (`Union[Sequence[str], Sequence[PIL.Image.Image]]`):
List of paths to the input image files with formats supported by Pillow,
or list of `PIL.Image.Image` objects.
- `batch_size` (`int`, default `8`): Batch size to be used for model inference.
Choose a value that works the best with your device resources.
- `preprocessing`: Same as that in `preprocess_image`.
Expand Down
41 changes: 24 additions & 17 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Inference utilities.
"""
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
Expand All @@ -28,18 +28,24 @@ def _update_global_model_if_needed(weights_path: Optional[str]) -> None:
global_model_path = weights_path


def _load_pil_image(image_handle: Union[str, Image.Image]) -> Image.Image:
if isinstance(image_handle, Image.Image):
return image_handle
return Image.open(image_handle)


def predict_image(
image_path: str,
image_handle: Union[str, Image.Image],
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_path: Optional[str] = None,
alpha: float = 0.8
) -> float:
"""
Pipeline from single image path to predicted NSFW probability.
Pipeline from single image handle to predicted NSFW probability.
Optionally generate and save the Grad-CAM plot.
"""
pil_image = Image.open(image_path)
pil_image = _load_pil_image(image_handle)
image = preprocess_image(pil_image, preprocessing)
_update_global_model_if_needed(weights_path)
assert global_model is not None
Expand All @@ -56,9 +62,9 @@ def predict_image(
return nsfw_probability


def _predict_from_image_paths_in_batches(
def _predict_from_image_handles_in_batches(
model_: Model,
image_paths: Sequence[str],
image_handles: Union[Sequence[str], Sequence[Image.Image]],
batch_size: int,
preprocessing: Preprocessing
) -> NDFloat32Array:
Expand All @@ -68,43 +74,44 @@ def _predict_from_image_paths_in_batches(
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[Any] = []
for i in range(0, len(image_paths), batch_size):
path_batch = image_paths[i: i + batch_size]
for i in range(0, len(image_handles), batch_size):
handle_batch = image_handles[i: i + batch_size]
image_batch = [
preprocess_image(Image.open(path), preprocessing)
for path in path_batch
preprocess_image(_load_pil_image(handle), preprocessing)
for handle in handle_batch
]
prediction_batches.append(model_(np.array(image_batch)))
predictions: NDFloat32Array = np.concatenate(prediction_batches, axis=0)
return predictions


def predict_images(
image_paths: Sequence[str],
image_handles: Union[Sequence[str], Sequence[Image.Image]],
batch_size: int = 8,
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_paths: Optional[Sequence[str]] = None,
alpha: float = 0.8
) -> List[float]:
"""
Pipeline from image paths to predicted NSFW probabilities.
Pipeline from image handles to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
_update_global_model_if_needed(weights_path)
predictions = _predict_from_image_paths_in_batches(
global_model, image_paths, batch_size, preprocessing
predictions = _predict_from_image_handles_in_batches(
global_model, image_handles, batch_size, preprocessing
)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
# TensorFlow will only be imported here.
from ._inspection import make_and_save_nsfw_grad_cam

for image_path, grad_cam_path in zip(image_paths, grad_cam_paths):
for image_handle, grad_cam_path in zip(image_handles, grad_cam_paths):
assert isinstance(image_handle, (str, Image.Image)) # For mypy.
pil_image = _load_pil_image(image_handle)
make_and_save_nsfw_grad_cam(
Image.open(image_path), preprocessing, global_model,
grad_cam_path, alpha
pil_image, preprocessing, global_model, grad_cam_path, alpha
)

return nsfw_probabilities
Expand Down
6 changes: 6 additions & 0 deletions tests/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Sequence

from keras import backend as keras_backend
from PIL import Image

import opennsfw2 as n2

Expand Down Expand Up @@ -61,9 +62,14 @@ def test_predict_images_simple_preprocessing(self) -> None:
self._assert(expected_probabilities, predicted_probabilities)

def test_predict_image(self) -> None:
# From path.
self.assertAlmostEqual(
0.983, n2.predict_image(IMAGE_PATHS[1]), places=3
)
# From PIL Image.
self.assertAlmostEqual(
0.983, n2.predict_image(Image.open(IMAGE_PATHS[1])), places=3
)

def test_predict_video_frames(self) -> None:
elapsed_seconds, nsfw_probabilities = n2.predict_video_frames(
Expand Down
2 changes: 1 addition & 1 deletion tests/run_code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set -e
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --errors-only
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --exit-zero
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 mypy --strict --implicit-reexport --disable-error-code attr-defined --disable-error-code unused-ignore
find tests -iname "*.py" | xargs -L 1 python3 -m unittest
find tests -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 python3 -m unittest

0 comments on commit 587af95

Please sign in to comment.