Skip to content

Commit cfff66b

Browse files
authored
Update convert_prompt.py (#736)
* Update convert_prompt.py * Update convert_prompt.py
1 parent 74594b5 commit cfff66b

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

src/convert_prompt.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,36 @@
3030
required=False,
3131
help="The subset you want to convert")
3232

33+
parser.add_argument("--sample",
34+
nargs=1,
35+
type=int,
36+
required=False,
37+
help="Number of samples from the dataset")
38+
3339
args = parser.parse_args()
3440
if args.dataset:
3541
dataset_name = args.dataset[0]
3642

3743
if args.subset:
3844
subset_name = args.subset[0]
3945

46+
if args.sample:
47+
sample_num = args.sample[0]
48+
4049
def get_dataset(dataset_name):
4150
dataset = load_dataset(dataset_name, split="train")
51+
if args.sample:
52+
cap = min(sample_num, len(dataset))
53+
dataset = random.choices(dataset, k = cap)
4254
# Load prompts for this dataset
4355
dataset_prompts = DatasetTemplates(dataset_name)
4456
return dataset, dataset_prompts
4557

4658
def get_subset(dataset_name, subset_name):
4759
dataset = load_dataset(dataset_name,subset_name, split="train")
60+
if args.sample:
61+
cap = min(sample_num, len(dataset))
62+
dataset = random.choices(dataset, k = cap)
4863
# Load prompts for this dataset and subset
4964
dataset_prompts = DatasetTemplates(f"{dataset_name}/{subset_name}")
5065
return dataset, dataset_prompts
@@ -59,6 +74,8 @@ def create_task(dataset, dataset_name, dataset_prompts):
5974
prompt = dataset_prompts[prompt_name]
6075
# Apply the prompt to the dataset
6176
data = {}
77+
data["Prompt Name"] = [prompt_name]
78+
data["Prompt id"] = [id]
6279
data["Contributors"] = []
6380
data["Source"] = [dataset_name]
6481
data["Categories"] = []
@@ -71,7 +88,7 @@ def create_task(dataset, dataset_name, dataset_prompts):
7188
data["Positive Examples"] = []
7289
data["Negative Examples"] = []
7390
data["Instances"] = []
74-
for i in range(min(6500,len(dataset))):
91+
for i in range(len(dataset)):
7592
result = prompt.apply(dataset[i])
7693
if len(result)==2:
7794
data["Instances"].append({
@@ -120,4 +137,4 @@ def save_json(data, dataset_name, prompt_name):
120137
dataset, dataset_prompts = get_dataset(dataset_name)
121138
if args.subset:
122139
dataset, dataset_prompts = get_subset(dataset_name, subset_name)
123-
create_task(dataset, dataset_name, dataset_prompts)
140+
create_task(dataset, dataset_name, dataset_prompts)

0 commit comments

Comments
 (0)