Skip to content

Commit

Permalink
Rename global model var
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed Dec 1, 2023
1 parent 4d09ce6 commit 86c0a81
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._model import make_open_nsfw_model
from ._typing import NDFloat32Array

model: Optional[Model] = None
global_model: Optional[Model] = None


def predict_image(
Expand All @@ -32,17 +32,17 @@ def predict_image(
"""
pil_image = Image.open(image_path)
image = preprocess_image(pil_image, preprocessing)
global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
nsfw_probability = float(model(np.expand_dims(image, 0))[0][1])
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)
nsfw_probability = float(global_model(np.expand_dims(image, 0))[0][1])

if grad_cam_path is not None:
# TensorFlow will only be imported here.
from ._inspection import make_and_save_nsfw_grad_cam

make_and_save_nsfw_grad_cam(
pil_image, preprocessing, model, grad_cam_path, alpha
pil_image, preprocessing, global_model, grad_cam_path, alpha
)

return nsfw_probability
Expand Down Expand Up @@ -82,10 +82,10 @@ def predict_images(
preprocess_image(Image.open(image_path), preprocessing)
for image_path in image_paths
]
global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
predictions = _predict_images_in_batches(model, images, batch_size)
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)
predictions = _predict_images_in_batches(global_model, images, batch_size)
nsfw_probabilities: List[float] = predictions[:, 1].tolist()

if grad_cam_paths is not None:
Expand All @@ -94,7 +94,7 @@ def predict_images(

for image_path, grad_cam_path in zip(image_paths, grad_cam_paths):
make_and_save_nsfw_grad_cam(
Image.open(image_path), preprocessing, model,
Image.open(image_path), preprocessing, global_model,
grad_cam_path, alpha
)

Expand Down Expand Up @@ -141,9 +141,9 @@ def predict_video_frames(
cap = cv2.VideoCapture(video_path) # pylint: disable=no-member
fps = cap.get(cv2.CAP_PROP_FPS) # pylint: disable=no-member

global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
global global_model
if global_model is None:
global_model = make_open_nsfw_model(weights_path=weights_path)

video_writer: Optional[cv2.VideoWriter] = None # pylint: disable=no-member
input_frames: List[NDFloat32Array] = []
Expand Down Expand Up @@ -181,7 +181,7 @@ def predict_video_frames(

if frame_count == 1 or len(input_frames) >= aggregation_size:
predictions = _predict_images_in_batches(
model, input_frames, batch_size
global_model, input_frames, batch_size
)
agg_fn = _get_aggregation_fn(aggregation)
nsfw_probability = agg_fn(predictions[:, 1])
Expand Down

0 comments on commit 86c0a81

Please sign in to comment.