diff --git a/examples/sam3_image_batched_inference.ipynb b/examples/sam3_image_batched_inference.ipynb index dddbfd83..96deb864 100644 --- a/examples/sam3_image_batched_inference.ipynb +++ b/examples/sam3_image_batched_inference.ipynb @@ -81,7 +81,8 @@ "torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n", "\n", "# inference mode for the whole notebook. Disable if you need gradients\n", - "torch.inference_mode().__enter__()\n" + "ctx = torch.inference_mode()\n", + "ctx.__enter__()" ] }, { @@ -165,7 +166,7 @@ " # In practice you're free to set any size you want, just edit the rest of the function\n", " assert len(datapoint.images) == 1, \"please set the image first\"\n", "\n", - " w, h = datapoint.images[0].size\n", + " h, w = datapoint.images[0].size\n", " datapoint.find_queries.append(\n", " FindQueryLoaded(\n", " query_text=text_query,\n", @@ -177,7 +178,7 @@ " coco_image_id=GLOBAL_COUNTER,\n", " original_image_id=GLOBAL_COUNTER,\n", " original_category_id=1,\n", - " original_size=[w, h],\n", + " original_size=[h, w],\n", " object_id=0,\n", " frame_index=0,\n", " )\n", @@ -208,7 +209,7 @@ " labels = torch.tensor(labels, dtype=torch.bool).view(-1)\n", " if not labels.any().item() and text_prompt==\"visual\":\n", " print(\"Warning: you provided no positive box, nor any text prompt. The prompt is ambiguous and the results will be undefined\")\n", - " w, h = datapoint.images[0].size\n", + " h, w = datapoint.images[0].size\n", " datapoint.find_queries.append(\n", " FindQueryLoaded(\n", " query_text=text_prompt,\n", @@ -222,7 +223,7 @@ " coco_image_id=GLOBAL_COUNTER,\n", " original_image_id=GLOBAL_COUNTER,\n", " original_category_id=1,\n", - " original_size=[w, h],\n", + " original_size=[h, w],\n", " object_id=0,\n", " frame_index=0,\n", " )\n",