diff --git a/llm/README.md b/llm/README.md index 679a681..270497f 100644 --- a/llm/README.md +++ b/llm/README.md @@ -11,6 +11,7 @@ The following notebooks are actively maintained in sync with MindSpore and MindS | No. | Model | Description | | :-- | :---- | :----------------------- | | 1 | [t5](./t5/) | Includes notebooks for T5 finetuning and inference on tasks such as email summarization | +| 2 | [distilgpt2](./distilgpt2/) | Includes notebooks for DistilGPT-2 finetuning and inference on causal language modeling (text generation) tasks. | ### Community-Driven / Legacy Applications diff --git a/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb b/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb new file mode 100644 index 0000000..3c1a8eb --- /dev/null +++ b/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb @@ -0,0 +1,785 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# 基于MindSpore的DistilGPT-2语言模型微调与文本生成任务\n", + "\n", + "## 案例介绍\n", + "\n", + "本 Notebook 演示如何在 **MindSpore + MindSpore NLP** 生态中,对 **Causal LM(自回归语言模型)** 进行微调,并在训练完成后进行文本生成(续写)推理。\n", + "\n", + "示例流程包含:\n", + "\n", + "- 环境与依赖准备(版本检查、可选安装)\n", + "- 数据集加载:Wikitext-2-raw-v1\n", + "- 预训练模型加载:DistilGPT-2\n", + "- 文本预处理与语言建模样本构造(shift labels、padding、batch)\n", + "- 训练与验证(loss 监控)\n", + "- 推理生成与模型保存/加载(可选)\n" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## 模型简介\n", + "\n", + "我们选用 HuggingFace 社区中体量较小、易于在单机设备上快速实验的 **DistilGPT-2** 模型:\n", + "\n", + "- 模型 ID:`distilgpt2`\n", + "- 结构:GPT-2 的轻量版,自回归语言模型(Causal LM)\n", + "- 任务:给定前文,预测下一个 token(下一词/子词)\n", + "\n", + "MindSpore NLP 的 `AutoTokenizer` 和 `AutoModelForCausalLM` 提供与 HuggingFace Transformers 类似的使用方式。\n" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## 环境准备\n", + "\n", + "本案例推荐运行环境(示例):\n", + "\n", + "| Python | MindSpore | MindSpore NLP |\n", + "| :----- | :-------- | :---------------- |\n", + "| 3.10 | 2.7.0 | 0.5.1 |\n", + "\n", + "运行设备建议:Ascend(如 Atlas 系列)。如你使用在线算力平台/预置镜像环境,通常无需重新安装 MindSpore,仅需按需补齐 MindSpore NLP 及辅助依赖。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# 检查 mindspore / MindSpore NLP 版本(若未安装,可先跳过执行)\n", + "!pip show mindspore\n", + "!pip show mindnlp\n" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "如果当前环境缺少依赖或版本不匹配,可参考下方(已注释)安装命令按需安装:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "# 如果环境已经安装了指定版本的 mindspore 和 MindSpore NLP,可以跳过本单元格\n", + "# 在昇思 AI 实验室 / Atlas 服务器上建议使用镜像自带的 MindSpore,再手动安装 MindSpore NLP。\n", + "\n", + "# 安装 MindSpore NLP 0.5.1(示例:从 PyPI 安装)\n", + "# !pip install mindnlp==0.5.1 -i https://pypi.tuna.tsinghua.edu.cn/simple\n", + "\n", + "# 若需要安装 HuggingFace datasets 和 evaluate:\n", + "# !pip install datasets evaluate tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66945dbe", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "import mindspore as ms\n", + "from mindspore import context\n", + "\n", + "from mindnlp.dataset import load_dataset\n", + "from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "print(\"MindSpore version:\", ms.__version__)\n", + "\n", + "import mindnlp\n", + "print(\"MindNLP version:\", mindnlp.__version__)" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "#### **设置 MindSpore 上下文**\n", + "\n", + "Ascend(Atlas)设备,默认使用PYNATIVE_MODE模式" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device_target = os.getenv(\"DEVICE_TARGET\", \"Ascend\") # 如需在CPU上运行,将\"Ascend\"改为\"CPU\"\n", + "print(\"Using device:\", device_target)\n", + "\n", + "context.set_context(\n", + " device_target=device_target\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## 数据加载与预处理\n", + "\n", + "我们使用 HuggingFace 上的 **wikitext-2-raw-v1** 数据集作为语言模型训练语料。\n", + "该数据集包含维基百科条目文本,是语言模型常用的开源基准数据集之一。\n", + "\n", + "MindSpore NLP 提供了 `load_dataset` 接口,可直接从 HuggingFace Datasets 仓库拉取数据。" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "#### **加载 Wikitext-2-raw-v1 数据集**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# 分别加载 train / validation / test 三个划分\n", + "# 这里指定子集名称 'wikitext-2-raw-v1'\n", + "wiki_ds_dict = load_dataset(\n", + " \"wikitext\",\n", + " name=\"wikitext-2-raw-v1\",\n", + " split=[\"train\", \"validation\", \"test\"]\n", + ")\n", + "\n", + "train_raw = wiki_ds_dict[\"train\"]\n", + "valid_raw = wiki_ds_dict[\"validation\"]\n", + "test_raw = wiki_ds_dict[\"test\"]\n", + "\n", + "print(\"Train size:\", len(train_raw))\n", + "print(\"Valid size:\", len(valid_raw))\n", + "print(\"Test size:\", len(test_raw))\n", + "\n", + "# 查看一个样本\n", + "print(\"\\nExample sample from train:\")\n", + "print(train_raw[10])" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## 模型构建\n", + "\n", + "#### **加载预训练模型与分词器(DistilGPT-2)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"distilgpt2\" # 也可使用 \"gpt2\" 等其他 Causal LM 模型\n", + "\n", + "# 若使用国内镜像,可根据平台设置环境变量 HF_ENDPOINT 或使用 mirror/modelscope 参数\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "\n", + "print(\"Tokenizer vocab size:\", len(tokenizer))\n", + "print(\"Model loaded:\", type(model))\n", + "\n", + "# GPT-2 家族默认没有 pad_token,这里将 pad_token 设置为 eos_token,方便批量 padding\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "\n", + "# 调整模型的词表大小以适配分词器(例如添加了 pad_token 的情况)\n", + "model.resize_token_embeddings(len(tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "#### **文本预处理与语言建模数据构造**\n", + "\n", + "对于自回归语言模型(Causal LM),训练数据通常形如:\n", + "\n", + "- 输入:`input_ids = [w_1, w_2, ..., w_{n-1}]`\n", + "- 标签:`labels = [w_2, w_3, ..., w_n]`\n", + "\n", + "在大多数 GPT 类实现中,可以**直接令 `labels` 与 `input_ids` 相同**,\n", + "模型内部在计算 loss 时会自动进行「右移一位」的处理,并忽略 padding 位置的标签(常用 `-100`)。\n", + "\n", + "本实验中,我们做如下简化处理:\n", + "\n", + "1. 对每条文本单独进行 token 化与截断(不做跨样本拼接);\n", + "2. 令 `labels = input_ids.copy()`,pad 时对 `labels` 使用填充值 `-100`,以避免影响 loss;\n", + "3. 使用 MindSpore 数据集的 `padded_batch` 完成动态 padding。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore.dataset as ds\n", + "from mindspore.dataset import transforms\n", + "\n", + "max_seq_len = 128 # 可根据显存调整\n", + "train_batch_size = 64\n", + "eval_batch_size = 64\n", + "\n", + "def _to_py_str(x):\n", + " if x is None:\n", + " return \"\"\n", + " if isinstance(x, str):\n", + " return x\n", + " if isinstance(x, bytes):\n", + " return x.decode(\"utf-8\", errors=\"ignore\")\n", + " if isinstance(x, (np.bytes_,)):\n", + " return x.decode(\"utf-8\", errors=\"ignore\")\n", + " if isinstance(x, (np.str_,)):\n", + " return str(x)\n", + " if isinstance(x, np.ndarray):\n", + " if x.ndim == 0:\n", + " return _to_py_str(x.item())\n", + " return \" \".join(_to_py_str(t) for t in x.tolist())\n", + " try:\n", + " if isinstance(x, ms.Tensor):\n", + " return _to_py_str(x.asnumpy())\n", + " except Exception:\n", + " pass\n", + " return str(x)\n", + "\n", + "def process_lm_dataset(dataset,\n", + " tokenizer,\n", + " max_seq_len=128,\n", + " batch_size=64,\n", + " shuffle=False,\n", + " take_len=None):\n", + " # GPT-2 系确保有 pad_token\n", + " if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "\n", + " # 打乱 / 截取\n", + " if shuffle:\n", + " dataset = dataset.shuffle(buffer_size=batch_size * 100)\n", + " if take_len:\n", + " dataset = dataset.take(take_len)\n", + "\n", + " # 文本标准化\n", + " dataset = dataset.map(\n", + " operations=[_to_py_str],\n", + " input_columns=\"text\",\n", + " output_columns=[\"text\"],\n", + " num_parallel_workers=4,\n", + " )\n", + "\n", + " # 分词 + labels(保证至少 1 个 token)\n", + " def tokenize_and_create_labels(text):\n", + " text = _to_py_str(text)\n", + " tokenized = tokenizer(\n", + " text,\n", + " truncation=True,\n", + " max_length=max_seq_len,\n", + " add_special_tokens=True,\n", + " )\n", + " ids = tokenized[\"input_ids\"]\n", + " if len(ids) == 0:\n", + " ids = [tokenizer.eos_token_id]\n", + " input_ids = np.array(ids, dtype=np.int32)\n", + " labels = input_ids.copy()\n", + " return input_ids, labels\n", + "\n", + " dataset = dataset.map(\n", + " operations=[tokenize_and_create_labels],\n", + " input_columns=\"text\",\n", + " output_columns=[\"input_ids\", \"labels\"],\n", + " num_parallel_workers=4,\n", + " )\n", + "\n", + " # 显式类型\n", + " type_cast_op = transforms.TypeCast(ms.int32)\n", + " dataset = dataset.map(operations=type_cast_op, input_columns=\"input_ids\", num_parallel_workers=4)\n", + " dataset = dataset.map(operations=type_cast_op, input_columns=\"labels\", num_parallel_workers=4)\n", + "\n", + " # 只保留数值列\n", + " dataset = dataset.project([\"input_ids\", \"labels\"])\n", + "\n", + " # 定长 padded_batch,避免 None 形状带来的歧义/兼容性问题\n", + " dataset = dataset.padded_batch(\n", + " batch_size=batch_size,\n", + " pad_info={\n", + " \"input_ids\": ([max_seq_len], tokenizer.pad_token_id),\n", + " \"labels\": ([max_seq_len], -100),\n", + " },\n", + " drop_remainder=False, # 需要的话可以改成 True\n", + " )\n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "train_take_len = 2000\n", + "valid_take_len = 512\n", + "test_take_len = 512\n", + "\n", + "train_dataset = process_lm_dataset(\n", + " train_raw,\n", + " tokenizer,\n", + " max_seq_len=max_seq_len,\n", + " batch_size=train_batch_size,\n", + " shuffle=True,\n", + " take_len=train_take_len\n", + ")\n", + "\n", + "valid_dataset = process_lm_dataset(\n", + " valid_raw,\n", + " tokenizer,\n", + " max_seq_len=max_seq_len,\n", + " batch_size=eval_batch_size,\n", + " shuffle=False,\n", + " take_len=valid_take_len\n", + ")\n", + "\n", + "test_dataset = process_lm_dataset(\n", + " test_raw,\n", + " tokenizer,\n", + " max_seq_len=max_seq_len,\n", + " batch_size=eval_batch_size,\n", + " shuffle=False,\n", + " take_len=test_take_len\n", + ")\n", + "\n", + "print(\"Train dataset size (batches):\", train_dataset.get_dataset_size())\n", + "print(\"Valid dataset size (batches):\", valid_dataset.get_dataset_size())\n", + "print(\"Test dataset size (batches):\", test_dataset.get_dataset_size())\n", + "\n", + "for batch in train_dataset.create_dict_iterator():\n", + " print(\"input_ids shape:\", batch[\"input_ids\"].shape)\n", + " print(\"labels shape:\", batch[\"labels\"].shape)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "### 模型训练\n", + "\n", + "为了简化训练流程并利用 Hugging Face 风格的 API,我们采用 `mindnlp.transformers.Trainer` 进行训练。针对 MindSpore 后端的特性,我们对 Trainer 进行了如下关键适配:\n", + "\n", + "1. **数据接口适配 (`MSMapDataset`)**\n", + " - 将 MindSpore 原生的流式 Dataset(Iterable)封装为支持下标访问的 Map-style 数据集。\n", + " - 将 Batch 数据预先缓存为 NumPy 格式,以便 Trainer 内部的 DataLoader 能正确索引和分发。\n", + "2. **自定义 Trainer (`NoJitTrainer`)**\n", + " - 继承自 `mindnlp` 的 `Trainer`,主要为了解决动态图模式下的梯度计算问题。\n", + " - **重写 `training_step`**:移除默认的 JIT 编译(静态图加速),采用显式的 `.backward()` 反向传播,避免 `value_and_grad` 在复杂控制流下的潜在兼容性报错。\n", + " - **重写 `compute_loss`**:增加了对 `loss_type=\"ForCausalLMLoss\"` 的强制检查,并提供手动计算 CrossEntropy 的兜底逻辑,确保在模型输出不含 loss 字段时也能正常训练。\n", + "3. **数据整理 (`passthrough_collator`)**\n", + " - 实现了一个直通式 Collator,负责将数据转换为 `mindtorch` Tensor(或 `int64` 类型的 Numpy 数组),不指定具体 Device,交由 Trainer 自动管理设备放置。\n", + "4. **训练配置**\n", + " - 使用 `TrainingArguments` 管理超参,配置 `adamw_torch` 优化器,设置 `learning_rate=5e-5`,并开启评估模式 (`do_eval=True`)。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# ===== 使用 MindSpore NLP Trainer 训练 =====\n", + "import math\n", + "import numpy as np\n", + "import mindspore as ms\n", + "from mindspore import context\n", + "from mindnlp.transformers import TrainingArguments, Trainer as _BaseTrainer\n", + "from transformers.trainer_callback import TrainerCallback\n", + "\n", + "# 将 MindSpore Dataset(已 padded_batch 的“批”)封成可下标 map-style\n", + "class MSMapDataset:\n", + " \"\"\"把 MindSpore Dataset 的每个 batch 缓存为 numpy,供 HF Trainer 索引。\"\"\"\n", + " def __init__(self, ms_dataset):\n", + " self.cache = []\n", + " for b in ms_dataset.create_dict_iterator():\n", + " def to_np(x):\n", + " return x.asnumpy() if hasattr(x, \"asnumpy\") else np.asarray(x)\n", + " self.cache.append({\n", + " \"input_ids\": to_np(b[\"input_ids\"]),\n", + " \"labels\": to_np(b[\"labels\"]),\n", + " })\n", + " def __len__(self):\n", + " return len(self.cache)\n", + " def __getitem__(self, idx):\n", + " return self.cache[idx]\n", + "\n", + "train_map = MSMapDataset(train_dataset)\n", + "valid_map = MSMapDataset(valid_dataset)\n", + "\n", + "# 直通式 collator:不做二次 padding/batch,不指定 device,直接产 int64\n", + "def passthrough_collator(features):\n", + " feat = features[0] if isinstance(features, list) and len(features) == 1 else features\n", + " arr_ids, arr_lbl = feat[\"input_ids\"], feat[\"labels\"]\n", + "\n", + " # 优先用 mindtorch Tensor;不指定 device,避免 'Ascend' 相关报错\n", + " try:\n", + " import mindtorch as mt\n", + " return {\n", + " \"input_ids\": mt.tensor(arr_ids, dtype=mt.int64),\n", + " \"labels\": mt.tensor(arr_lbl, dtype=mt.int64),\n", + " }\n", + " except Exception:\n", + " # 兜底:numpy(Trainer 也能吃)\n", + " return {\n", + " \"input_ids\": np.asarray(arr_ids, dtype=np.int64),\n", + " \"labels\": np.asarray(arr_lbl, dtype=np.int64),\n", + " }\n", + "\n", + "# 自定义 Trainer:手写 backward,避免 value_and_grad / JIT\n", + "class NoJitTrainer(_BaseTrainer):\n", + " def __init__(self, *args, loss_type: str = \"ForCausalLMLoss\", **kwargs):\n", + " # 收下想用的 loss_type\n", + " self._force_loss_type = loss_type\n", + "\n", + " # 在父类 __init__ 之前尽早写入 config.loss_type,避免早期 warning\n", + " model = kwargs.get(\"model\", args[0] if len(args) > 0 else None)\n", + " if model is not None and hasattr(model, \"config\"):\n", + " try:\n", + " if getattr(model.config, \"loss_type\", None) != self._force_loss_type:\n", + " setattr(model.config, \"loss_type\", self._force_loss_type)\n", + " except Exception:\n", + " pass\n", + "\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def _ensure_loss_type(self, model):\n", + " cfg = getattr(model, \"config\", None)\n", + " if cfg is not None and self._force_loss_type:\n", + " try:\n", + " if getattr(cfg, \"loss_type\", None) != self._force_loss_type:\n", + " setattr(cfg, \"loss_type\", self._force_loss_type)\n", + " except Exception:\n", + " pass\n", + "\n", + " # 补上 num_items_in_batch 和 **kwargs 以兼容 Trainer 调用\n", + " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):\n", + " self._ensure_loss_type(model)\n", + "\n", + " outputs = model(**inputs)\n", + " loss = getattr(outputs, \"loss\", None)\n", + "\n", + " if loss is None:\n", + " # 兜底:手动 CE(shift + ignore_index=-100)\n", + " import mindtorch as mt\n", + " import mindtorch.nn.functional as F\n", + " logits = outputs.logits # [B, L, V]\n", + " labels = inputs[\"labels\"] # [B, L]\n", + " shift_logits = logits[:, :-1, :].contiguous()\n", + " shift_labels = labels[:, 1:].contiguous()\n", + " loss = F.cross_entropy(\n", + " shift_logits.reshape(-1, shift_logits.size(-1)),\n", + " shift_labels.reshape(-1),\n", + " ignore_index=-100\n", + " )\n", + "\n", + " return (loss, outputs) if return_outputs else loss\n", + "\n", + " # ⚠️ 同样把签名补齐,Trainer 有时也会传这个参数\n", + " def training_step(self, model, inputs, num_items_in_batch=None, **kwargs):\n", + " model.train()\n", + " loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)\n", + "\n", + " if self.args.gradient_accumulation_steps > 1:\n", + " loss = loss / self.args.gradient_accumulation_steps\n", + "\n", + " loss.backward()\n", + " return loss.detach()\n", + "\n", + "\n", + "# 配置 TrainingArguments\n", + "num_epochs = 3 # 演示可 1~3\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./outputs/gpt2_wikitext2_ms\",\n", + " num_train_epochs=num_epochs,\n", + " learning_rate=5e-5,\n", + " optim=\"adamw_torch\", # 使用合法枚举,避免 'adamw' 报错\n", + " logging_steps=max(1, len(train_map)//20), # 每个 epoch 约打印 20 次 step 级 loss\n", + "\n", + " # 上游已 padded_batch,这里每步就取“一个批”\n", + " per_device_train_batch_size=1,\n", + " per_device_eval_batch_size=1,\n", + "\n", + " dataloader_num_workers=0,\n", + " remove_unused_columns=False,\n", + " do_eval=True, # 打开评估(我们用回调强制每个 epoch 评一次)\n", + ")\n", + "\n", + "trainer = NoJitTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_map,\n", + " eval_dataset=valid_map,\n", + " data_collator=passthrough_collator,\n", + " loss_type=\"ForCausalLMLoss\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "#### **开始训练**\n", + "\n", + "为节省时间,我们这里演示性地训练 **3~5 个 epoch**。实际任务中可按需加大训练轮数。\n", + "\n", + "训练过程中会打印:\n", + "\n", + "- 每个 epoch 的平均训练 loss;\n", + "- 每个 epoch 结束后的验证集平均 loss。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "# ===== 启动训练:打印每个 epoch 的平均 Train/Valid Loss =====\n", + "\n", + "class EpochAvgAndEval(TrainerCallback):\n", + " \"\"\"收集 step 级 loss;在每个 epoch 末强制 evaluate,并打印平均训练 loss 与验证 loss。\"\"\"\n", + " def __init__(self):\n", + " self._loss_buf = []\n", + "\n", + " def on_log(self, args, state, control, logs=None, **kwargs):\n", + " if logs and \"loss\" in logs:\n", + " try:\n", + " self._loss_buf.append(float(logs[\"loss\"]))\n", + " except Exception:\n", + " pass\n", + "\n", + " def on_epoch_end(self, args, state, control, **kwargs):\n", + " ep = int(state.epoch) if state.epoch is not None else -1\n", + " if self._loss_buf:\n", + " avg_train = sum(self._loss_buf) / len(self._loss_buf)\n", + " print(f\"Train loss (epoch {ep}): {avg_train:.4f}\")\n", + " else:\n", + " print(f\"Train loss (epoch {ep}): N/A\")\n", + " self._loss_buf.clear()\n", + "\n", + " # 强制本 epoch 末评估(即使没有 evaluation_strategy)\n", + " control.should_evaluate = True\n", + " return control\n", + "\n", + " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n", + " ep = int(state.epoch) if state.epoch is not None else -1\n", + " if metrics and metrics.get(\"eval_loss\") is not None:\n", + " print(f\"Valid loss (epoch {ep}): {metrics['eval_loss']:.4f}\")\n", + "\n", + "# 挂回调\n", + "trainer.add_callback(EpochAvgAndEval())\n", + "\n", + "# (可选)也可以在开始前先做一次 quick sanity eval\n", + "# _ = trainer.evaluate()\n", + "\n", + "# 开始训练(每个 epoch 末都会打印 Train/Valid loss)\n", + "train_output = trainer.train()\n", + "\n", + "# 训练结束后再评一次并打印 PPL\n", + "final_metrics = trainer.evaluate()\n", + "if final_metrics.get(\"eval_loss\") is not None:\n", + " ppl = math.exp(final_metrics[\"eval_loss\"])\n", + " print(f\"Final Eval loss: {final_metrics['eval_loss']:.4f} | PPL: {ppl:.2f}\")\n", + "else:\n", + " print(\"Final eval_loss 不存在,跳过 PPL 计算。\")" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## 模型推理\n", + "\n", + "训练完成后,我们使用 `model.generate` 进行文本自动续写。\n", + "整体流程:\n", + "\n", + "1. 准备一个中文或英文的起始提示(prompt);\n", + "2. 使用分词器编码为 `input_ids`;\n", + "3. 调用 `model.generate` 生成若干新 token;\n", + "4. 用分词器解码为可读文本。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_text(model,\n", + " tokenizer,\n", + " prompt,\n", + " max_new_tokens=50,\n", + " do_sample=True,\n", + " top_p=0.9,\n", + " temperature=1.0):\n", + " import numpy as np\n", + " import mindtorch as mt\n", + "\n", + " # 以模型参数的 device 为准,避免 device 不一致\n", + " try:\n", + " model_device = next(model.parameters()).device\n", + " except Exception:\n", + " model_device = mt.device(\"cpu\") # 极端兜底\n", + "\n", + " # eval 模式\n", + " model.set_train(False) # 或 model.eval()\n", + "\n", + " # pad_token 兜底(GPT-2 常见)\n", + " if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + " # 编码 -> 放到“模型的 device”\n", + " enc = tokenizer(prompt, add_special_tokens=True)\n", + " input_ids = mt.tensor(np.array([enc[\"input_ids\"]]), device=model_device).long()\n", + "\n", + " # 生成\n", + " with mt.no_grad():\n", + " outputs = model.generate(\n", + " input_ids=input_ids,\n", + " max_new_tokens=max_new_tokens,\n", + " do_sample=do_sample,\n", + " top_p=top_p,\n", + " temperature=temperature,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " eos_token_id=tokenizer.eos_token_id\n", + " )\n", + "\n", + " # 解码\n", + " generated_ids = outputs[0].tolist()\n", + " return tokenizer.decode(generated_ids, skip_special_tokens=True)\n", + "\n", + "\n", + "\n", + "\n", + "## 试一试生成\n", + "\n", + "prompt = \"Deep learning has changed natural language processing because\"\n", + "print(\"Prompt:\\n\", prompt)\n", + "\n", + "generated = generate_text(\n", + " model, tokenizer, prompt,\n", + " max_new_tokens=60,\n", + " do_sample=True,\n", + " top_p=0.95,\n", + " temperature=0.8\n", + ")\n", + "print(\"\\nGenerated text:\\n\", generated)" + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "## 模型保存与加载(可选)\n", + "\n", + "为了在后续 Notebook 或部署场景中复用微调结果,可以将模型权重与分词器信息保存到本地目录,\n", + "之后通过 `from_pretrained` 的方式重新加载。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "save_dir = \"./distilgpt2-ms-finetuned-wikitext2\"\n", + "os.makedirs(save_dir, exist_ok=True)\n", + "\n", + "model.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)\n", + "print(\"Model & tokenizer saved to:\", save_dir)\n", + "\n", + "## 加载(验证)\n", + "\n", + "from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "loaded_tokenizer = AutoTokenizer.from_pretrained(save_dir)\n", + "loaded_model = AutoModelForCausalLM.from_pretrained(save_dir)\n", + "\n", + "test_prompt = \"Language models are\"\n", + "generated_loaded = generate_text(loaded_model, loaded_tokenizer, test_prompt)\n", + "print(\"\\n[Loaded model generation]\\n\", generated_loaded)" + ] + } + ], + "metadata": { + "colab": { + "name": "Fine-tune a language model", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}