Skip to content

Commit 03cfc99

Browse files
committed
fix poor llama ptb
1 parent 27e8523 commit 03cfc99

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Diff for: owq/utils/datautils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import torch
55
from datasets import load_dataset
6-
from transformers import AutoTokenizer
6+
from transformers import AutoTokenizer, LlamaTokenizer
77

88
def get_wikitext2(nsamples, seed, seqlen, tokenizer, train):
99
if train:
@@ -89,6 +89,8 @@ def get_loaders(
8989
name, nsamples=128, seed=0, seqlen=2048, model='', train=True
9090
):
9191
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
92+
if isinstance(tokenizer, LlamaTokenizer) and 'ptb' in name:
93+
tokenizer.tokens_trie.data = {}
9294

9395
if 'wikitext2' in name:
9496
return get_wikitext2(nsamples, seed, seqlen, tokenizer, train)

0 commit comments

Comments
 (0)