Skip to content

Commit

Permalink
Improve batching logic for predicting images from image paths
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed Dec 7, 2023
1 parent a4fef36 commit 6b76f40
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,25 @@ def predict_image(
return nsfw_probability


def _predict_images_in_batches(
def _predict_from_image_paths_in_batches(
model_: Model,
images: List[NDFloat32Array],
batch_size: int
image_paths: Sequence[str],
batch_size: int,
preprocessing: Preprocessing
) -> NDFloat32Array:
"""
It is on purpose not to use `model.predict` because many users would like
to use the API in loops. See here:
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[KerasTensor] = []
for i in range(0, len(images), batch_size):
batch = np.array(images[i: i + batch_size])
prediction_batches.append(model_(batch))
for i in range(0, len(image_paths), batch_size):
path_batch = image_paths[i: i + batch_size]
image_batch = [
preprocess_image(Image.open(path), preprocessing)
for path in path_batch
]
prediction_batches.append(model_(np.array(image_batch)))
predictions: NDFloat32Array = np.concatenate(prediction_batches, axis=0)
return predictions

Expand All @@ -78,14 +83,12 @@ def predict_images(
Pipeline from image paths to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
images = [
preprocess_image(Image.open(image_path), preprocessing)
for image_path in image_paths
]
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)
predictions = _predict_images_in_batches(global_model, images, batch_size)
predictions = _predict_from_image_paths_in_batches(
global_model, image_paths, batch_size, preprocessing
)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
Expand Down Expand Up @@ -124,6 +127,19 @@ def fn(x: NDFloat32Array) -> float:
return fn


def _predict_preprocessed_images_in_batches(
model_: Model,
images: List[NDFloat32Array],
batch_size: int
) -> NDFloat32Array:
prediction_batches: List[KerasTensor] = []
for i in range(0, len(images), batch_size):
batch = np.array(images[i: i + batch_size])
prediction_batches.append(model_(batch))
predictions: NDFloat32Array = np.concatenate(prediction_batches, axis=0)
return predictions


def predict_video_frames(
video_path: str,
frame_interval: int = 8,
Expand Down Expand Up @@ -180,7 +196,7 @@ def predict_video_frames(
input_frames.append(input_frame)

if frame_count == 1 or len(input_frames) >= aggregation_size:
predictions = _predict_images_in_batches(
predictions = _predict_preprocessed_images_in_batches(
global_model, input_frames, batch_size
)
agg_fn = _get_aggregation_fn(aggregation)
Expand Down

0 comments on commit 6b76f40

Please sign in to comment.