Skip to content

Commit 39af585

Browse files
committed
more flexible access to tasks, new tasks
1 parent 97c5996 commit 39af585

File tree

5 files changed

+370
-348
lines changed

5 files changed

+370
-348
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ install_requires =
2323
pandas
2424
numpy
2525
scipy
26+
sorcery
2627

2728
[options.packages.find]
2829
where = src

src/tasksource/access.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datasets import load_dataset
77
import funcy as fc
88
import os
9-
9+
from sorcery import dict_of
1010

1111
def parse_var_name(s):
1212
config_name,task_name = None,None
@@ -29,7 +29,6 @@ def pretty_name(x):
2929

3030
def list_tasks(tasks_path=f'{os.path.dirname(__file__)}/tasks.py'):
3131
task_order = open(tasks_path).readlines()
32-
task_order= task_order[:task_order.index('###END\n')]
3332
task_order = [x.split('=')[0].rstrip() for x in task_order if '=' in x]
3433
task_order = [x for x in task_order if x.isidentifier()]
3534
task_order = fc.flip(dict(enumerate(task_order)))
@@ -59,17 +58,22 @@ def list_tasks(tasks_path=f'{os.path.dirname(__file__)}/tasks.py'):
5958

6059
task_df = list_tasks()
6160

62-
def load_preprocessing(dataset_name, config_name=None, task_name=None):
63-
y = task_df
64-
y = y[y.dataset_name.map(lambda x:x==dataset_name)]
65-
y = y[y.config_name.map(lambda x:x==config_name)]
66-
y = y[y.task_name.map(lambda x:x==task_name)]
67-
return getattr(tasks,y.preprocessing_name.iloc[0])
68-
61+
def dict_to_query(d=dict(), **kwargs):
62+
d={**d,**kwargs}
63+
return '&'.join([f'`{k}`=="{v}"' for k,v in d.items()])
6964

65+
def load_preprocessing(tasks=tasks, **kwargs):
66+
y = task_df.query(dict_to_query(**kwargs)).iloc[0]
67+
preprocessing= getattr(tasks, y.preprocessing_name)
68+
for c in 'dataset_name','config_name':
69+
if not isinstance(getattr(preprocessing,c), str):
70+
setattr(preprocessing,c,getattr(y,c))
71+
return preprocessing
7072

71-
def load_task(dataset_name,config_name=None,task_name=None,
73+
def load_task(id=None, dataset_name=None,config_name=None,task_name=None,preprocessing_name=None,
7274
max_rows=None, max_rows_eval=None):
73-
dataset = load_dataset(dataset_name,config_name)
74-
preprocessing = load_preprocessing(dataset_name,config_name,task_name)
75+
query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)
76+
query = {k:v for k,v in query.items() if v}
77+
preprocessing = load_preprocessing(**query)
78+
dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name)
7579
return preprocessing(dataset,max_rows, max_rows_eval)

src/tasksource/preprocess.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def __map_to_target(x,fn=lambda x:None, target=None):
3333
x[target]=fn(x)
3434
return x
3535

36+
def load(self):
37+
return self(datasets.load_dataset(self.dataset_name,self.config_name))
38+
3639
def __call__(self,dataset, max_rows=None, max_rows_eval=None):
3740
dataset = self.pre_process(dataset)
3841
for k,v in zip(self.default_splits, self.splits):
@@ -53,7 +56,7 @@ def __call__(self,dataset, max_rows=None, max_rows_eval=None):
5356
and type(v)==str and k!=v)})
5457
for k in self.to_dict().keys():
5558
v=getattr(self, k)
56-
if callable(v) and k not in {"post_process","pre_process"}:
59+
if callable(v) and k not in {"post_process","pre_process","load"}:
5760
dataset=dataset.map(self.__map_to_target,
5861
fn_kwargs={'fn':v,'target':k})
5962

src/tasksource/tasks.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
config_name=set(get_dataset_config_names("metaeval/babi_nli"))-{"agents-motivations"}
1818
) # agents-motivations task is not as clear-cut as the others
1919

20-
def ling_nli_postprocess(ds):
21-
return ds.cast_column('labels', ClassLabel(
22-
names=['entailment','neutral','contradiction']))
2320

