Skip to content

Commit

Permalink
fix poor llama ptb
Browse files Browse the repository at this point in the history
  • Loading branch information
jinjungyu committed Mar 7, 2024
1 parent 27e8523 commit 03cfc99
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion owq/utils/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, LlamaTokenizer

def get_wikitext2(nsamples, seed, seqlen, tokenizer, train):
if train:
Expand Down Expand Up @@ -89,6 +89,8 @@ def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model='', train=True
):
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
if isinstance(tokenizer, LlamaTokenizer) and 'ptb' in name:
tokenizer.tokens_trie.data = {}

if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, tokenizer, train)
Expand Down

0 comments on commit 03cfc99

Please sign in to comment.