diff --git a/opennsfw2/_inference.py b/opennsfw2/_inference.py index a3fa7dc..9c59c9f 100644 --- a/opennsfw2/_inference.py +++ b/opennsfw2/_inference.py @@ -9,7 +9,7 @@ from PIL import Image # type: ignore from tqdm import tqdm # type: ignore -from keras_core import Model # type: ignore +from keras_core import KerasTensor, Model # type: ignore from ._download import get_default_weights_path from ._image import preprocess_image, Preprocessing @@ -58,7 +58,7 @@ def _predict_images_in_batches( to use the API in loops. See here: https://keras.io/api/models/model_training_apis/#predict-method """ - prediction_batches: List[NDFloat32Array] = [] + prediction_batches: List[KerasTensor] = [] for i in range(0, len(images), batch_size): batch = np.array(images[i: i + batch_size]) prediction_batches.append(model_(batch))