24-
ling_nli = Classification("premise_original","hypothesis_original","label",
25-
dataset_name="metaeval/lingnli", post_process=ling_nli_postprocess
26-
)
21+
ling_nli = Classification("premise","hypothesis","label",dataset_name="metaeval/lingnli")
2722

2823

2924
sick__label = Classification('sentence_A','sentence_B','label')
@@ -124,23 +119,23 @@ def ling_nli_postprocess(ds):
124119
add_one_rte = Classification("premise","hypothesis","label",
125120
dataset_name="pietrolesci/add_one_rte",splits=["train","dev","test"])
126121

127-
def imppres_post_process(ds,prefix=''):
122+
def _imppres_post_process(ds,prefix=''):
128123
# imppres entailment definition is either purely semantic or purely pragmatic
129124
# because of that, we assign differentiate the labels from anli/mnli notation
130125
return ds.cast_column('labels', ClassLabel(
131126
names=[f'imppres{prefix}_entailment',f'imppres{prefix}_neutral',f'imppres{prefix}_contradiction']))
132127

133128
imppres__presupposition = imppres__prag = Classification("premise","hypothesis","gold_label",
134129
dataset_name="metaeval/imppres", config_name=imppres_presupposition,
135-
post_process=imppres_post_process)
130+
post_process=_imppres_post_process)
136131

137132
imppres__prag = Classification("premise","hypothesis","gold_label_prag",
138133
dataset_name="metaeval/imppres", config_name=imppres_implicature,
139-
post_process=lambda x: imppres_post_process(x,'_prag'))
134+
post_process=lambda x: _imppres_post_process(x,'_prag'))
140135

141136
imppres__log = Classification("premise","hypothesis","gold_label_log",
142137
dataset_name="metaeval/imppres", config_name=imppres_implicature,
143-
post_process=lambda x: imppres_post_process(x,'_log'))
138+
post_process=lambda x: _imppres_post_process(x,'_log'))
144139

145140

146141
glue__diagnostics = Classification("premise","hypothesis","label",
@@ -312,13 +307,13 @@ def imppres_post_process(ds,prefix=''):
312307

313308
swag=MultipleChoice(cat(["sent1","sent2"]),regen("ending[0-3]"),"label")
314309

315-
def split_choices(s):
310+
def _split_choices(s):
316311
import re
317312
return [x.rstrip(', ') for x in re.split(r'[a-e] \) (.*?)',s) if x.strip(', ')]
318313

319314
math_qa = MultipleChoice(
320315
'Problem',
321-
choices_list = lambda x: split_choices(x['options']),
316+
choices_list = lambda x: _split_choices(x['options']),
322317
labels = lambda x:'abcde'.index(x['correct'])
323318
)
324319

@@ -500,15 +495,14 @@ def split_choices(s):
500495

501496
metaeval_linguisticprobing = Classification("sentence", labels="label", dataset_name="metaeval/linguisticprobing",
502497
config_name=['subj_number',
503-
'word_content',
504498
'obj_number',
505499
'past_present',
506500
'sentence_length',
507501
'top_constituents',
508502
'tree_depth',
509503
'coordination_inversion',
510504
'odd_man_out',
511-
'bigram_shift']
505+
'bigram_shift']#+['word_content'] #too many labels
512506
)
513507

514508
metaeval_crowdflower = Classification("text", labels="label",
@@ -664,5 +658,22 @@ def split_choices(s):
664658
config_name='proto_qa'
665659
)
666660

667-
###END
668-
################### END OF SUPPORT ######################
661+
wiki_qa = Classification("question","answer","label")
662+
663+
cycic_classification = Classification("question",labels="correct_answer",
664+
dataset_name = "metaeval/cycic_classification")
665+
cycic_mc = MultipleChoice("question", choices=regen('answer\_option[0-4]'), labels="correct_answer",
666+
dataset_name = "metaeval/cycic_multiplechoice")
667+
668+
669+
def _preprocess_chatgpt_detection(ex):
670+
import random
671+
label=random.random()<=0.5
672+
ex['label']=label
673+
ex['answer']=[ex['human_answers'],ex['chatgpt_answers']][label]
674+
return ex
675+
676+
chatgpt_detection = Classification("question","answer","label",
677+
dataset_name = 'Hello-SimpleAI/HC3', config_name="all",
678+
pre_process=lambda dataset:dataset.map(_preprocess_chatgpt_detection)
679+
)

0 commit comments

Comments
 (0)