diff --git a/nlp/README.md b/nlp/README.md index a2c9dae..32be5fd 100644 --- a/nlp/README.md +++ b/nlp/README.md @@ -6,7 +6,7 @@ This directory contains ready-to-use Natural Language Processing application not | No. | Model | Description | | :-- | :---- | :------------------------------ | -| 1 | / | This section is empty for now — feel free to contribute your first application! | +| 1 | BERT-QA | An extractive question answering system based on BERT-Base and SQuAD v1.1. Features include sliding window mechanism for long documents. | ## Contributing New NLP Applications diff --git a/nlp/question_answering_bert.ipynb b/nlp/question_answering_bert.ipynb new file mode 100644 index 0000000..1bc4384 --- /dev/null +++ b/nlp/question_answering_bert.ipynb @@ -0,0 +1,987 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b5dc2eca", + "metadata": {}, + "source": [ + "# 基于MindNLP 的 BERT 抽取式问答模型微调实战\n", + "\n", + "## 1. 项目背景\n", + "机器阅读理解(Machine Reading Comprehension, MRC)是自然语言处理的核心任务之一。本实验旨在基于 **MindSpore 2.7.0** 深度学习框架与 **MindNLP 0.5.1** 自然语言处理套件,在经典的 **SQuAD (Stanford Question Answering Dataset)** 数据集上微调 BERT 模型,构建一个端到端的抽取式问答(Extractive QA)系统。\n", + "\n", + "## 2. 实验目标与关键技术\n", + "* **模型架构**:使用 BERT-Base 模型作为骨干网络,通过下游任务微调实现答案跨度(Answer Span)预测。\n", + "* **长文本处理**:实现**滑动窗口(Sliding Window)**机制,有效解决 BERT 输入序列长度限制(512 Tokens)问题。\n", + "* **异构计算适配**:解决 MindSpore 在动态图模式(PyNative)下的张量设备同步(Device Mismatch)问题。\n", + "* **性能评估**:通过验证集样本进行定性分析,验证模型在真实场景下的推理能力。\n", + "\n", + "## 3. 环境依赖\n", + "* Hardware: Ascend 910 / GPU (CUDA)\n", + "* Framework: MindSpore >= 2.7.0\n", + "* Toolkit: MindNLP == 0.5.1\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a1d256e", + "metadata": {}, + "outputs": [], + "source": [ + "# 安装实验所需的依赖库\n", + "# 注意:首次运行需取消注释并执行,若mindnlp安装后报错可直接下载源码并执行 pip install -e .\n", + "# !pip install mindspore==2.7.0\n", + "# !pip install mindnlp==0.5.1\n", + "# !pip install datasets tqdm nbformat\n", + "# !diffusers==0.35.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dc6e1a6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + " setattr(self, word, getattr(machar, word).flat[0])\n", + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + " return self._float_to_str(self.smallest_subnormal)\n", + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + " setattr(self, word, getattr(machar, word).flat[0])\n", + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + " return self._float_to_str(self.smallest_subnormal)\n", + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases.\n", + "[WARNING] ME(3907224:281473775931424,MainProcess):2026-01-04-19:56:53.460.000 [mindspore/context.py:1412] For 'context.set_context', the parameter 'device_target' will be deprecated and removed in a future version. Please use the api mindspore.set_device() instead.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "当前运行设备 (Device): Ascend\n", + "当前运行模式 (Mode): 1\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import mindspore as ms\n", + "from mindnlp.transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n", + "from datasets import load_dataset \n", + "\n", + "# --- 全局环境配置 ---\n", + "# 自动检测计算设备:优先使用 Ascend NPU,其次尝试 CCPU\n", + "try:\n", + " ms.set_context(device_target=\"Ascend\")\n", + "except Exception:\n", + " ms.set_context(device_target=\"CPU\")\n", + "\n", + "# 设置运行模式:推荐使用 PYNATIVE_MODE (动态图模式) 以获得更好的调试体验\n", + "ms.set_context(mode=ms.PYNATIVE_MODE)\n", + "\n", + "print(f\"当前运行设备 (Device): {ms.get_context('device_target')}\")\n", + "print(f\"当前运行模式 (Mode): {ms.get_context('mode')}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4eac9854", + "metadata": {}, + "source": [ + "## 4. 数据加载与预处理\n", + "\n", + "本实验使用 Hugging Face 的 `datasets` 库加载 SQuAD 数据集。该数据集包含 问题 (Question)、上下文 (Context) 以及 答案 (Answer)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75231c33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "正在加载数据集 (SQuAD)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the dataset since squad couldn't be found on the Hugging Face Hub\n", + "Found the latest cached dataset configuration 'plain_text' at /home/zdai/.cache/huggingface/datasets/squad/plain_text/0.0.0/7b6d24c440a36b6815f21b70d25016731768db1f (last modified on Sun Jan 4 19:55:04 2026).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集样本数: 5000\n", + "验证集样本数: 1500\n", + "样本示例: {'id': '5733be284776f41900661182', 'title': 'University_of_Notre_Dame', 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}\n" + ] + } + ], + "source": [ + "# 加载 SQuAD 数据集\n", + "print(\"正在加载数据集 (SQuAD)...\")\n", + "squad_dataset = load_dataset(\"squad\")\n", + "\n", + "# === 实验数据采样 ===\n", + "# 为了演示流程的高效性,此处仅抽取部分数据进行训练和验证\n", + "# 在正式全量训练时,请使用完整数据集\n", + "train_dataset = squad_dataset[\"train\"].select(range(5000)) # 训练集采样: 5000条\n", + "eval_dataset = squad_dataset[\"validation\"].select(range(1500)) # 验证集采样: 1500条\n", + "\n", + "print(f\"训练集样本数: {len(train_dataset)}\")\n", + "print(f\"验证集样本数: {len(eval_dataset)}\")\n", + "print(f\"样本示例: {train_dataset[0]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b36d584c", + "metadata": {}, + "source": [ + "### 4.1 数据预处理:Tokenization 与 滑动窗口\n", + "\n", + "由于 SQuAD 中的部分文章长度超过 BERT 的最大输入限制(通常为 512 或 384),我们需要引入**滑动窗口(Sliding Window)**策略:\n", + "1. **截断与步长**:当文本过长时,将其切分为多个包含重叠片段的特征(Features)。\n", + "2. **标签对齐**:计算答案在切分后的每个特征片段中的新起始位置(Start/End Position)。\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bffc14de", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "正在执行数据预处理 (Map)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5000/5000 [00:03<00:00, 1541.10 examples/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "预处理完成!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# 模型与预处理超参数配置\n", + "model_checkpoint = \"bert-base-uncased\"\n", + "batch_size = 16\n", + "max_length = 384 # 输入序列最大长度\n", + "doc_stride = 128 # 滑动窗口步长 (重叠部分的长度)\n", + "\n", + "# 初始化 Tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", + "\n", + "def prepare_train_features(examples):\n", + " '''\n", + " 数据预处理核心函数:\n", + " 1. 对 Question 和 Context 进行编码。\n", + " 2. 处理长文本溢出 (Overflow),应用滑动窗口。\n", + " 3. 将字符级别的答案位置 (Character Index) 映射为 Token 级别的索引。\n", + " '''\n", + " # 去除问题两端的空白字符\n", + " questions = [q.strip() for q in examples[\"question\"]]\n", + " \n", + " # Tokenize\n", + " inputs = tokenizer(\n", + " questions,\n", + " examples[\"context\"],\n", + " max_length=max_length,\n", + " truncation=\"only_second\", # 仅截断 Context,保留 Question\n", + " stride=doc_stride, # 应用滑动窗口\n", + " return_overflowing_tokens=True, # 允许返回多个片段\n", + " return_offsets_mapping=True, # 返回字符偏移量映射,用于定位答案\n", + " padding=\"max_length\"\n", + " )\n", + "\n", + " # 获取映射关系\n", + " sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n", + " offset_mapping = inputs.pop(\"offset_mapping\")\n", + "\n", + " start_positions = []\n", + " end_positions = []\n", + "\n", + " for i, offsets in enumerate(offset_mapping):\n", + " sample_idx = sample_map[i]\n", + " answers = examples[\"answers\"][sample_idx]\n", + " \n", + " # 异常处理:如果样本没有答案,标注为 CLS (0, 0)\n", + " if len(answers[\"answer_start\"]) == 0:\n", + " start_positions.append(0)\n", + " end_positions.append(0)\n", + " continue\n", + " \n", + " # 获取答案在原文中的字符级起止位置\n", + " start_char = answers[\"answer_start\"][0]\n", + " end_char = start_char + len(answers[\"text\"][0])\n", + "\n", + " # 区分 Sequence 中的 Question 部分和 Context 部分\n", + " sequence_ids = inputs.sequence_ids(i)\n", + " \n", + " # 寻找 Context 的起止 Token 索引\n", + " idx = 0\n", + " while sequence_ids[idx] != 1:\n", + " idx += 1\n", + " context_start = idx\n", + " while sequence_ids[idx] == 1:\n", + " idx += 1\n", + " context_end = idx - 1\n", + "\n", + " # 判断:如果答案并没有完全包含在当前的窗口片段中,则标记为 (0, 0)\n", + " if not (offsets[context_start][0] <= start_char and offsets[context_end][1] >= end_char):\n", + " start_positions.append(0)\n", + " end_positions.append(0)\n", + " else:\n", + " # 否则,寻找答案 token 的起止索引\n", + " idx = context_start\n", + " while idx <= context_end and offsets[idx][0] <= start_char:\n", + " idx += 1\n", + " start_positions.append(idx - 1)\n", + "\n", + " idx = context_end\n", + " while idx >= context_start and offsets[idx][1] >= end_char:\n", + " idx -= 1\n", + " end_positions.append(idx + 1)\n", + "\n", + " inputs[\"start_positions\"] = start_positions\n", + " inputs[\"end_positions\"] = end_positions\n", + " return inputs\n", + "\n", + "print(\"正在执行数据预处理 (Map)...\")\n", + "tokenized_train = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)\n", + "tokenized_eval = eval_dataset.map(prepare_train_features, batched=True, remove_columns=eval_dataset.column_names)\n", + "print(\"预处理完成!\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b7a7ebb", + "metadata": {}, + "source": [ + "## 5. 模型微调 (Model Fine-tuning)\n", + "\n", + "加载预训练的 BERT 模型,配置训练参数,并使用 MindNLP 的 `Trainer` 接口启动训练流程。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ab1784b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "/tmp/ipykernel_3907224/2377589481.py:18: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> 开始模型训练...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [957/957 07:52, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
105.826700
205.353700
304.728800
404.384300
504.057100
603.816100
703.714400
803.448000
903.292400
1003.003000
1102.879200
1202.679500
1302.753000
1402.521200
1502.350600
1602.485800
1702.567700
1802.503700
1902.168200
2002.260100
2102.331600
2202.239200
2301.947200
2402.063500
2502.218200
2601.996000
2701.921200
2801.600500
2901.844700
3001.861700
3101.638900
3202.075100
3301.511900
3401.536700
3501.470500
3601.312800
3701.297800
3801.265400
3901.342900
4001.436300
4101.385500
4201.531800
4301.195400
4401.527800
4501.447600
4601.470700
4701.343100
4801.245600
4901.247800
5001.428800
5101.120100
5201.442700
5301.222900
5401.123600
5501.416900
5601.212700
5701.267700
5801.340000
5901.230200
6001.456800
6101.284200
6201.356900
6301.254500
6401.133900
6501.131500
6601.089800
6700.958400
6801.028200
6900.871000
7000.935000
7100.870400
7200.921400
7300.973600
7400.886300
7501.091400
7600.856100
7700.899600
7800.965600
7900.999900
8000.970100
8100.874400
8200.737900
8300.913100
8401.015800
8500.753200
8600.894500
8700.942900
8800.830200
8900.980500
9000.884100
9100.929100
9201.153400
9300.994200
9400.816800
9500.992300

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> 模型训练结束。\n" + ] + } + ], + "source": [ + "# 加载预训练模型 (AutoModelForQuestionAnswering)\n", + "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)\n", + "\n", + "# 配置训练参数 (TrainingArguments)\n", + "args = TrainingArguments(\n", + " output_dir=\"./bert_qa_output\",\n", + " eval_strategy=\"no\", # 演示阶段不进行频繁评估以加速\n", + " learning_rate=2e-5, # 学习率\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " num_train_epochs=3, # 训练轮次\n", + " weight_decay=0.01,\n", + " save_strategy=\"no\", # 不保存 Checkpoint\n", + " logging_steps=10\n", + ")\n", + "\n", + "# 初始化 Trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=tokenized_train,\n", + " eval_dataset=tokenized_eval,\n", + " tokenizer=tokenizer,\n", + ")\n", + "\n", + "# 启动训练\n", + "print(\">>> 开始模型训练...\")\n", + "trainer.train()\n", + "print(\">>> 模型训练结束。\")" + ] + }, + { + "cell_type": "markdown", + "id": "9538f1b9", + "metadata": {}, + "source": [ + "## 6. 模型推理与应用 (Inference)\n", + "\n", + "为了验证模型效果,我们定义一个端到端的预测函数。\n", + "\n", + "**技术难点说明**:\n", + "在使用 MindNLP 与 MindSpore 进行交互时,需特别注意**张量设备同步 (Device Synchronization)**。输入模型的 Tensor 必须显式移动到与模型相同的计算设备(Ascend/GPU)上,否则会引发 `ValueError: All tensor arguments must be on the same device`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30b89700", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- 单例推理测试 ---\n", + "Q: What is MindSpore?\n", + "A: deep learning training / inference\n" + ] + } + ], + "source": [ + "import mindspore as ms\n", + "import numpy as np\n", + "\n", + "def predict_answer(question, context):\n", + " '''\n", + " 端到端预测函数:输入问题和上下文,输出预测的答案文本。\n", + " 包含设备自动适配逻辑。\n", + " '''\n", + " # 1. 切换模型至评估模式 (Evaluation Mode)\n", + " model.set_train(False)\n", + " \n", + " # 2. 输入编码\n", + " inputs = tokenizer(\n", + " question, \n", + " context, \n", + " return_tensors=\"ms\", \n", + " max_length=max_length, \n", + " truncation=\"only_second\"\n", + " )\n", + " \n", + " # 3. 设备同步 (Critical Step)\n", + " # 获取模型当前所在的设备 (Ascend/CPU)\n", + " try:\n", + " target_device = model.device\n", + " except:\n", + " # 兼容性处理:如果无法直接获取,尝试通过参数或上下文推断\n", + " try:\n", + " target_device = next(model.get_parameters()).device\n", + " except:\n", + " target_device = ms.get_context(\"device_target\")\n", + "\n", + " # 将输入张量移动到目标设备\n", + " input_ids = ms.Tensor(inputs[\"input_ids\"].asnumpy(), dtype=ms.int32).to(target_device)\n", + " attention_mask = ms.Tensor(inputs[\"attention_mask\"].asnumpy(), dtype=ms.int32).to(target_device)\n", + " token_type_ids = ms.Tensor(inputs[\"token_type_ids\"].asnumpy(), dtype=ms.int32).to(target_device)\n", + " \n", + " # 4. 前向传播 (Forward Pass)\n", + " outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", + " \n", + " # 5. 结果解析 (Logits -> Text)\n", + " start_logits = outputs.start_logits\n", + " end_logits = outputs.end_logits\n", + " \n", + " # 获取概率最大的起始和结束位置索引\n", + " start_idx = np.argmax(start_logits.asnumpy(), axis=-1)[0]\n", + " end_idx = np.argmax(end_logits.asnumpy(), axis=-1)[0]\n", + " \n", + " # 简单的逻辑校验:如果结束位置在起始位置之前,做简单的回退处理\n", + " if end_idx < start_idx:\n", + " end_idx = start_idx + 10 \n", + " \n", + " # 解码答案\n", + " answer_ids = inputs[\"input_ids\"].asnumpy()[0][start_idx : end_idx + 1]\n", + " answer = tokenizer.decode(answer_ids, skip_special_tokens=True)\n", + " return answer\n", + "\n", + "# 单例测试\n", + "print(\"--- 单例推理测试 ---\")\n", + "context_demo = \"MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.\"\n", + "question_demo = \"What is MindSpore?\"\n", + "print(f\"Q: {question_demo}\")\n", + "print(f\"A: {predict_answer(question_demo, context_demo)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "978732e8", + "metadata": {}, + "source": [ + "## 7. 验证集效果抽样评估\n", + "\n", + "从验证集(Validation Set)中随机抽取样本,对比“模型预测答案”与“数据集标准答案”,以定性评估模型性能。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c5174d7b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== 正在从验证集抽取 3 个样本进行评估 ===\n", + "\n", + "【样本 Case 1】\n", + "📌 问题: Which NFL team represented the AFC at Super Bowl 50?\n", + "📄 原文: Super Bowl 50 was an American football game to determine the champion of the Nat...\n", + "--------------------------------------------------\n", + "🟢 标准答案: Denver Broncos\n", + "🔴 模型预测: denver broncos\n", + "==================================================\n", + "\n", + "【样本 Case 2】\n", + "📌 问题: Which NFL team represented the NFC at Super Bowl 50?\n", + "📄 原文: Super Bowl 50 was an American football game to determine the champion of the Nat...\n", + "--------------------------------------------------\n", + "🟢 标准答案: Carolina Panthers\n", + "🔴 模型预测: carolina panthers\n", + "==================================================\n", + "\n", + "【样本 Case 3】\n", + "📌 问题: Where did Super Bowl 50 take place?\n", + "📄 原文: Super Bowl 50 was an American football game to determine the champion of the Nat...\n", + "--------------------------------------------------\n", + "🟢 标准答案: Santa Clara, California\n", + "🔴 模型预测: santa clara, california\n", + "==================================================\n", + "\n" + ] + } + ], + "source": [ + "# 配置测试样本数\n", + "num_samples = 3\n", + "print(f\"=== 正在从验证集抽取 {num_samples} 个样本进行评估 ===\\n\")\n", + "\n", + "# 遍历验证集的前 N 个样本\n", + "# 注意:直接使用 Hugging Face Dataset 的索引访问,避免使用 create_dict_iterator\n", + "for i in range(num_samples):\n", + " batch = eval_dataset[i]\n", + " \n", + " # 提取原始文本\n", + " # Hugging Face Dataset 加载的数据默认为 Python str 类型\n", + " question = batch['question']\n", + " context = batch['context']\n", + " \n", + " # 解析参考答案 (SQuAD 格式较为复杂,此处提取第一个标准答案文本用于展示)\n", + " try:\n", + " ref_text_list = batch['answers']['text']\n", + " ref_display = ref_text_list[0] if len(ref_text_list) > 0 else \"<无答案>\"\n", + " except Exception as e:\n", + " ref_display = f\"答案解析错误: {str(e)}\"\n", + "\n", + " # 执行模型预测\n", + " pred_answer = predict_answer(question, context)\n", + " \n", + " # 打印对比结果\n", + " print(f\"【样本 Case {i+1}】\")\n", + " print(f\"📌 问题: {question}\")\n", + " # 截断过长的上下文以便展示\n", + " print(f\"📄 原文: {context[:80]}...\") \n", + " print(f\"--------------------------------------------------\")\n", + " print(f\"🟢 标准答案: {ref_display}\")\n", + " print(f\"🔴 模型预测: {pred_answer}\")\n", + " print(f\"==================================================\\n\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "QA", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}