From 9f7557e6da5f664d216b1ba1c3a9d13083360e49 Mon Sep 17 00:00:00 2001 From: Imen Masmoudi <83138804+ImenMasmoudiEm@users.noreply.github.com> Date: Wed, 9 Aug 2023 20:12:17 +0100 Subject: [PATCH] Update visualization/plot_image_gallery.py Update visualization/plot_image_gallery.py The guide has an issue when calling the visualization.plot_image_gallery function with its input being a list, it has to be of type array because it will be calling for the shape of it later in the function. The error occurred in the second code cell in the Beginner section, Ligne 155 in the object_detection_keras_cv.py file. This update should solve it. --- keras_cv/visualization/plot_image_gallery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 1d98c20f53..3ab9f6ae55 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -133,7 +133,7 @@ def plot_image_gallery( ) # batch_size from within passed `tf.data.Dataset` else: batch_size = ( - images.shape[0] if len(images.shape) == 4 else 1 + np.asarray(images).shape[0] if len(images.shape) == 4 else 1 ) # batch_size from np.array or single image rows = rows or int(math.ceil(math.sqrt(batch_size)))