Skip to content

Commit

Permalink
Do not use model.predict method
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed Dec 1, 2023
1 parent f52d14a commit 4d09ce6
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def predict_image(
return nsfw_probability


def _predict_images_in_batches(
model_: Model,
images: List[NDFloat32Array],
batch_size: int
) -> NDFloat32Array:
"""
It is on purpose not to use `model.predict` because many users would like
to use the API in loops. See here:
https://keras.io/api/models/model_training_apis/#predict-method
"""
prediction_batches: List[NDFloat32Array] = []
for i in range(0, len(images), batch_size):
batch = np.array(images[i: i + batch_size])
prediction_batches.append(model_(batch))
predictions = np.concatenate(prediction_batches, axis=0)
return predictions


def predict_images(
image_paths: Sequence[str],
batch_size: int = 8,
Expand All @@ -60,14 +78,14 @@ def predict_images(
Pipeline from image paths to predicted NSFW probabilities.
Optionally generate and save the Grad-CAM plots.
"""
images = np.array([
images = [
preprocess_image(Image.open(image_path), preprocessing)
for image_path in image_paths
])
]
global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
predictions = model.predict(images, batch_size=batch_size, verbose=0)
predictions = _predict_images_in_batches(model, images, batch_size)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
Expand Down Expand Up @@ -162,12 +180,9 @@ def predict_video_frames(
input_frames.append(input_frame)

if frame_count == 1 or len(input_frames) >= aggregation_size:
prediction_batches: List[NDFloat32Array] = []
for i in range(0, len(input_frames), batch_size):
batch = np.array(input_frames[i: i + batch_size])
prediction_batches.append(model(batch))
predictions = np.concatenate(prediction_batches, axis=0)

predictions = _predict_images_in_batches(
model, input_frames, batch_size
)
agg_fn = _get_aggregation_fn(aggregation)
nsfw_probability = agg_fn(predictions[:, 1])
input_frames = []
Expand Down

0 comments on commit 4d09ce6

Please sign in to comment.