Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inference in loops usage with a global model #21

Merged
merged 21 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
backend: [ tensorflow, jax ]
include:
- backend: tensorflow
package: tensorflow
packages: tensorflow
- backend: jax
package: jax[cpu]
packages: tensorflow jax[cpu]
name: python-${{ matrix.python }}-${{ matrix.backend }}
steps:
- name: Checkout
Expand All @@ -26,7 +26,7 @@ jobs:
run: |
python3 -m pip install pylint
python3 -m pip install mypy
python3 -m pip install ${{ matrix.package }}
python3 -m pip install ${{ matrix.packages }}
- name: Test
run: |
./tests/install_from_local_and_test.sh ${{ matrix.backend }}
55 changes: 40 additions & 15 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from PIL import Image # type: ignore
from tqdm import tqdm # type: ignore

from keras_core import KerasTensor, Model # type: ignore

from ._download import get_default_weights_path
from ._image import preprocess_image, Preprocessing
from ._model import make_open_nsfw_model
from ._typing import NDFloat32Array

global_model: Optional[Model] = None


def predict_image(
image_path: str,
Expand All @@ -28,20 +32,40 @@ def predict_image(
"""
pil_image = Image.open(image_path)
image = preprocess_image(pil_image, preprocessing)
model = make_open_nsfw_model(weights_path=weights_path)
nsfw_probability = float(model(np.expand_dims(image, 0))[0][1])
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)
nsfw_probability = float(global_model(np.expand_dims(image, 0))[0][1])

if grad_cam_path is not None:
# TensorFlow will only be imported here.
from ._inspection import make_and_save_nsfw_grad_cam

make_and_save_nsfw_grad_cam(
pil_image, preprocessing, model, grad_cam_path, alpha
pil_image, preprocessing, global_model, grad_cam_path, alpha
)

return nsfw_probability


def _predict_images_in_batches(
model_: Model,
images: List[NDFloat32Array],
batch_size: int
) -> NDFloat32Array:
"""
It is on purpose not to use `model.predict` because many users would like
to use the API in loops. See here:
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[KerasTensor] = []
for i in range(0, len(images), batch_size):
batch = np.array(images[i: i + batch_size])
prediction_batches.append(model_(batch))
predictions: NDFloat32Array = np.concatenate(prediction_batches, axis=0)
return predictions


def predict_images(
image_paths: Sequence[str],
batch_size: int = 8,
Expand All @@ -54,12 +78,14 @@ def predict_images(
Pipeline from image paths to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
images = np.array([
images = [
preprocess_image(Image.open(image_path), preprocessing)
for image_path in image_paths
])
model = make_open_nsfw_model(weights_path=weights_path)
predictions = model.predict(images, batch_size=batch_size, verbose=0)
]
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)
predictions = _predict_images_in_batches(global_model, images, batch_size)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
Expand All @@ -68,7 +94,7 @@ def predict_images(

for image_path, grad_cam_path in zip(image_paths, grad_cam_paths):
make_and_save_nsfw_grad_cam(
Image.open(image_path), preprocessing, model,
Image.open(image_path), preprocessing, global_model,
grad_cam_path, alpha
)

Expand Down Expand Up @@ -115,7 +141,9 @@ def predict_video_frames(
cap = cv2.VideoCapture(video_path) # pylint: disable=no-member
fps = cap.get(cv2.CAP_PROP_FPS) # pylint: disable=no-member

model = make_open_nsfw_model(weights_path=weights_path)
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)

video_writer: Optional[cv2.VideoWriter] = None # pylint: disable=no-member
input_frames: List[NDFloat32Array] = []
Expand Down Expand Up @@ -152,12 +180,9 @@ def predict_video_frames(
input_frames.append(input_frame)

if frame_count == 1 or len(input_frames) >= aggregation_size:
prediction_batches: List[NDFloat32Array] = []
for i in range(0, len(input_frames), batch_size):
batch = np.array(input_frames[i: i + batch_size])
prediction_batches.append(model(batch))
predictions = np.concatenate(prediction_batches, axis=0)

predictions = _predict_images_in_batches(
global_model, input_frames, batch_size
)
agg_fn = _get_aggregation_fn(aggregation)
nsfw_probability = agg_fn(predictions[:, 1])
input_frames = []
Expand Down
2 changes: 1 addition & 1 deletion opennsfw2/_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf # type: ignore # pylint: disable=import-error
from keras_core import Model # type: ignore
from keras_core.preprocessing.image import array_to_img # type: ignore
from matplotlib import colormaps as cm # type: ignore
from matplotlib import colormaps as cm
from PIL import Image # type: ignore

from ._image import preprocess_image, Preprocessing
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ matplotlib>=3.0.0
numpy>=1.22.0
opencv-python>=4.0.0.0
Pillow>=8.0.0
scikit-image>=0.18.0
scikit-image==0.21.0 # Unpin when support for Python 3.8 will be dropped.
tqdm>=4.62
5 changes: 2 additions & 3 deletions tests/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import unittest
from typing import Optional, Sequence

import keras_core
from keras_core import backend as keras_backend

import opennsfw2 as n2

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL = n2.make_open_nsfw_model()
IMAGE_PATHS = [
os.path.join(BASE_DIR, "test_image_1.jpg"),
os.path.join(BASE_DIR, "test_image_2.jpg"),
Expand Down Expand Up @@ -38,7 +37,7 @@ def _assert(
self.assertTrue(os.path.exists(paths[i]))

def test_predict_images_yahoo_preprocessing(self) -> None:
if keras_core.backend.backend() == "tensorflow":
if keras_backend.backend() == "tensorflow":
grad_cam_paths = OUTPUT_GRAD_CAM_PATHS
else:
grad_cam_paths = None
Expand Down
2 changes: 1 addition & 1 deletion tests/install_from_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ set -e

python3 -m pip install --upgrade pip
python3 -m pip install -r requirements.txt
python3 -m pip install opencv-stubs # For mypy.
python3 -m pip install opencv-stubs matplotlib-stubs # For mypy.
python3 -m pip install .
2 changes: 1 addition & 1 deletion tests/run_code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ set -e

find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --errors-only
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 pylint --exit-zero
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 mypy --strict --implicit-reexport
find opennsfw2 -iname "*.py" | grep -v -e "__init__.py" | xargs -L 1 mypy --strict --implicit-reexport --disable-error-code attr-defined
find tests -iname "*.py" | xargs -L 1 python3 -m unittest