forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
alpaca_dataset.py
68 lines (56 loc) · 2.22 KB
/
alpaca_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
import copy
import json
import torch
from torch.utils.data import Dataset
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class InstructionDataset(Dataset):
def __init__(self, dataset_config, tokenizer, partition="train"):
self.ann = json.load(open(dataset_config.data_path))
if partition == "train":
self.ann = self.ann[200:]
else:
self.ann = self.ann[:200]
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
ann = self.ann[index]
if ann.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
example = prompt + ann["output"]
prompt = torch.tensor(
self.tokenizer.encode(prompt), dtype=torch.int64
)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(
example, dtype=torch.int64
)
labels = copy.deepcopy(example)
labels[: len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
return {
"input_ids": example.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
}