Skip to content

Commit

Permalink
handle mistralai/Mistral-7B-Instruct-v0.3 tokenizer correctly (NVIDIA…
Browse files Browse the repository at this point in the history
…#11839)

* handle mistralai/Mistral-7B-Instruct-v0.3 tokenizer correctly

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove manual token addition

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Jan 13, 2025
1 parent 8279d3e commit d82f53a
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""


import hashlib
import json
import os
import re
from argparse import ArgumentParser
from collections import OrderedDict
from pathlib import Path
Expand All @@ -31,7 +33,7 @@
import torch.nn
from lightning.pytorch.core.saving import _load_state as ptl_load_state
from lightning.pytorch.trainer.trainer import Trainer
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict
from transformers import AutoModelForCausalLM, AutoTokenizer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
Expand All @@ -58,6 +60,7 @@ def get_args():
parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--precision", type=str, default="bf16", help="Model precision")
parser.add_argument('--low-ram', '--low-mem', action='store_true', dest='low_ram')
parser.add_argument('--add-additional-tokens', action='store_true')
parser.add_argument('--tmp-dir', default='/tmp/mistral_ckpt_parts/')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -430,6 +433,16 @@ def merge(a: dict, b: dict, path=[]):
return a


def md5_checksum(filepath):
if filepath is None:
return None
hash_md5 = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()


def save_to_nemo(args, checkpoint):
"""saves checkpoint to nemo format"""

Expand Down Expand Up @@ -464,15 +477,42 @@ def save_to_nemo(args, checkpoint):
# disable cpu init
model.cfg.use_cpu_initialization = False
model.cfg.perform_initialization = True
# If user has passed --add-additional-tokens or model is mistralai/Mistral-7B-Instruct-v0.3
if (
args.add_additional_tokens
or md5_checksum(getattr(tokenizer, 'vocab_file', None)) == '2bbc01eba250283314fdbd53d05de94b'
):

def make_token_name(token):
prefix = ''
if len(token) > 1 and token[1] == '/':
prefix = 'eos_'
else:
prefix = 'bos_'
return prefix + re.sub(r'\W', '_', token)

if len(tokenizer.added_tokens_decoder) > 0:
with open_dict(model.cfg.tokenizer):
model.cfg.tokenizer.sentencepiece_legacy = True
model.cfg.tokenizer.special_tokens = {}
model.cfg.tokenizer.special_tokens['bos_token'] = tokenizer.bos_token or "<s>"
model.cfg.tokenizer.special_tokens['eos_token'] = tokenizer.eos_token or "</s>"
model.cfg.tokenizer.special_tokens['pad_token'] = tokenizer.pad_token or "<pad>"
skip_tokens = set(model.cfg.tokenizer.special_tokens.values())
skip_tokens.add('<unk>')
for token_id, token in tokenizer.added_tokens_decoder.items():
token_name = make_token_name(token.content)
if token.content in skip_tokens:
continue
assert not token_name in model.cfg.tokenizer.special_tokens
model.cfg.tokenizer.special_tokens[token_name] = token.content

if getattr(tokenizer, 'chat_template', None) is not None:
import hashlib

template_hash = hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest()
if template_hash != "0b629f783db54e02509999196956ff40":
logging.warning("Got unkown chat template")
else:
from omegaconf import OmegaConf, open_dict

with open_dict(model.cfg):
model.cfg.tokenizer.chat_template = OmegaConf.create(
{
Expand Down

0 comments on commit d82f53a

Please sign in to comment.