Skip to content

Commit

Permalink
save qwen_model in meta file
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoffatt2 committed Dec 10, 2024
1 parent 1ac1ef9 commit edd4653
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions data/template/tokenizer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def tokenize(self, data):
meta = {
"vocab_size": self.vocab_size,
"tokenizer": "qwen2",
"qwen2_model": self.huggingface_model_name,
"special_tokens": self.special_tokens,
}
self.save_meta(meta)
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def parse_args():
)
training_group.add_argument('--gpt2_type', default='gpt2', type=str)
training_group.add_argument(
'--qwen_model',
'--qwen2_model',
default='qwen2_1p5b',
choices=['qwen2_0p5b', 'qwen2_1p5b', 'qwen2_7b'],
type=str,
help="Type of Qwen model to use for initialization."
help="Type of Qwen2 model to use for initialization."
)
training_group.add_argument('--prev_run_ckpt', default='', type=str)
training_group.add_argument('--csv_ckpt_dir', default='', type=str)
Expand Down Expand Up @@ -698,12 +698,12 @@ def setup(self):

elif self.args.init_from.startswith('qwen2'):

assert self.args.qwen2_method in model_variation_dictionary
assert self.args.qwen2_model in model_variation_dictionary

self.iter_num = 0 # for starting from scratch
self.best_val_loss = 1e9 # really big number

variation_dict, huggingface_name = model_variation_dictionary[self.args.qwen2_method]
variation_dict, huggingface_name = model_variation_dictionary[self.args.qwen2_model]
# NOTE: the hierarchy of parameters goes: 1)variation_dict >> 2)cmd-line args >> 3)GPTConfig defaults
for k in variation_dict:
self.model_args[k] = variation_dict[k]
Expand Down Expand Up @@ -783,7 +783,7 @@ def load_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(meta["qwen2_model"], trust_remote_code=True)
self.encode = lambda s: tokenizer.encode(s, add_special_tokens=False)
self.decode = lambda l: tokenizer.decode(l)
print(f"Using Qwen tokenizer: {meta['qwen2_model']}")
print(f"Using Qwen2 tokenizer: {meta['qwen2_model']}")
elif 'tokenizer' in meta and meta['tokenizer'] == 'custom_char_with_byte_fallback':
self.stoi = meta['stoi']
self.itos = meta['itos']
Expand Down

0 comments on commit edd4653

Please sign in to comment.