diff --git a/opennsfw2/_inference.py b/opennsfw2/_inference.py index 241bd64..bd69e73 100644 --- a/opennsfw2/_inference.py +++ b/opennsfw2/_inference.py @@ -35,7 +35,7 @@ def _load_pil_image(image_handle: Union[str, Image.Image]) -> Image.Image: def predict_image( - image_handle: str, + image_handle: Union[str, Image.Image], preprocessing: Preprocessing = Preprocessing.YAHOO, weights_path: Optional[str] = get_default_weights_path(), grad_cam_path: Optional[str] = None, @@ -64,7 +64,7 @@ def predict_image( def _predict_from_image_handles_in_batches( model_: Model, - image_handles: Sequence[str], + image_handles: Union[Sequence[str], Sequence[Image.Image]], batch_size: int, preprocessing: Preprocessing ) -> NDFloat32Array: