Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This directory contains ready-to-use Computer Vision application notebooks built
| :-- | :---- | :-------------------------------- |
| 1 | [ResNet](./resnet/) | Includes notebooks for ResNet finetuning on tasks such as chinese herbal classification |
| 2 | [U-Net](./unet/) | Includes notebooks for U-Net training on tasks such as segmentation |
| 3 | [SAM](./sam/) | Includes notebooks for using SAM to inference |

## Contributing New CV Applications

Expand Down
359 changes: 359 additions & 0 deletions cv/sam/inference_sam_segmentation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8285b580",
"metadata": {},
"source": [
"# 基于MindSpore 和 MindSpore NLP 的 Segment Anything Model(SAM)通用图像分割推理任务\n",
"\n",
"## 模型简介\n",
"\n",
"图像分割(Segmentation)旨在为图像中每个像素生成标签,输出与原图尺寸一致的掩码(mask)。\n",
"**Segment Anything Model(SAM)**(Meta AI, 2023)是一种“可提示(promptable)”的通用分割模型:可通过 **点 / 框 / 已有掩码** 等提示,在零样本条件下对任意目标生成分割结果。\n",
"\n",
"本 Notebook 参考 Hugging Face 的交互式推理流程,基于 **MindSpore** 与 **MindSpore NLP Transformers**,演示 **BBox 框提示** 的端到端推理与可视化。\n",
"\n",
"## 环境准备\n",
"\n",
"本案例的运行环境为:\n",
"\n",
"| Python | MindSpore | MindSpore NLP |\n",
"|--------|-----------|---------------|\n",
"| 3.10 | 2.7.0 | 0.5.1 |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b550a2d-7114-4a32-bcb6-4f6a8504da19",
"metadata": {},
"outputs": [],
"source": [
"!pip show mindspore\n",
"!pip show mindnlp"
]
},
{
"cell_type": "markdown",
"id": "f5a0708f",
"metadata": {},
"source": [
"如果你在如昇思大模型平台、华为云ModelArts、启智社区等算力平台的Jupyter在线编程环境中运行本案例,可取消如下代码的注释,进行依赖库安装:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "359d8d32-e0b4-484c-a0b5-0c4b2956253f",
"metadata": {},
"outputs": [],
"source": [
"# 安装mindspore==2.7.0版本,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量\n",
"# !pip uninstall mindspore -y\n",
"# %env MINDSPORE_VERSION=2.7.0\n",
"# !pip install mindspore==2.7.0 -i https://repo.mindspore.cn/pypi/simple --trusted-host repo.mindspore.cn --extra-index-url https://repo.huaweicloud.com/repository/pypi/simple\n",
"\n",
"# 安装mindnlp==0.5.1版本\n",
"# !pip uninstall mindnlp -y\n",
"# !pip install mindnlp==0.5.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f78fd9fb-2b63-4372-adeb-a27331a805c3",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"import numpy as np\n",
"import mindspore as ms\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"\n",
"from PIL import Image\n",
"from pathlib import Path\n",
"from mindnlp.transformers import SamModel, SamProcessor"
]
},
{
"cell_type": "markdown",
"id": "d735217e",
"metadata": {},
"source": [
"## 数据加载\n",
"\n",
"本案例使用 Meta 官方仓库提供的示例图片 `dog.jpg`。执行下方单元将图片下载到当前工作目录。\n",
"如你希望分割自己的图片,可替换 `image_url`,或直接将图片放在当前目录,并在后续单元修改 `img_path`。\n",
"\n",
"#### **数据下载**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "610eb8c7-73b3-47d3-ad43-698b7d9cfe1e",
"metadata": {},
"outputs": [],
"source": [
"def download_image(url: str, save_dir: str = \".\") -> str:\n",
" \"\"\"\n",
" 从 URL 下载图片到 save_dir,返回本地文件路径(字符串)。\n",
" \"\"\"\n",
" save_path = Path(save_dir)\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
"\n",
" filename = (url.rsplit(\"/\", 1)[-1] or \"image.jpg\")\n",
" dst = save_path / filename\n",
"\n",
" try:\n",
" resp = requests.get(url, timeout=30)\n",
" resp.raise_for_status()\n",
" dst.write_bytes(resp.content)\n",
" print(f\"示例图片已成功下载到: {dst}\")\n",
" return str(dst)\n",
" except Exception as e:\n",
" print(f\"下载示例图片时出错: {e}\")\n",
" return \"\"\n",
"\n",
"# 使用示例\n",
"image_url = \"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg\"\n",
"downloaded_image = download_image(image_url)\n",
"if downloaded_image:\n",
" print(f\"示例图片已保存为: {downloaded_image}\")"
]
},
{
"cell_type": "markdown",
"id": "36ab043f-076b-4e0b-b47d-8d8ceed8e875",
"metadata": {},
"source": [
"#### **数据加载**\n",
"\n",
"读取图片并设置提示框(BBox)\n",
"- `bbox` 采用 **原图坐标系**:`[x1, y1, x2, y2]`\n",
"- 你可以通过修改 `bbox` 来框住想要分割的目标区域\n",
"- 本单元会把输入框画在原图上,便于检查是否框选正确"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45a22bc6-9399-4044-9897-a61dc2f158ad",
"metadata": {},
"outputs": [],
"source": [
"img_path = \"dog.jpg\" # <-- change to your image path\n",
"assert os.path.exists(img_path), f\"Image not found: {img_path}\"\n",
"\n",
"# Custom bbox in original image coordinates [x1,y1,x2,y2]\n",
"bbox = [0, 217, 450, 800] # <-- change if needed, ensure in-bounds\n",
"bbox = [int(x) for x in bbox] # keep as Python ints for clarity\n",
"\n",
"image = Image.open(img_path).convert(\"RGB\")\n",
"W, H = image.size\n",
"print(\"图片尺寸:\", (W, H), \"| BBox:\", bbox)\n",
"\n",
"plt.figure(figsize=(6,4))\n",
"plt.imshow(image)\n",
"ax = plt.gca()\n",
"rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1],\n",
" linewidth=2, edgecolor='yellow', facecolor='none')\n",
"ax.add_patch(rect)\n",
"plt.title(\"Original image with input box\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "108b98f7-a054-4fa6-b4f9-2c139ca17355",
"metadata": {},
"source": [
"## 模型推理\n",
"\n",
"#### **加载模型**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d06aee8c",
"metadata": {},
"outputs": [],
"source": [
"MODEL_ID = \"facebook/sam-vit-base\"\n",
"CACHE_DIR = os.path.expanduser(\"~/.cache/mindnlp\") # or \"/tmp/mindnlp\"\n",
"\n",
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
"\n",
"# (可选)将其他库的缓存目录对齐到同一路径\n",
"os.environ[\"HF_HOME\"] = CACHE_DIR\n",
"os.environ[\"MINDNLP_HOME\"] = CACHE_DIR\n",
"\n",
"print(\"正在加载 SAM ...\")\n",
"processor = SamProcessor.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)\n",
"model = SamModel.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)\n",
"model.set_train(False)\n",
"print(\"加载完成!\")"
]
},
{
"cell_type": "markdown",
"id": "9fa86db6",
"metadata": {},
"source": [
"#### **传入图像进行推理**\n",
"\n",
"`SamProcessor` 会自动完成:\n",
"\n",
"- 图像缩放 / 填充(padding)\n",
"- 输入提示(BBox)整理为模型需要的张量格式\n",
"\n",
"模型输出包含:\n",
"\n",
"- `pred_masks`:候选掩码\n",
"- `iou_scores`:每张候选掩码的 IoU 评分(可用于挑选最佳结果)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "411f4369",
"metadata": {},
"outputs": [],
"source": [
"inputs = processor(images=image, input_boxes=[[bbox]], return_tensors=\"pt\")\n",
"outputs = model(**inputs)"
]
},
{
"cell_type": "markdown",
"id": "59162075-e83d-46cb-b208-8cf70a92511e",
"metadata": {},
"source": [
"## 结果可视化展示\n",
"#### **候选掩码可视化**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88964116-f7f7-4d6d-b0f7-9f05494e7421",
"metadata": {},
"outputs": [],
"source": [
"pred_masks = outputs.pred_masks # (B, boxes, M, 256, 256)\n",
"iou_scores = outputs.iou_scores # (B, boxes, M)\n",
"\n",
"print(\"pred_masks 形状:\", tuple(pred_masks.shape))\n",
"print(\"iou_scores 形状:\", tuple(iou_scores.shape))\n",
"\n",
"# preview low-res candidate masks (M masks at 256x256)\n",
"pm = pred_masks[0, 0].asnumpy() # (M, 256, 256)\n",
"scores_np = iou_scores[0, 0].asnumpy()\n",
"M = pm.shape[0]\n",
"\n",
"fig, axes = plt.subplots(1, M, figsize=(4*M, 4))\n",
"if M == 1:\n",
" axes = [axes]\n",
"for i in range(M):\n",
" axes[i].imshow(pm[i] > 0, cmap=\"gray\")\n",
" axes[i].set_title(f\"Mask {i}\\nIoU≈{float(scores_np[i]):.3f}\")\n",
" axes[i].axis(\"off\")\n",
"plt.suptitle(\"Low-res candidate masks (256×256)\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a73bb2f6",
"metadata": {},
"source": [
"#### **最佳掩码与原图叠加的可视化展示**\n",
"\n",
"`processor.post_process_masks` 会根据预处理时的缩放/填充信息以及原图尺寸,将掩码映射回 **原图尺寸**,便于直接叠加可视化或后续保存。\n",
"\n",
"最后将最佳掩码以半透明方式叠加到原图并绘制输入框。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77b7e2f7",
"metadata": {},
"outputs": [],
"source": [
"upsampled_list = processor.post_process_masks(\n",
" pred_masks,\n",
" inputs[\"original_sizes\"], # [[H_orig, W_orig]]\n",
" inputs[\"reshaped_input_sizes\"], # [[H_in, W_in ]]\n",
")\n",
"\n",
"m = upsampled_list[0] # (boxes, M, H, W) or (M, H, W)\n",
"if m.ndim == 4:\n",
" m = m[0]\n",
"\n",
"scores_np = iou_scores[0, 0].asnumpy()\n",
"best_idx = int(np.argmax(scores_np))\n",
"best_score = float(scores_np[best_idx])\n",
"best_mask = (m[best_idx].asnumpy() > 0) # (H, W) bool\n",
"\n",
"print(\"最佳索引:\", best_idx, \"| IoU:\", best_score)\n",
"print(\"最佳掩码形状 (H,W):\", best_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e908d71",
"metadata": {},
"outputs": [],
"source": [
"H, W = best_mask.shape\n",
"overlay = np.zeros((H, W, 4), dtype=np.uint8)\n",
"overlay[best_mask] = np.array([255, 0, 0, 115], dtype=np.uint8) # ~45% alpha\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plt.imshow(image)\n",
"plt.imshow(overlay)\n",
"import matplotlib.patches as patches\n",
"x1, y1, x2, y2 = bbox\n",
"rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,\n",
" linewidth=2, edgecolor='yellow', facecolor='none')\n",
"plt.gca().add_patch(rect)\n",
"plt.title(f\"IoU: {best_score:.3f}\")\n",
"plt.axis(\"off\")\n",
"plt.tight_layout()\n",
"save_path = \"dog_segmentation_result.png\"\n",
"plt.savefig(save_path, dpi=300, bbox_inches=\"tight\")\n",
"plt.show()\n",
"print(\"已保存:\", save_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "MindSpore (Python 3.10)",
"language": "python",
"name": "ms_env_py310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}