diff --git a/llm/README.md b/llm/README.md index 679a681..767d964 100644 --- a/llm/README.md +++ b/llm/README.md @@ -11,7 +11,7 @@ The following notebooks are actively maintained in sync with MindSpore and MindS | No. | Model | Description | | :-- | :---- | :----------------------- | | 1 | [t5](./t5/) | Includes notebooks for T5 finetuning and inference on tasks such as email summarization | - +| 2 | [esmforproteinfolding](./esmforproteinfolding/) | Includes notebooks for EsmForProteinFolding finetuning and inference on tasks | ### Community-Driven / Legacy Applications Addtional community-contributed or legacy notebooks are stored under the [legacy](./legacy/) directory. These notebooks are not actively maintained and may rely on older APIs. diff --git a/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb b/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb new file mode 100644 index 0000000..608b357 --- /dev/null +++ b/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb @@ -0,0 +1,494 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "829cca7c", + "metadata": {}, + "source": [ + "# 基于 MindSpore NLP 实现 ESMFold 蛋白质结构预测\n", + "\n", + "## 1. 案例简介\n", + "蛋白质折叠(Protein Folding)是计算生物学和结构生物学中的核心问题:给定一条氨基酸序列,如何准确预测其三维空间构象,直接关系到蛋白质的功能解析、新药设计以及变体致病性的评估。传统方法依赖同源建模和复杂的数据库检索,往往计算耗时且对外部结构数据库高度依赖;而近年来以 AlphaFold、ESMFold 为代表的深度学习模型,已经将蛋白质结构预测的准确度与速度提升到一个全新的水平。\n", + "\n", + "\n", + "**ESMFold** 是 Meta AI 发布的蛋白质折叠预测模型。与 AlphaFold2 不同,ESMFold 基于大规模蛋白质语言模型(ESM-2)进行端到端预测,**不需要** 依赖外部数据库(MSA 检索),推理速度提升了约 60 倍,且能保持极高的预测精度。这使得在个人计算设备上快速预测蛋白质结构成为可能。\n", + "\n", + "### 案例任务\n", + "本案例基于 **MindSpore NLP** 套件,实现以下流程:\n", + "1. **环境部署**:配置 MindSpore 及生物计算依赖。\n", + "2. **模型加载**:一键加载 `facebook/esmfold_v1` 预训练权重。\n", + "3. **序列推理**:输入氨基酸序列,直接预测三维原子坐标。\n", + "4. **PDB 生成**:将模型输出保存为标准的 `.pdb` 文件。\n", + "5. **3D 可视化**:在 Notebook 中交互式查看蛋白质结构。\n", + "6. **微调演示**:展示参数高效微调(PEFT)的代码实现。" + ] + }, + { + "cell_type": "markdown", + "id": "60fe0cd1", + "metadata": {}, + "source": [ + "## 2. 环境准备\n", + "本案例的运行环境为:\n", + "| Python | MindSpore | MindSpore NLP | py3Dmol | biopython | diffusers |\n", + "| ----- | ----- | ----- | ----- | ----- | ----- |\n", + "| 3.10.19 | 2.7.0 | 0.5.1 | 2.5.3 | 1.86 | 0.35.2 | \n", + "\n", + "> 重要:在 Ascend 环境下启动 Notebook 之前,请确保已加载 CANN/driver 环境变量(否则会出现 `libmindspore_ascend.so` / `libge_runner.so` 找不到的问题)。\n", + "> \n", + "> - 典型做法:在启动 `jupyter` 之前执行 `source /usr/local/Ascend/ascend-toolkit/set_env.sh`,并把 `/usr/local/Ascend/driver/lib64/driver`、`/usr/local/Ascend/driver/lib64/common` 加入 `LD_LIBRARY_PATH`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f95cc806-1cfb-4308-931a-4291ae0b1ecc", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install mindspore==2.7.0\n", + "#!pip install mindnlp==0.5.1\n", + "#!pip install py3Dmol\n", + "#!pip install biopython\n", + "#强制修改diffusers版本为0.35.2,防止不兼容\n", + "#!pip install diffusers==0.35.2\n", + "print(\"环境依赖:mindspore / mindnlp / py3Dmol / biopython\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc1c20ef", + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore as ms\n", + "\n", + "# 设置运行环境(默认 Ascend 环境)\n", + "# 推荐使用 PYNATIVE_MODE 便于动态形状;如需极致性能可切换 GRAPH_MODE+AMP\n", + "# 默认情况下,MindSpore 会自动选择动态图模式,故不作显式指定\n", + "# ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"Ascend\")\n", + "\n", + "print(f\"MindSpore Version: {ms.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f4014443", + "metadata": {}, + "source": [ + "## 3. 加载 ESMFold 模型\n", + "\n", + "我们使用 MindSpore NLP 的 `EsmForProteinFolding` 接口。该接口与 Hugging Face `transformers` 设计保持一致,支持直接加载 PyTorch 权重并自动转换。\n", + "\n", + "* **Tokenizer**: 负责将氨基酸序列转换为模型可识别的 Token ID。\n", + "* **Model**: 包含 ESM-2 主干网络和 Folding Trunk 结构模块。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b5efbe", + "metadata": {}, + "outputs": [], + "source": [ + "from mindnlp.transformers import EsmForProteinFolding, AutoTokenizer\n", + "\n", + "MODEL_NAME = \"facebook/esmfold_v1\"\n", + "\n", + "print(f\"正在加载模型: {MODEL_NAME} ...\")\n", + "# 加载分词器和模型\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", + "model = EsmForProteinFolding.from_pretrained(MODEL_NAME)\n", + "\n", + "# 可选:通过 chunk_size 降低显存占用(长序列/显存紧张时开启)\n", + "model.trunk.set_chunk_size(64)\n", + "\n", + "print(\"✅ 模型加载完成!\")" + ] + }, + { + "cell_type": "markdown", + "id": "d0205a80", + "metadata": {}, + "source": [ + "## 4. 蛋白质结构预测 \n", + "\n", + "ESMFold 支持直接输入单条或多条氨基酸序列。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66d4c37a", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import mindtorch as torch\n", + "\n", + "# 示例序列(人类 GNAT1 蛋白)\n", + "test_protein = \"MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF\"\n", + "\n", + "print(f\"输入序列长度: {len(test_protein)}\")\n", + "print(f\"序列前 60 个氨基酸: {test_protein[:60]}...\")\n", + "\n", + "# 推理:将模型与输入显式放到 NPU(否则会落到 CPU 分发,出现算子报错)\n", + "model.set_train(False)\n", + "if hasattr(model, \"to\"):\n", + " model = model.to(\"npu\")\n", + "\n", + "# 1) 推理:获取 positions / pLDDT(返回可能是 dict 或对象,这里做兼容)\n", + "start_time = time.time()\n", + "tokenized_input = tokenizer([test_protein], return_tensors=\"pt\", add_special_tokens=False)\n", + "for k, v in list(tokenized_input.items()):\n", + " if hasattr(v, \"to\"):\n", + " tokenized_input[k] = v.to(\"npu\")\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**tokenized_input)\n", + "end_time = time.time()\n", + "\n", + "# 兼容 dict / 对象两种返回\n", + "plddt = outputs[\"plddt\"] if isinstance(outputs, dict) and \"plddt\" in outputs else getattr(outputs, \"plddt\", None)\n", + "positions = outputs[\"positions\"] if isinstance(outputs, dict) and \"positions\" in outputs else getattr(outputs, \"positions\", None)\n", + "\n", + "print(f\"推理耗时: {end_time - start_time:.2f} 秒\")\n", + "if positions is not None:\n", + " print(f\"positions 形状: {positions.shape}(常见为 [recycles, B, L, atom14, 3])\")\n", + "if plddt is not None:\n", + " # 该实现的 pLDDT 通常是 0~1,换算到 0~100 更直观\n", + " plddt_mean = float(plddt.mean().cpu().numpy()) if hasattr(plddt, \"cpu\") else float(plddt.mean().numpy())\n", + " print(f\"平均置信度 (pLDDT): {plddt_mean * 100:.2f} / 100\")\n", + "\n", + "# 2) mindnlp支持直接生成 PDB\n", + "pdb_string = str(model.infer_pdb(test_protein))\n", + "print(f\"PDB 预览(前 400 字符):\\n{pdb_string[:400]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f954f054", + "metadata": {}, + "source": [ + "## 5. (进阶) 性能优化:混合精度推理 (AMP)\n", + "\n", + "ESMFold 是一个拥有约 30 亿参数的巨型模型。在推理过程中,使用 **混合精度 (Automatic Mixed Precision, AMP)** 可以显著降低显存占用并提升计算速度。\n", + "\n", + "本案例的推理后端基于 `mindtorch`,因此这里使用 `mindtorch.amp.autocast` 在 **NPU 上开启 bfloat16(优先)/float16 自动混合精度**,并与 FP32 推理耗时做对比。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07e4019c", + "metadata": {}, + "outputs": [], + "source": [ + "import mindtorch as torch\n", + "from mindtorch import autocast\n", + "\n", + "print(\"=== 混合精度推理:FP32 vs AMP(fp16) ===\")\n", + "\n", + "def _to_cpu_float(x):\n", + " \"\"\"将 mindtorch 张量均值转为 Python float,兼容 cpu() 接口。\"\"\"\n", + " if hasattr(x, \"cpu\"):\n", + " return float(x.mean().cpu().numpy())\n", + " return float(x.mean().numpy())\n", + "\n", + "# 复用上一节已经准备好的 model、tokenized_input\n", + "# 确保在 NPU\n", + "model.set_train(False)\n", + "if hasattr(model, \"to\"):\n", + " model = model.to(\"npu\")\n", + "\n", + "# 预热一次(避免首轮编译/缓存影响计时)\n", + "with torch.no_grad():\n", + " _ = model(**tokenized_input)\n", + "\n", + "# FP32\n", + "with torch.no_grad():\n", + " t0 = time.time()\n", + " out_fp32 = model(**tokenized_input)\n", + " t_fp32 = time.time() - t0\n", + "\n", + "plddt_fp32 = out_fp32[\"plddt\"] if isinstance(out_fp32, dict) else getattr(out_fp32, \"plddt\", None)\n", + "if plddt_fp32 is not None:\n", + " print(f\"FP32 pLDDT(mean): {_to_cpu_float(plddt_fp32) * 100:.2f} / 100\")\n", + "print(f\"FP32 推理耗时: {t_fp32:.2f} 秒\")\n", + "\n", + "# AMP:Ascend 上优先使用 bfloat16(更稳定);float16 在部分算子上可能不稳定\n", + "amp_dtype = torch.bfloat16 if hasattr(torch, \"bfloat16\") else torch.float16\n", + "print(f\"AMP dtype: {amp_dtype}\")\n", + "\n", + "with torch.no_grad(), autocast(device_type=\"npu\", dtype=amp_dtype):\n", + " t1 = time.time()\n", + " out_amp = model(**tokenized_input)\n", + " t_amp = time.time() - t1\n", + "\n", + "plddt_amp = out_amp[\"plddt\"] if isinstance(out_amp, dict) else getattr(out_amp, \"plddt\", None)\n", + "if plddt_amp is not None:\n", + " print(f\"AMP({amp_dtype}) pLDDT(mean): {_to_cpu_float(plddt_amp) * 100:.2f} / 100\")\n", + "print(f\"AMP({amp_dtype}) 推理耗时: {t_amp:.2f} 秒\")\n", + "\n", + "speedup = (t_fp32 / t_amp) if t_amp > 0 else float('inf')\n", + "print(f\"加速比(约): {speedup:.2f}x\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "74411a0d-0696-42c8-8138-c38257064e69", + "metadata": {}, + "source": [ + "说明:本环境下 `infer_pdb` 的导出阶段涉及 Python 字符串格式化,放在 autocast 里可能踩到兼容性问题。\n", + "因此这里只对 forward 做 AMP 对比;PDB 仍使用 FP32 的 infer_pdb 结果(pdb_string)。\n", + "提示:PDB 导出建议使用 FP32(稳定优先)。" + ] + }, + { + "cell_type": "markdown", + "id": "42971375", + "metadata": {}, + "source": [ + "## 6. (进阶) 多链复合物预测 (Multimer / Complex)\n", + "\n", + "蛋白质往往不是独立工作的,它们会形成复合物(Complexes)。ESMFold 支持通过 **Linker(连接符)** 技巧来预测多条肽链的相互作用。\n", + "\n", + "我们将两条链通过一串甘氨酸(Glycine, 'G')或特殊的 Mask 连接起来,模型会自动推断它们在空间中的结合方式。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2ba34af", + "metadata": {}, + "outputs": [], + "source": [ + "import mindtorch as torch\n", + "import py3Dmol\n", + "\n", + "# 多链复合物预测(Linker 技巧)\n", + "# 说明:当前 ESMFold 对 multimer 的常见做法是把多条链用 linker 拼接,然后用残基范围做可视化区分。\n", + "\n", + "model.set_train(False)\n", + "if hasattr(model, \"to\"):\n", + " model = model.to(\"npu\")\n", + "\n", + "chain_A = \"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG\"\n", + "chain_B = \"KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE\"\n", + "linker = \"G\" * 25\n", + "complex_sequence = chain_A + linker + chain_B\n", + "\n", + "print(f\"Chain A 长度: {len(chain_A)}\")\n", + "print(f\"Chain B 长度: {len(chain_B)}\")\n", + "print(f\"复合物总长度 (含 Linker): {len(complex_sequence)}\")\n", + "\n", + "# 1) 预处理(mindtorch 张量)\n", + "tokenized_complex = tokenizer([complex_sequence], return_tensors=\"pt\", add_special_tokens=False)\n", + "for k, v in list(tokenized_complex.items()):\n", + " if hasattr(v, \"to\"):\n", + " tokenized_complex[k] = v.to(\"npu\")\n", + "\n", + "# 2) 推理\n", + "print(\"正在预测复合物结构...\")\n", + "with torch.no_grad():\n", + " complex_outputs = model(**tokenized_complex)\n", + "\n", + "# 3) 使用模型自带的 PDB 导出\n", + "complex_pdb = str(model.output_to_pdb(complex_outputs)[0])\n", + "\n", + "# 4) 可视化(按残基范围区分链)\n", + "view = py3Dmol.view(width=900, height=420)\n", + "view.addModel(complex_pdb, \"pdb\")\n", + "\n", + "# Chain A:红色\n", + "view.setStyle({\"resi\": f\"1-{len(chain_A)}\"}, {\"cartoon\": {\"color\": \"red\"}})\n", + "# Linker:黑色线框(也可隐藏)\n", + "view.setStyle({\"resi\": f\"{len(chain_A)+1}-{len(chain_A)+len(linker)}\"}, {\"line\": {\"color\": \"black\", \"opacity\": 0.5}})\n", + "# Chain B:蓝色\n", + "view.setStyle({\"resi\": f\"{len(chain_A)+len(linker)+1}-{len(complex_sequence)}\"}, {\"cartoon\": {\"color\": \"blue\"}})\n", + "\n", + "view.zoomTo()\n", + "view.show()\n", + "print(\"红色: Chain A, 蓝色: Chain B, 黑色: Linker\")" + ] + }, + { + "cell_type": "markdown", + "id": "c293e375", + "metadata": {}, + "source": [ + "## 7. 保存 PDB 文件\n", + "\n", + "`infer_pdb` 会直接返回标准 PDB 字符串(已写入 B-factor 置信度信息)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a27c3a22", + "metadata": {}, + "outputs": [], + "source": [ + "# 保存 PDB 文件(直接保存 infer_pdb 的输出)\n", + "output_filename = \"predicted.pdb\"\n", + "with open(output_filename, \"w\") as f:\n", + " f.write(pdb_string)\n", + "\n", + "print(f\"PDB 文件已生成: {output_filename}\")" + ] + }, + { + "cell_type": "markdown", + "id": "38112517-e67a-400e-86b3-74d73b4e7c28", + "metadata": { + "vscode": { + "languageId": "ini" + } + }, + "source": [ + "## 8. 结构可视化\n", + "\n", + "使用 `py3Dmol` 在 Notebook 中渲染 3D 结构。我们可以根据 **pLDDT (置信度)** 对结构进行着色:蓝色代表高置信度,红色代表低置信度。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c78ef1f-fe40-43d3-95af-24414cf8d91f", + "metadata": {}, + "outputs": [], + "source": [ + "# 单链结构可视化\n", + "# 注:`infer_pdb` 已把置信度写入 B-factor 字段,py3Dmol 可按 B-factor 上色。\n", + "\n", + "def view_pdb_by_plddt(pdb_str):\n", + " view = py3Dmol.view(width=900, height=420)\n", + " view.addModel(pdb_str, 'pdb')\n", + " # 以 B-factor(=pLDDT) 着色:蓝色高、红色低\n", + " view.setStyle({\"cartoon\": {\"colorscheme\": {\"prop\": \"b\", \"gradient\": \"roygb\", \"min\": 50, \"max\": 95}}})\n", + " view.zoomTo()\n", + " return view\n", + "\n", + "view = view_pdb_by_plddt(pdb_string)\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "id": "61f1c056", + "metadata": {}, + "source": [ + "## 9. 模型微调演示 \n", + "\n", + "在实际应用中,我们可能需要针对特定的蛋白质家族对模型进行微调。\n", + "由于 ESMFold 参数量巨大(约 3B),全参数微调对显存要求极高。此处演示 **参数高效微调 (PEFT)** 策略:**冻结主干网络,仅训练结构模块的投影层**。\n", + "\n", + "> **注**:本节仅演示训练循环的构建,未引入真实训练数据。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0700e509", + "metadata": {}, + "outputs": [], + "source": [ + "import mindtorch as torch\n", + "import mindtorch.nn as nn\n", + "import mindtorch.optim as optim\n", + "\n", + "print(\"=== 微调演示(最小可跑版本)===\")\n", + "print(\"说明:此处用伪造目标坐标 + MSE 损失演示训练闭环;真实训练需结构数据与 FAPE 等损失。\")\n", + "\n", + "# 1) 选择可训练参数(参数高效:只训练少量投影层)\n", + "for _, p in model.named_parameters():\n", + " p.requires_grad = False\n", + "\n", + "trainable_params = []\n", + "for name, p in model.named_parameters():\n", + " if \"esm_s_mlp\" in name: # 结构模块输入投影层(示例)\n", + " p.requires_grad = True\n", + " trainable_params.append(p)\n", + "\n", + "print(f\"可训练参数数量: {len(trainable_params)}\")\n", + "\n", + "# 2) 构造一个短序列,减少训练开销\n", + "short_seq = \"MKTVRQERLKSIVRILERSKEPVSGAQLAEEL\" \n", + "train_inputs = tokenizer([short_seq], return_tensors=\"pt\", add_special_tokens=False)\n", + "for k, v in list(train_inputs.items()):\n", + " if hasattr(v, \"to\"):\n", + " train_inputs[k] = v.to(\"npu\")\n", + "\n", + "model = model.to(\"npu\")\n", + "model.train()\n", + "\n", + "# 3) 先跑一次推理得到基准坐标,再构造伪目标\n", + "with torch.no_grad():\n", + " base_out = model(**train_inputs)\n", + "base_pos = base_out[\"positions\"] if isinstance(base_out, dict) else getattr(base_out, \"positions\", None)\n", + "base_pos_last = base_pos[-1, 0] # [L, atom14, 3]\n", + "\n", + "target_pos = (base_pos_last + 0.01 * torch.randn_like(base_pos_last)).detach()\n", + "\n", + "# 4) 定义优化器与损失\n", + "optimizer = optim.Adam(trainable_params, lr=1e-4)\n", + "loss_fn = nn.MSELoss()\n", + "\n", + "# 5) 跑 1-2 step 验证反向与更新\n", + "for step in range(2):\n", + " # mindtorch.optim.Optimizer.zero_grad 在当前环境可能依赖 profiler,故这里手动清梯度\n", + " for p in trainable_params:\n", + " p.grad = None\n", + "\n", + " out = model(**train_inputs)\n", + " pos = out[\"positions\"] if isinstance(out, dict) else getattr(out, \"positions\", None)\n", + " pred = pos[-1, 0]\n", + " loss = loss_fn(pred, target_pos)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " loss_value = loss.detach().cpu().numpy() if hasattr(loss, \"cpu\") else loss.detach().numpy()\n", + " print(f\"Step {step+1} loss: {float(loss_value):.6f}\")\n", + "\n", + "print(\"✅ 微调闭环跑通\")" + ] + }, + { + "cell_type": "markdown", + "id": "61653ce6", + "metadata": {}, + "source": [ + "## 10. 总结\n", + "\n", + "本案例基于 MindSpore NLP 成功实现了 ESMFold 的全流程应用:\n", + "1. **易用性**:通过 `MindSpore NLP` 接口,无缝加载了 Hugging Face 的 PyTorch 权重,无需手动转换。\n", + "2. **高性能**:利用 Ascend/GPU 算力,实现了秒级的蛋白质结构推理。\n", + "3. **完整性**:覆盖了从数据预处理、推理、PDB 生成到微调演示的完整开发链路。\n", + "\n", + "这为生物计算领域的开发者提供了一个基于国产 AI 框架的高效基座。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}