Skip to content

Commit

Permalink
Use global model
Browse files Browse the repository at this point in the history
  • Loading branch information
bhky committed Dec 1, 2023
1 parent 2776f5f commit 6db3c35
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions opennsfw2/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from ._model import make_open_nsfw_model
from ._typing import NDFloat32Array

model: Optional[Model] = None


def predict_image(
image_path: str,
preprocessing: Preprocessing = Preprocessing.YAHOO,
model: Optional[Model] = None,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_path: Optional[str] = None,
alpha: float = 0.8
Expand All @@ -31,6 +32,7 @@ def predict_image(
"""
pil_image = Image.open(image_path)
image = preprocess_image(pil_image, preprocessing)
global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
nsfw_probability = float(model(np.expand_dims(image, 0))[0][1])
Expand All @@ -50,7 +52,6 @@ def predict_images(
image_paths: Sequence[str],
batch_size: int = 8,
preprocessing: Preprocessing = Preprocessing.YAHOO,
model: Optional[Model] = None,
weights_path: Optional[str] = get_default_weights_path(),
grad_cam_paths: Optional[Sequence[str]] = None,
alpha: float = 0.8
Expand All @@ -63,6 +64,7 @@ def predict_images(
preprocess_image(Image.open(image_path), preprocessing)
for image_path in image_paths
])
global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)
predictions = model.predict(images, batch_size=batch_size, verbose=0)
Expand Down Expand Up @@ -112,7 +114,6 @@ def predict_video_frames(
batch_size: int = 8,
output_video_path: Optional[str] = None,
preprocessing: Preprocessing = Preprocessing.YAHOO,
model: Optional[Model] = None,
weights_path: Optional[str] = get_default_weights_path(),
progress_bar: bool = True
) -> Tuple[List[float], List[float]]:
Expand All @@ -122,6 +123,7 @@ def predict_video_frames(
cap = cv2.VideoCapture(video_path) # pylint: disable=no-member
fps = cap.get(cv2.CAP_PROP_FPS) # pylint: disable=no-member

global model
if model is None:
model = make_open_nsfw_model(weights_path=weights_path)

Expand Down

0 comments on commit 6db3c35

Please sign in to comment.