Skip to content

Commit

Permalink
Support using PIL Image objects as inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed Aug 17, 2024
1 parent ef9c389 commit 25aedb6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ For more details, please refer to the [API](#api) section.
import opennsfw2 as n2

# To get the NSFW probability of a single image.
image_path = "path/to/your/image.jpg"
image_handle = "path/to/your/image.jpg" # Alternatively, a `PIL.Image.Image` object.

nsfw_probability = n2.predict_image(image_path)
nsfw_probability = n2.predict_image(image_handle)

# To get the NSFW probabilities of a list of images.
# This is better than looping with `predict_image` as the model will only be instantiated once
# and batching is used during inference.
image_paths = [
"path/to/your/image1.jpg",
image_handles = [
"path/to/your/image1.jpg", # Alternatively, a list of `PIL.Image.Image` objects.
"path/to/your/image2.jpg",
# ...
]

nsfw_probabilities = n2.predict_images(image_paths)
nsfw_probabilities = n2.predict_images(image_handles)
```

## Video
Expand Down Expand Up @@ -154,8 +154,8 @@ Create an instance of the NSFW model, optionally with pre-trained weights from Y
### `predict_image`
End-to-end pipeline function from the input image to the predicted NSFW probability.
- Parameters:
- `image_path` (`str`): Path to the input image file.
The image format must be supported by Pillow.
- `image_handle` (`Union[str, PIL.Image.Image]`):
Path to the input image file with a format supported by Pillow, or a `PIL.Image.Image` object.
- `preprocessing`: Same as that in `preprocess_image`.
- `weights_path`: Same as that in `make_open_nsfw_model`.
- `grad_cam_path` (`Optional[str]`, default `None`): If not `None`, e.g., `cam.jpg`,
Expand All @@ -171,8 +171,9 @@ End-to-end pipeline function from the input image to the predicted NSFW probabil
### `predict_images`
End-to-end pipeline function from the input images to the predicted NSFW probabilities.
- Parameters:
- `image_paths` (`Sequence[str]`): List of paths to the input image files.
The image format must be supported by Pillow.
- `image_handles` (`Union[Sequence[str], Sequence[PIL.Image.Image]]`):
List of paths to the input image files with formats supported by Pillow,
or list of `PIL.Image.Image` objects.
- `batch_size` (`int`, default `8`): Batch size to be used for model inference.
Choose a value that works the best with your device resources.
- `preprocessing`: Same as that in `preprocess_image`.
Expand Down
34 changes: 20 additions & 14 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,24 @@ def _update_global_model_if_needed(weights_path: Optional[str]) -> None:
global_model_path = weights_path


def _load_pil_image(image_handle: str | Image.Image) -> Image.Image:
if isinstance(image_handle, Image.Image):
return image_handle
return Image.open(image_handle)


def predict_image(
image_path: str,
image_handle: str,
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_path: Optional[str] = None,
alpha: float = 0.8
) -> float:
"""
Pipeline from single image path to predicted NSFW probability.
Pipeline from single image handle to predicted NSFW probability.
Optionally generate and save the Grad-CAM plot.
"""
pil_image = Image.open(image_path)
pil_image = _load_pil_image(image_handle)
image = preprocess_image(pil_image, preprocessing)
_update_global_model_if_needed(weights_path)
assert global_model is not None
Expand All @@ -56,9 +62,9 @@ def predict_image(
return nsfw_probability


def _predict_from_image_paths_in_batches(
def _predict_from_image_handles_in_batches(
model_: Model,
image_paths: Sequence[str],
image_handles: Sequence[str],
batch_size: int,
preprocessing: Preprocessing
) -> NDFloat32Array:
Expand All @@ -68,10 +74,10 @@ def _predict_from_image_paths_in_batches(
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[Any] = []
for i in range(0, len(image_paths), batch_size):
path_batch = image_paths[i: i + batch_size]
for i in range(0, len(image_handles), batch_size):
path_batch = image_handles[i: i + batch_size]
image_batch = [
preprocess_image(Image.open(path), preprocessing)
preprocess_image(_load_pil_image(path), preprocessing)
for path in path_batch
]
prediction_batches.append(model_(np.array(image_batch)))
Expand All @@ -80,30 +86,30 @@ def _predict_from_image_paths_in_batches(


def predict_images(
image_paths: Sequence[str],
image_handles: Sequence[str],
batch_size: int = 8,
preprocessing: Preprocessing = Preprocessing.YAHOO,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_paths: Optional[Sequence[str]] = None,
alpha: float = 0.8
) -> List[float]:
"""
Pipeline from image paths to predicted NSFW probabilities.
Pipeline from image handles to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
_update_global_model_if_needed(weights_path)
predictions = _predict_from_image_paths_in_batches(
global_model, image_paths, batch_size, preprocessing
predictions = _predict_from_image_handles_in_batches(
global_model, image_handles, batch_size, preprocessing
)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
# TensorFlow will only be imported here.
from ._inspection import make_and_save_nsfw_grad_cam

for image_path, grad_cam_path in zip(image_paths, grad_cam_paths):
for image_handle, grad_cam_path in zip(image_handles, grad_cam_paths):
make_and_save_nsfw_grad_cam(
Image.open(image_path), preprocessing, global_model,
_load_pil_image(image_handle), preprocessing, global_model,
grad_cam_path, alpha
)

Expand Down

0 comments on commit 25aedb6

Please sign in to comment.