Skip to content

Commit a5f7643

Browse files
author
Damien Sileo
committed
new tasks
1 parent 606727e commit a5f7643

File tree

5 files changed

+1688
-448
lines changed

5 files changed

+1688
-448
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import random
2+
from datasets import DatasetDict, Dataset
3+
from sorcery import dict_of
4+
import string
5+
6+
improper_labels =['recast/recast_kg_relations','linguisticprobing',"lex_glue/scotus",'lexical_relation_classification/ROOT09',"pragmeval/squinky","pragmeval/emobank",'pragmeval/persuasiveness']
7+
improper_labels += ['glue/stsb', 'sick/relatedness', 'joci', 'utilitarianism', 'amazon_counterfactual/en', 'toxic_conversations', 'ethos/multilabel', 'lex_glue/eurlex', 'lex_glue/unfair_tos', 'app_reviews', 'humicroedit/subtask-1', 'stackoverflow-questions', 'go_emotions/simplified', 'google_wellformed_query', 'has_part', 'blog_authorship_corpus/age', 'promptCoherence', 'Sarcasm_News_Headline', 'auditor_review/demo-org--auditor_review', 'Dynasent_Disagreement', 'Politeness_Disagreement', 'SBIC_Disagreement', 'SChem_Disagreement', 'Dilemmas_Disagreement', 'sts-companion', 'acceptability-prediction', 'chaos-mnli-ambiguity', 'headline_cause/en_simple', 'oasst1_dense_flat', 'civil_comments']
8+
9+
improper_labels += ['stsb_multi_mt','MLMA_hate_speech','icl-symbol-tuning-instruct','zero-shot-label-nli']
10+
11+
def render_options(options):
12+
options = [f'"{x}"' for x in options]
13+
return f"{', '.join(options[:-1])} or {options[-1]}"
14+
15+
def render_classification(text,options,answer):
16+
example = 'A→B' if text.startswith('A:') else 'the following'
17+
inputs = f'With no explanation, label {example} with either {render_options(options)}.\n{text}'
18+
targets = f"{answer}."
19+
return dict_of(inputs,targets)
20+
21+
def render_token_classification(tokens,options,labels):
22+
prefix = f'With no explanation, label each line with {render_options(options)} preceded by ":".\n'
23+
inputs = prefix+"\n".join(tokens)
24+
targets = "\n".join([':'.join(x) for x in zip(tokens,labels)])
25+
return dict_of(inputs,targets)
26+
27+
def render_multiple_choice(prompt, options, labels):
28+
inputs=(prompt+'\n' if prompt else '')
29+
letters = string.ascii_uppercase[:len(options)]
30+
inputs=f'With no explanation, chose the best option from {render_options(letters)}. {inputs}'
31+
for letter, option in zip(letters, options):
32+
inputs+=f'\n{letter}: {option}'
33+
targets = f'{letters[labels]}.'
34+
return dict_of(inputs, targets)
35+
36+
def negative_sample_options(y, labels,N=4):
37+
if len(labels)<N:
38+
return labels
39+
else:
40+
return [y]+random.sample([x for x in labels if x!=y], N-1)
41+
42+
def shuffle_choices(x):
43+
choices = sorted([k for k in x if 'choice' in k])
44+
choices_texts = [x[c] for c in choices]
45+
correct_choice =choices_texts[x['labels']]
46+
random.shuffle(choices_texts)
47+
for c, ct in zip(choices, choices_texts):
48+
x[c]=ct
49+
x["labels"]=choices_texts.index(correct_choice)
50+
return x
51+
52+
def recast_dataset_classification_to_mc(dataset,sep="[SEP]",N=4):
53+
54+
def recast_split(d,N=N):
55+
labels = d.features['labels']
56+
df=d.to_pandas()
57+
df['inputs'] = df.sentence1
58+
if "sentence2" in df:
59+
df['inputs'] +=sep + df.sentence2
60+
61+
N=min(N, len(labels.names))
62+
df['choices']=df.apply(lambda x:negative_sample_options(labels.int2str(x['labels']), labels.names,N),axis=1)
63+
df['labels']=df.apply(lambda x:x['choices'].index(labels.int2str(x['labels'])),axis=1)
64+
65+
for i in range(N):
66+
df[f'choice{i}']= "This example is " + df.choices.map(lambda x:x[i])
67+
68+
choices = [f'choice{i}' for i in range(N)]
69+
return Dataset.from_pandas(df[['inputs',*choices,'labels']],preserve_index=False)
70+
71+
return DatasetDict({k: recast_split(v) for k,v in dataset.items()})
72+
73+
74+
def recast_instruct(dataset):
75+
features = dataset['train'].features
76+
labels = features['labels']
77+
78+
if "sentence1" in features:
79+
task_type='Classification'
80+
if "choice0" in features:
81+
task_type = "MultipleChoice"
82+
if "tokens" in features:
83+
task_type = "TokenClassification"
84+
85+
def recast_MultipleChoice(x):
86+
x=shuffle_choices(x)
87+
choices = sorted([k for k in x if 'choice' in k])
88+
if all([x[c] in x['inputs'] for c in choices]):
89+
return {"inputs":x['inputs'], 'targets': x[f"choice{x['labels']}"].strip()+"."}
90+
else:
91+
return render_multiple_choice(x['inputs'],[x[c] for c in choices],x['labels'])
92+
93+
def recast_TokenClassification(x):
94+
distractors = list(labels.feature.names)
95+
x_labels = [labels.feature.int2str(y) for y in x['labels']]
96+
labels_set= list({labels.feature.int2str(y) for y in x['labels']})
97+
options=list(dict.fromkeys(labels_set+distractors))[:max(len(labels_set),10)]
98+
return render_token_classification(x['tokens'],options,x_labels)
99+
100+
def recast_Classification(x):
101+
if 'sentence2' in x:
102+
text=f"A: {x['sentence1']}\nB: {x['sentence2']}"
103+
else:
104+
text=x['sentence1']
105+
106+
answer=labels.int2str(x['labels']).strip()
107+
options= negative_sample_options(answer, labels._int2str)
108+
return render_classification(text, options, answer)
109+
110+
dataset = dataset.map(eval(f"recast_{task_type}"))
111+
dataset = dataset.remove_columns([k for k in features if k not in ['inputs','targets']])
112+
return dataset
113+

0 commit comments

Comments
 (0)