diff --git a/llm/README.md b/llm/README.md index 679a681..351b1e9 100644 --- a/llm/README.md +++ b/llm/README.md @@ -11,6 +11,7 @@ The following notebooks are actively maintained in sync with MindSpore and MindS | No. | Model | Description | | :-- | :---- | :----------------------- | | 1 | [t5](./t5/) | Includes notebooks for T5 finetuning and inference on tasks such as email summarization | +| 2 | [BERT (SWAG Multiple Choice)](./bert/finetune_bert_multiple_choice.ipynb) | Fine-tuning BERT on SWAG dataset for Multiple Choice tasks using MindSpore NLP | ### Community-Driven / Legacy Applications diff --git a/llm/bert/finetune_bert_multiple_choice.ipynb b/llm/bert/finetune_bert_multiple_choice.ipynb new file mode 100644 index 0000000..161d2b2 --- /dev/null +++ b/llm/bert/finetune_bert_multiple_choice.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# 基于 MindSpore 的 BERT 模型 SWAG 多选阅读理解任务\n", + "\n", + "## 案例介绍\n", + "\n", + "**SWAG** (Situations With Adversarial Generations) 是一个大规模的对抗性数据集,用于基于常识的自然语言推理 (NLI)。给定一个部分描述的事件作为上下文,任务是从四个选项中选择最合理的结尾。\n", + "\n", + "本案例基于 **MindSpore** 框架和 **MindSpore NLP** 套件,使用 **BERT** (Bidirectional Encoder Representations from Transformers) 预训练模型在 SWAG 数据集上进行微调 (Fine-tune),实现多项选择任务的自动推理。\n", + "\n", + "**核心流程:**\n", + "1. **环境准备**:配置 MindSpore 运行环境及 HF-Mirror 镜像加速。\n", + "2. **数据处理**:加载 SWAG 数据集,进行 Tokenization、Flatten 处理及动态 Padding。\n", + "3. **模型构建**:加载 BERT 预训练权重,构建多选分类网络。\n", + "4. **模型训练**:定义损失函数与优化器,执行微调。\n", + "5. **模型推理**:加载微调后的模型,演示端到端推理。\n" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## 环境准备\n", + "\n", + "本案例运行环境要求如下:\n", + "\n", + "| Python | MindSpore | MindSpore NLP |\n", + "| :----- | :-------- | :------ |\n", + "| 3.9+ | >= 2.7.0 | >= 0.5.1 |\n", + "\n", + "首先导入必要的依赖库,并设置环境变量以使用 HF-Mirror 国内镜像加速下载。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import mindspore as ms\n", + "import numpy as np\n", + "from dataclasses import dataclass\n", + "from typing import Any, Dict, List, Optional, Union\n", + "\n", + "# ----------------------------\n", + "# Environment (HF-Mirror)\n", + "# ----------------------------\n", + "os.environ.setdefault(\"HF_ENDPOINT\", \"https://hf-mirror.com\")\n", + "os.environ.setdefault(\"TOKENIZERS_PARALLELISM\", \"false\")\n", + "\n", + "# ----------------------------\n", + "# Import MindSpore NLP transformers\n", + "# ----------------------------\n", + "import mindnlp # noqa: F401\n", + "\n", + "from mindnlp.transformers import (\n", + " AutoTokenizer,\n", + " AutoModelForMultipleChoice,\n", + " Trainer,\n", + " TrainingArguments,\n", + ")\n", + "\n", + "print(f\">>> MindSpore Version: {ms.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "#### **定义辅助函数**\n", + "\n", + "为了确保代码的健壮性以及在不同硬件(Ascend/GPU/CPU)上的兼容性,我们定义以下工具函数:\n", + "- `set_ms_context`: 设置 MindSpore 运行模式(PYNATIVE)。\n", + "- `to_numpy`: 鲁棒的 Tensor 转 Numpy 函数,兼容 MindSpore Tensor 和 PyTorch Tensor。\n", + "- `move_inputs_to_device`: 确保推理时输入数据与模型在同一设备上。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# ----------------------------\n", + "# MindSpore context\n", + "# ----------------------------\n", + "def set_ms_context():\n", + " ms.set_seed(42)\n", + " ms.set_context(mode=ms.PYNATIVE_MODE)\n", + " try:\n", + " ms.set_device(\"Ascend\", 0)\n", + " except AttributeError:\n", + " ms.set_context(device_target=\"Ascend\", device_id=0)\n", + "\n", + "def to_numpy(x) -> np.ndarray:\n", + " if x is None:\n", + " return None\n", + " if isinstance(x, np.ndarray):\n", + " return x\n", + " if isinstance(x, (list, tuple)):\n", + " return np.asarray(x)\n", + "\n", + " # MindSpore Tensor\n", + " if hasattr(x, \"asnumpy\"):\n", + " try:\n", + " return x.asnumpy()\n", + " except Exception:\n", + " pass\n", + "\n", + " # (mind)torch Tensor\n", + " if hasattr(x, \"detach\"):\n", + " try:\n", + " x = x.detach()\n", + " except Exception:\n", + " pass\n", + " if hasattr(x, \"cpu\"):\n", + " try:\n", + " x = x.cpu()\n", + " except Exception:\n", + " pass\n", + " if hasattr(x, \"numpy\"):\n", + " try:\n", + " return x.numpy()\n", + " except Exception:\n", + " pass\n", + "\n", + " return np.asarray(x)\n", + "\n", + "def get_model_device(model):\n", + " \"\"\"从参数上取 device(适配 mindtorch/torch 风格模型)。\"\"\"\n", + " try:\n", + " for p in model.parameters():\n", + " return p.device\n", + " except Exception:\n", + " return None\n", + " return None\n", + "\n", + "def move_inputs_to_device(inputs: Dict[str, Any], device):\n", + " \"\"\"把 Batch inputs 全部移动到同一 device(仅对有 .to 的张量生效)。\"\"\"\n", + " if device is None:\n", + " return inputs\n", + " out = {}\n", + " for k, v in inputs.items():\n", + " if hasattr(v, \"to\"):\n", + " out[k] = v.to(device)\n", + " else:\n", + " out[k] = v\n", + " return out\n", + "\n", + "# 初始化上下文\n", + "set_ms_context()\n", + "print(\">>> Context set to PYNATIVE | Ascend:0\")" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## 数据加载与预处理\n", + "\n", + "我们使用 HuggingFace `datasets` 库加载 SWAG 数据集。\n", + "为了演示效率,我们从原始训练集和验证集中截取部分数据进行训练。\n", + "\n", + "- **model_checkpoint**: 使用 `google-bert/bert-base-uncased`。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# 配置\n", + "model_checkpoint = \"google-bert/bert-base-uncased\"\n", + "output_dir = \"./my_awesome_swag_model_ms\"\n", + "\n", + "# 样本量\n", + "max_train_samples = 2000\n", + "max_eval_samples = 1000\n", + "\n", + "print(f\">>> Model: {model_checkpoint}\")\n", + "\n", + "# 1. Dataset\n", + "from datasets import load_dataset\n", + "raw = load_dataset(\"swag\", \"regular\")\n", + "\n", + "# 截取子集\n", + "raw[\"train\"] = raw[\"train\"].select(range(min(max_train_samples, len(raw[\"train\"]))))\n", + "raw[\"validation\"] = raw[\"validation\"].select(range(min(max_eval_samples, len(raw[\"validation\"]))))\n", + "\n", + "print(f\">>> Data: Train={len(raw['train'])}, Valid={len(raw['validation'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "#### **数据预处理 (Tokenization)**\n", + "\n", + "多选任务的数据预处理稍显特殊。我们需要将 **Context (sent1)** 与 **Header (sent2)** 结合,分别与 4 个 **Ending** 选项拼接,形成 4 个独立的输入序列。\n", + "\n", + "1. **Flatten**: 将 `(Batch, 4)` 的结构展平为 `(Batch * 4)` 进行 Tokenize。\n", + "2. **Tokenize**: 使用 BERT Tokenizer 进行编码。\n", + "3. **Un-flatten**: 将编码后的结果重新 reshape 回 `(Batch, 4, Seq_Len)`。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# 2. Tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "# 3. Preprocess\n", + "ending_names = [\"ending0\", \"ending1\", \"ending2\", \"ending3\"]\n", + "\n", + "def preprocess_function(examples):\n", + " first_sentences = [[c] * 4 for c in examples[\"sent1\"]]\n", + " question_headers = examples[\"sent2\"]\n", + " second_sentences = [\n", + " [f\"{h} {examples[end][i]}\" for end in ending_names]\n", + " for i, h in enumerate(question_headers)\n", + " ]\n", + "\n", + " # Flatten\n", + " first_sentences = sum(first_sentences, [])\n", + " second_sentences = sum(second_sentences, [])\n", + "\n", + " tokenized = tokenizer(first_sentences, second_sentences, truncation=True)\n", + "\n", + " # Un-flatten\n", + " result = {k: [v[i:i + 4] for i in range(0, len(v), 4)] for k, v in tokenized.items()}\n", + " result[\"labels\"] = examples[\"label\"]\n", + " return result\n", + "\n", + "# 执行 Map 操作\n", + "encoded = raw.map(preprocess_function, batched=True, remove_columns=raw[\"train\"].column_names)\n", + "print(\"Data preprocessing completed.\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "#### **定义 DataCollator**\n", + "\n", + "定义数据整理器,负责在 Batch 层面进行 **Dynamic Padding**,并将数据转换为 MindSpore Tensor。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# ----------------------------\n", + "# DataCollator\n", + "# ----------------------------\n", + "@dataclass\n", + "class DataCollatorForMultipleChoice:\n", + " tokenizer: Any\n", + " padding: Union[bool, str] = \"longest\"\n", + " pad_to_multiple_of: Optional[int] = None\n", + " label_dtype: ms.dtype = ms.int32\n", + "\n", + " def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:\n", + " label_key = \"labels\" if \"labels\" in features[0] else (\"label\" if \"label\" in features[0] else None)\n", + " labels = [f.pop(label_key) for f in features] if label_key else None\n", + "\n", + " batch_size = len(features)\n", + " num_choices = len(features[0][\"input_ids\"])\n", + "\n", + " flattened = []\n", + " for feat in features:\n", + " for i in range(num_choices):\n", + " flattened.append({k: v[i] for k, v in feat.items()})\n", + "\n", + " # 训练阶段:优先返回 MindSpore 张量\n", + " try:\n", + " batch = self.tokenizer.pad(\n", + " flattened,\n", + " padding=self.padding,\n", + " pad_to_multiple_of=self.pad_to_multiple_of,\n", + " return_tensors=\"ms\",\n", + " )\n", + " out = {k: v.reshape((batch_size, num_choices, -1)) for k, v in batch.items()}\n", + " if labels is not None:\n", + " out[\"labels\"] = ms.Tensor(np.asarray(labels, dtype=np.int32), dtype=self.label_dtype)\n", + " return out\n", + " except Exception:\n", + " # Fallback\n", + " batch_np = self.tokenizer.pad(\n", + " flattened,\n", + " padding=self.padding,\n", + " pad_to_multiple_of=self.pad_to_multiple_of,\n", + " return_tensors=\"np\",\n", + " )\n", + " out = {}\n", + " for k, v in batch_np.items():\n", + " arr = np.asarray(v).reshape((batch_size, num_choices, -1))\n", + " if k in (\"input_ids\", \"attention_mask\", \"token_type_ids\"):\n", + " arr = arr.astype(np.int32, copy=False)\n", + " out[k] = ms.Tensor(arr)\n", + " if labels is not None:\n", + " out[\"labels\"] = ms.Tensor(np.asarray(labels, dtype=np.int32), dtype=self.label_dtype)\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## 模型构建\n", + "\n", + "使用 `AutoModelForMultipleChoice` 加载预训练模型。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# 4. Model\n", + "model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)\n", + "print(\"Model loaded.\")" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "## 模型训练\n", + "\n", + "配置 `TrainingArguments` 并初始化 `Trainer`。\n", + "我们定义 `compute_metrics` 函数来计算准确率 (Accuracy)。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# 5. Metrics\n", + "def compute_metrics(eval_predictions):\n", + " if hasattr(eval_predictions, \"predictions\"):\n", + " logits = eval_predictions.predictions\n", + " labels = eval_predictions.label_ids\n", + " else:\n", + " logits, labels = eval_predictions\n", + " preds = np.argmax(to_numpy(logits), axis=1)\n", + " labels = to_numpy(labels)\n", + " return {\"accuracy\": float((preds == labels).mean())}\n", + "\n", + "# 6. TrainingArguments\n", + "train_args = TrainingArguments(\n", + " output_dir=output_dir,\n", + " learning_rate=5e-5,\n", + " per_device_train_batch_size=8,\n", + " per_device_eval_batch_size=8,\n", + " num_train_epochs=3,\n", + " weight_decay=0.01,\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " save_total_limit=1,\n", + " logging_steps=50,\n", + " push_to_hub=False,\n", + " remove_unused_columns=False,\n", + " report_to=[],\n", + ")\n", + "\n", + "# 7. Trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=train_args,\n", + " train_dataset=encoded[\"train\"],\n", + " eval_dataset=encoded[\"validation\"],\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorForMultipleChoice(tokenizer),\n", + " compute_metrics=compute_metrics,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "#### **执行训练与保存**\n", + "\n", + "调用 `trainer.train()` 开始训练,训练结束后保存模型和 Tokenizer。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# 8. Run\n", + "print(\"\\n>>> Starting training...\")\n", + "trainer.train()\n", + "\n", + "print(\"\\n>>> Starting evaluation...\")\n", + "metrics = trainer.evaluate()\n", + "print(f\">>> Eval metrics: {metrics}\")\n", + "\n", + "# 9. Save\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "trainer.save_model(output_dir)\n", + "try:\n", + " tokenizer.save_pretrained(output_dir)\n", + "except Exception:\n", + " pass\n", + "print(f\">>> Model saved to: {output_dir}\")" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## 模型推理\n", + "\n", + "为了验证模型效果,我们进行一次端到端的推理演示。\n", + "\n", + "**注意(关键修复):**\n", + "在 MindSpore NLP 环境下,为了确保推理的稳定性和跨后端兼容性,我们采取以下策略:\n", + "1. **`return_tensors=\"pt\"`**: 使用 PyTorch 兼容的 Tensor 格式(MindSpore NLP 会自动代理到 MindTorch)。\n", + "2. **`move_inputs_to_device`**: 显式将输入数据移动到模型参数所在的设备,避免 \"All tensor arguments must be on the same device\" 错误。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# 10. Inference Demo (FIXED)\n", + "print(\"\\n>>> Running Inference Demo...\")\n", + "\n", + "# 推理阶段需要 torch/no_grad;在 MindSpore NLP 环境下 torch 会被代理到 mindtorch\n", + "import torch\n", + "\n", + "model.eval()\n", + "device = get_model_device(model)\n", + "print(f\">>> Inference model device: {device}\")\n", + "\n", + "# 1. 准备单条样本\n", + "sample = raw[\"validation\"][0]\n", + "context = sample[\"sent1\"]\n", + "header = sample[\"sent2\"]\n", + "choices = [sample[e] for e in ending_names]\n", + "\n", + "first = [context] * 4\n", + "second = [f\"{header} {c}\" for c in choices]\n", + "\n", + "# 2. Tokenize\n", + "tok = tokenizer(first, second, truncation=True, padding=True, return_tensors=\"pt\")\n", + "inputs = {k: v.reshape((1, 4, -1)) for k, v in tok.items()}\n", + "inputs = move_inputs_to_device(inputs, device)\n", + "\n", + "# 3. 执行前向计算 (No Grad)\n", + "with torch.no_grad():\n", + " outputs = model(**inputs)\n", + "\n", + "logits = outputs[\"logits\"] if isinstance(outputs, dict) else outputs.logits\n", + "pred = int(np.argmax(to_numpy(logits), axis=1)[0])\n", + "\n", + "# 4. 打印结果\n", + "print(\"-\" * 50)\n", + "print(f\"Context: {context}\")\n", + "print(f\"Header : {header}\")\n", + "for i, c in enumerate(choices):\n", + " mark = \"[x]\" if i == pred else \"[ ]\"\n", + " print(f\" {mark} {c}\")\n", + "print(\"-\" * 50)\n", + "\n", + "gold = int(sample[\"label\"])\n", + "if pred == gold:\n", + " print(f\"Result: CORRECT (Pred: {pred}, Gold: {gold})\")\n", + "else:\n", + " print(f\"Result: INCORRECT (Pred: {pred}, Gold: {gold})\")" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}