Skip to content

Commit 3bde35a

Browse files
author
Damien Sileo
committed
new tasks
1 parent 75d61c0 commit 3bde35a

File tree

5 files changed

+1005
-611
lines changed

5 files changed

+1005
-611
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
from collections.abc import Iterable
2+
from dotwiz import DotWiz
3+
from dataclasses import dataclass
4+
from typing import Union
5+
import itertools
6+
import funcy as fc
7+
import exrex
8+
import magicattr
9+
import numpy as np
10+
import copy
11+
import datasets
12+
import time
13+
14+
def get_column_names(dataset):
15+
cn = dataset.column_names
16+
if type(cn)==dict:
17+
return set(fc.flatten(cn.values()))
18+
else:
19+
return set(cn)
20+
21+
22+
def sample_dataset(dataset,n=10000, n_eval=1000,seed=0):
23+
for k in dataset:
24+
n_k=(n if k=='train' else n_eval)
25+
if n_k and len(dataset[k])>n_k:
26+
dataset[k]=dataset[k].train_test_split(train_size=n_k,seed=seed)['train']
27+
return dataset
28+
29+
class Preprocessing(DotWiz):
30+
default_splits = ('train','validation','test')
31+
@staticmethod
32+
def __map_to_target(x,fn=lambda x:None, target=None):
33+
x[target]=fn(x)
34+
return x
35+
36+
def load(self):
37+
return self(datasets.load_dataset(self.dataset_name,self.config_name))
38+
39+
def __call__(self,dataset, max_rows=None, max_rows_eval=None,seed=0):
40+
dataset = self.pre_process(dataset)
41+
42+
# manage splits
43+
for k,v in zip(self.default_splits, self.splits):
44+
if v and k!=v:
45+
dataset[k]=dataset[v]
46+
del dataset[v]
47+
if k in dataset and not v: # obfuscated label
48+
del dataset[k]
49+
dataset = fix_splits(dataset)
50+
51+
for k in list(dataset.keys()):
52+
if k not in self.default_splits:
53+
del dataset[k]
54+
dataset = sample_dataset(dataset, max_rows, max_rows_eval,seed=seed)
55+
56+
# field annotated with a string
57+
substitutions = {v:k for k,v in self.to_dict().items()
58+
if (k and k not in {'splits','dataset_name','config_name'}
59+
and type(v)==str and k!=v)}
60+
61+
dataset=dataset.remove_columns([c for c in substitutions.values() if c in dataset['train'].features and c not in substitutions])
62+
dataset=dataset.rename_columns(substitutions)
63+
64+
# field annotated with a function
65+
for k in self.to_dict().keys():
66+
v=getattr(self, k)
67+
if callable(v) and k not in {"post_process","pre_process","load"}:
68+
dataset=dataset.map(self.__map_to_target,
69+
fn_kwargs={'fn':v,'target':k})
70+
71+
dataset=dataset.remove_columns(
72+
get_column_names(dataset)-set(self.to_dict().keys()))
73+
dataset = fix_labels(dataset)
74+
dataset = fix_splits(dataset) # again: label mapping changed
75+
dataset = self.post_process(dataset)
76+
return dataset
77+
78+
79+
@dataclass
80+
class cat(Preprocessing):
81+
fields:Union[str,list]=None
82+
separator:str=' '
83+
84+
def __call__(self, example=None):
85+
y=[np.char.array(example[f]) + sep
86+
for f,sep in zip(self.fields[::-1],itertools.repeat(self.separator))]
87+
y=list(sum(*y))
88+
if len(y)==1:
89+
y=y[0]
90+
return y
91+
92+
93+
def pretty(f):
94+
class pretty_f(DotWiz):
95+
def __init__(self,*args):
96+
self.__f_arg = f(*args)
97+
for a in args:
98+
setattr(self,'value',a)
99+
100+
def __call__(self, *args,**kwargs):
101+
return self.__f_arg(*args,**kwargs)
102+
103+
def __repr__(self):
104+
return f"{self.__f_arg.__qualname__ .split('.')[0]}({self.value})"
105+
return pretty_f
106+
107+
class dotgetter:
108+
def __init__(self, path=''):
109+
self.path=path
110+
111+
def __bool__(self):
112+
return bool(self.path)
113+
114+
def __getattr__(self, k):
115+
return self.__class__(f'{self.path}.{k}'.lstrip('.'))
116+
117+
def __getitem__(self, i):
118+
return self.__class__(f'{self.path}[{i}]')
119+
120+
def __call__(self, example=None):
121+
return magicattr.get(DotWiz(example), self.path)
122+
123+
def __hash__(self):
124+
return hash(self.path)
125+
126+
127+
@dataclass
128+
class ClassificationFields(Preprocessing):
129+
sentence1:str='sentence1'
130+
sentence2:str='sentence2'
131+
labels:str='labels'
132+
133+
@dataclass
134+
class Seq2SeqLMFields(Preprocessing):
135+
prompt:str='prompt'
136+
output:str='output'
137+
138+
@dataclass
139+
class TokenClassificationFields(Preprocessing):
140+
tokens:str='tokens'
141+
labels:str='labels'
142+
143+
@dataclass
144+
class MultipleChoiceFields(Preprocessing):
145+
inputs:str='input'
146+
choices:Iterable=tuple()
147+
labels:str='labels'
148+
choices_list:str=None
149+
def __post_init__(self):
150+
for i, c in enumerate(self.choices):
151+
setattr(self,f'choice{i}',c)
152+
delattr(self,'choices')
153+
if not self.choices_list:
154+
delattr(self,'choices_list')
155+
156+
def __call__(self,dataset, *args, **kwargs):
157+
dataset = super().__call__(dataset, *args, **kwargs)
158+
if self.choices_list:
159+
dataset = dataset.filter(lambda x: 1<len(x['choices_list']))
160+
n_options = min([len(x) for k in dataset for x in dataset[k]['choices_list']])
161+
n_options = min(5,n_options)
162+
dataset = dataset.map(self.flatten, fn_kwargs={'n_options':n_options})
163+
return dataset
164+
165+
@staticmethod
166+
def flatten(x, n_options=None):
167+
n_neg = n_options-1 if n_options else None
168+
choices = x['choices_list']
169+
label=x['labels']
170+
neg = choices[:label] + choices[label+1:]
171+
pos = choices[label]
172+
x['labels']=0
173+
x['choices_list']=[pos]+neg[:n_neg]
174+
for i,o in enumerate(x['choices_list']):
175+
x[f'choice{i}']=o
176+
del x['choices_list']
177+
return x
178+
179+
@dataclass
180+
class SharedFields:
181+
splits:list=Preprocessing.default_splits
182+
dataset_name:str = None
183+
config_name:str = None
184+
pre_process: callable = lambda x:x
185+
post_process: callable = lambda x:x
186+
#language:str="en"
187+
188+
189+
@dataclass
190+
class Classification(SharedFields, ClassificationFields): pass
191+
192+
@dataclass
193+
class MultipleChoice(SharedFields, MultipleChoiceFields): pass
194+
195+
@dataclass
196+
class TokenClassification(SharedFields, TokenClassificationFields): pass
197+
198+
@dataclass
199+
class Seq2SeqLM(SharedFields, Seq2SeqLMFields): pass
200+
201+
get=dotgetter()
202+
constant = pretty(fc.constantly)
203+
regen = lambda x: list(exrex.generate(x))
204+
205+
def name(label_name, classes):
206+
return lambda x:classes[x[label_name]]
207+
208+
def fix_splits(dataset):
209+
210+
if len(dataset)==1 and "train" not in dataset:
211+
k = list(dataset)[0]
212+
dataset['train'] = copy.deepcopy(dataset[k])
213+
del dataset[k]
214+
215+
if 'auxiliary_train' in dataset:
216+
del dataset['auxiliary_train']
217+
218+
if 'test' in dataset: # manage obfuscated labels
219+
if 'labels' in dataset['test'].features:
220+
if len(set(fc.flatten(dataset['test'].to_dict()['labels'])))==1:
221+
del dataset['test']
222+
223+
if 'validation' in dataset and 'train' not in dataset:
224+
train_validation = dataset['validation'].train_test_split(0.5, seed=0)
225+
dataset['train'] = train_validation['train']
226+
dataset['validation']=train_validation['test']
227+
228+
if 'validation' in dataset and 'test' not in dataset:
229+
validation_test = dataset['validation'].train_test_split(0.5, seed=0)
230+
dataset['validation'] = validation_test['train']
231+
dataset['test']=validation_test['test']
232+
233+
if 'train' in dataset and 'validation' not in dataset:
234+
train_val = dataset['train'].train_test_split(train_size=0.90, seed=0)
235+
dataset['train'] = train_val['train']
236+
dataset['validation']=train_val['test']
237+
238+
if 'test' in dataset and 'validation' not in dataset:
239+
validation_test = dataset['test'].train_test_split(0.5, seed=0)
240+
dataset['validation'] = validation_test['train']
241+
dataset['test']=validation_test['test']
242+
243+
if 'validation' not in dataset and 'test' not in dataset:
244+
train_val_test = dataset["train"].train_test_split(train_size=0.90, seed=0)
245+
val_test = train_val_test["test"].train_test_split(0.5, seed=0)
246+
dataset["train"] = train_val_test["train"]
247+
dataset["validation"] = val_test["train"]
248+
dataset["test"] = val_test["test"]
249+
250+
return dataset
251+
252+
def fix_labels(dataset, label_key='labels'):
253+
if type(dataset['train'][label_key][0]) in [int,list,float]:
254+
return dataset
255+
labels=set(fc.flatten(dataset[k][label_key] for k in {"train"}))
256+
if set(labels)=={'entailment','neutral','contradiction'}:
257+
order=lambda x:dict(fc.flip(enumerate(['entailment','neutral','contradiction']))).get(x,x)
258+
else:
259+
order=str
260+
labels=sorted(labels, key=order)
261+
dataset=dataset.cast_column(label_key, datasets.ClassLabel(names=labels))
262+
return dataset
263+
264+
def concatenate_dataset_dict(l):
265+
"""Concatenate a list of DatastDict objects sharing same splits and columns."""
266+
keys=l[0].keys()
267+
return datasets.DatasetDict({k: datasets.concatenate_datasets([x[k] for x in l]) for k in keys})

0 commit comments

Comments
 (0)