diff --git a/cv/README.md b/cv/README.md index d2ebb78..d8edae0 100644 --- a/cv/README.md +++ b/cv/README.md @@ -8,6 +8,7 @@ This directory contains ready-to-use Computer Vision application notebooks built | :-- | :---- | :-------------------------------- | | 1 | [ResNet](./resnet/) | Includes notebooks for ResNet finetuning on tasks such as chinese herbal classification | | 2 | [U-Net](./unet/) | Includes notebooks for U-Net training on tasks such as segmentation | +| 3 | [SAM](./sam/) | Includes notebooks for using SAM to inference | ## Contributing New CV Applications diff --git a/cv/sam/inference_sam_segmentation.ipynb b/cv/sam/inference_sam_segmentation.ipynb new file mode 100644 index 0000000..278e9dd --- /dev/null +++ b/cv/sam/inference_sam_segmentation.ipynb @@ -0,0 +1,359 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8285b580", + "metadata": {}, + "source": [ + "# 基于MindSpore 和 MindSpore NLP 的 Segment Anything Model(SAM)通用图像分割推理任务\n", + "\n", + "## 模型简介\n", + "\n", + "图像分割(Segmentation)旨在为图像中每个像素生成标签,输出与原图尺寸一致的掩码(mask)。\n", + "**Segment Anything Model(SAM)**(Meta AI, 2023)是一种“可提示(promptable)”的通用分割模型:可通过 **点 / 框 / 已有掩码** 等提示,在零样本条件下对任意目标生成分割结果。\n", + "\n", + "本 Notebook 参考 Hugging Face 的交互式推理流程,基于 **MindSpore** 与 **MindSpore NLP Transformers**,演示 **BBox 框提示** 的端到端推理与可视化。\n", + "\n", + "## 环境准备\n", + "\n", + "本案例的运行环境为:\n", + "\n", + "| Python | MindSpore | MindSpore NLP |\n", + "|--------|-----------|---------------|\n", + "| 3.10 | 2.7.0 | 0.5.1 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b550a2d-7114-4a32-bcb6-4f6a8504da19", + "metadata": {}, + "outputs": [], + "source": [ + "!pip show mindspore\n", + "!pip show mindnlp" + ] + }, + { + "cell_type": "markdown", + "id": "f5a0708f", + "metadata": {}, + "source": [ + "如果你在如昇思大模型平台、华为云ModelArts、启智社区等算力平台的Jupyter在线编程环境中运行本案例,可取消如下代码的注释,进行依赖库安装:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "359d8d32-e0b4-484c-a0b5-0c4b2956253f", + "metadata": {}, + "outputs": [], + "source": [ + "# 安装mindspore==2.7.0版本,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量\n", + "# !pip uninstall mindspore -y\n", + "# %env MINDSPORE_VERSION=2.7.0\n", + "# !pip install mindspore==2.7.0 -i https://repo.mindspore.cn/pypi/simple --trusted-host repo.mindspore.cn --extra-index-url https://repo.huaweicloud.com/repository/pypi/simple\n", + "\n", + "# 安装mindnlp==0.5.1版本\n", + "# !pip uninstall mindnlp -y\n", + "# !pip install mindnlp==0.5.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f78fd9fb-2b63-4372-adeb-a27331a805c3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "import numpy as np\n", + "import mindspore as ms\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "\n", + "from PIL import Image\n", + "from pathlib import Path\n", + "from mindnlp.transformers import SamModel, SamProcessor" + ] + }, + { + "cell_type": "markdown", + "id": "d735217e", + "metadata": {}, + "source": [ + "## 数据加载\n", + "\n", + "本案例使用 Meta 官方仓库提供的示例图片 `dog.jpg`。执行下方单元将图片下载到当前工作目录。\n", + "如你希望分割自己的图片,可替换 `image_url`,或直接将图片放在当前目录,并在后续单元修改 `img_path`。\n", + "\n", + "#### **数据下载**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "610eb8c7-73b3-47d3-ad43-698b7d9cfe1e", + "metadata": {}, + "outputs": [], + "source": [ + "def download_image(url: str, save_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " 从 URL 下载图片到 save_dir,返回本地文件路径(字符串)。\n", + " \"\"\"\n", + " save_path = Path(save_dir)\n", + " save_path.mkdir(parents=True, exist_ok=True)\n", + "\n", + " filename = (url.rsplit(\"/\", 1)[-1] or \"image.jpg\")\n", + " dst = save_path / filename\n", + "\n", + " try:\n", + " resp = requests.get(url, timeout=30)\n", + " resp.raise_for_status()\n", + " dst.write_bytes(resp.content)\n", + " print(f\"示例图片已成功下载到: {dst}\")\n", + " return str(dst)\n", + " except Exception as e:\n", + " print(f\"下载示例图片时出错: {e}\")\n", + " return \"\"\n", + "\n", + "# 使用示例\n", + "image_url = \"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg\"\n", + "downloaded_image = download_image(image_url)\n", + "if downloaded_image:\n", + " print(f\"示例图片已保存为: {downloaded_image}\")" + ] + }, + { + "cell_type": "markdown", + "id": "36ab043f-076b-4e0b-b47d-8d8ceed8e875", + "metadata": {}, + "source": [ + "#### **数据加载**\n", + "\n", + "读取图片并设置提示框(BBox)\n", + "- `bbox` 采用 **原图坐标系**:`[x1, y1, x2, y2]`\n", + "- 你可以通过修改 `bbox` 来框住想要分割的目标区域\n", + "- 本单元会把输入框画在原图上,便于检查是否框选正确" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45a22bc6-9399-4044-9897-a61dc2f158ad", + "metadata": {}, + "outputs": [], + "source": [ + "img_path = \"dog.jpg\" # <-- change to your image path\n", + "assert os.path.exists(img_path), f\"Image not found: {img_path}\"\n", + "\n", + "# Custom bbox in original image coordinates [x1,y1,x2,y2]\n", + "bbox = [0, 217, 450, 800] # <-- change if needed, ensure in-bounds\n", + "bbox = [int(x) for x in bbox] # keep as Python ints for clarity\n", + "\n", + "image = Image.open(img_path).convert(\"RGB\")\n", + "W, H = image.size\n", + "print(\"图片尺寸:\", (W, H), \"| BBox:\", bbox)\n", + "\n", + "plt.figure(figsize=(6,4))\n", + "plt.imshow(image)\n", + "ax = plt.gca()\n", + "rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1],\n", + " linewidth=2, edgecolor='yellow', facecolor='none')\n", + "ax.add_patch(rect)\n", + "plt.title(\"Original image with input box\")\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "108b98f7-a054-4fa6-b4f9-2c139ca17355", + "metadata": {}, + "source": [ + "## 模型推理\n", + "\n", + "#### **加载模型**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d06aee8c", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"facebook/sam-vit-base\"\n", + "CACHE_DIR = os.path.expanduser(\"~/.cache/mindnlp\") # or \"/tmp/mindnlp\"\n", + "\n", + "os.makedirs(CACHE_DIR, exist_ok=True)\n", + "\n", + "# (可选)将其他库的缓存目录对齐到同一路径\n", + "os.environ[\"HF_HOME\"] = CACHE_DIR\n", + "os.environ[\"MINDNLP_HOME\"] = CACHE_DIR\n", + "\n", + "print(\"正在加载 SAM ...\")\n", + "processor = SamProcessor.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)\n", + "model = SamModel.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)\n", + "model.set_train(False)\n", + "print(\"加载完成!\")" + ] + }, + { + "cell_type": "markdown", + "id": "9fa86db6", + "metadata": {}, + "source": [ + "#### **传入图像进行推理**\n", + "\n", + "`SamProcessor` 会自动完成:\n", + "\n", + "- 图像缩放 / 填充(padding)\n", + "- 输入提示(BBox)整理为模型需要的张量格式\n", + "\n", + "模型输出包含:\n", + "\n", + "- `pred_masks`:候选掩码\n", + "- `iou_scores`:每张候选掩码的 IoU 评分(可用于挑选最佳结果)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "411f4369", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = processor(images=image, input_boxes=[[bbox]], return_tensors=\"pt\")\n", + "outputs = model(**inputs)" + ] + }, + { + "cell_type": "markdown", + "id": "59162075-e83d-46cb-b208-8cf70a92511e", + "metadata": {}, + "source": [ + "## 结果可视化展示\n", + "#### **候选掩码可视化**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88964116-f7f7-4d6d-b0f7-9f05494e7421", + "metadata": {}, + "outputs": [], + "source": [ + "pred_masks = outputs.pred_masks # (B, boxes, M, 256, 256)\n", + "iou_scores = outputs.iou_scores # (B, boxes, M)\n", + "\n", + "print(\"pred_masks 形状:\", tuple(pred_masks.shape))\n", + "print(\"iou_scores 形状:\", tuple(iou_scores.shape))\n", + "\n", + "# preview low-res candidate masks (M masks at 256x256)\n", + "pm = pred_masks[0, 0].asnumpy() # (M, 256, 256)\n", + "scores_np = iou_scores[0, 0].asnumpy()\n", + "M = pm.shape[0]\n", + "\n", + "fig, axes = plt.subplots(1, M, figsize=(4*M, 4))\n", + "if M == 1:\n", + " axes = [axes]\n", + "for i in range(M):\n", + " axes[i].imshow(pm[i] > 0, cmap=\"gray\")\n", + " axes[i].set_title(f\"Mask {i}\\nIoU≈{float(scores_np[i]):.3f}\")\n", + " axes[i].axis(\"off\")\n", + "plt.suptitle(\"Low-res candidate masks (256×256)\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a73bb2f6", + "metadata": {}, + "source": [ + "#### **最佳掩码与原图叠加的可视化展示**\n", + "\n", + "`processor.post_process_masks` 会根据预处理时的缩放/填充信息以及原图尺寸,将掩码映射回 **原图尺寸**,便于直接叠加可视化或后续保存。\n", + "\n", + "最后将最佳掩码以半透明方式叠加到原图并绘制输入框。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77b7e2f7", + "metadata": {}, + "outputs": [], + "source": [ + "upsampled_list = processor.post_process_masks(\n", + " pred_masks,\n", + " inputs[\"original_sizes\"], # [[H_orig, W_orig]]\n", + " inputs[\"reshaped_input_sizes\"], # [[H_in, W_in ]]\n", + ")\n", + "\n", + "m = upsampled_list[0] # (boxes, M, H, W) or (M, H, W)\n", + "if m.ndim == 4:\n", + " m = m[0]\n", + "\n", + "scores_np = iou_scores[0, 0].asnumpy()\n", + "best_idx = int(np.argmax(scores_np))\n", + "best_score = float(scores_np[best_idx])\n", + "best_mask = (m[best_idx].asnumpy() > 0) # (H, W) bool\n", + "\n", + "print(\"最佳索引:\", best_idx, \"| IoU:\", best_score)\n", + "print(\"最佳掩码形状 (H,W):\", best_mask.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e908d71", + "metadata": {}, + "outputs": [], + "source": [ + "H, W = best_mask.shape\n", + "overlay = np.zeros((H, W, 4), dtype=np.uint8)\n", + "overlay[best_mask] = np.array([255, 0, 0, 115], dtype=np.uint8) # ~45% alpha\n", + "\n", + "plt.figure(figsize=(8, 6))\n", + "plt.imshow(image)\n", + "plt.imshow(overlay)\n", + "import matplotlib.patches as patches\n", + "x1, y1, x2, y2 = bbox\n", + "rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,\n", + " linewidth=2, edgecolor='yellow', facecolor='none')\n", + "plt.gca().add_patch(rect)\n", + "plt.title(f\"IoU: {best_score:.3f}\")\n", + "plt.axis(\"off\")\n", + "plt.tight_layout()\n", + "save_path = \"dog_segmentation_result.png\"\n", + "plt.savefig(save_path, dpi=300, bbox_inches=\"tight\")\n", + "plt.show()\n", + "print(\"已保存:\", save_path)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "MindSpore (Python 3.10)", + "language": "python", + "name": "ms_env_py310" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}