1
+ from collections .abc import Iterable
2
+ from dotwiz import DotWiz
3
+ from dataclasses import dataclass
4
+ from typing import Union
5
+ import itertools
6
+ import funcy as fc
7
+ import exrex
8
+ import magicattr
9
+ import numpy as np
10
+ import copy
11
+ import datasets
12
+ import time
13
+
14
+ def get_column_names (dataset ):
15
+ cn = dataset .column_names
16
+ if type (cn )== dict :
17
+ return set (fc .flatten (cn .values ()))
18
+ else :
19
+ return set (cn )
20
+
21
+
22
+ def sample_dataset (dataset ,n = 10000 , n_eval = 1000 ,seed = 0 ):
23
+ for k in dataset :
24
+ n_k = (n if k == 'train' else n_eval )
25
+ if n_k and len (dataset [k ])> n_k :
26
+ dataset [k ]= dataset [k ].train_test_split (train_size = n_k ,seed = seed )['train' ]
27
+ return dataset
28
+
29
+ class Preprocessing (DotWiz ):
30
+ default_splits = ('train' ,'validation' ,'test' )
31
+ @staticmethod
32
+ def __map_to_target (x ,fn = lambda x :None , target = None ):
33
+ x [target ]= fn (x )
34
+ return x
35
+
36
+ def load (self ):
37
+ return self (datasets .load_dataset (self .dataset_name ,self .config_name ))
38
+
39
+ def __call__ (self ,dataset , max_rows = None , max_rows_eval = None ,seed = 0 ):
40
+ dataset = self .pre_process (dataset )
41
+
42
+ # manage splits
43
+ for k ,v in zip (self .default_splits , self .splits ):
44
+ if v and k != v :
45
+ dataset [k ]= dataset [v ]
46
+ del dataset [v ]
47
+ if k in dataset and not v : # obfuscated label
48
+ del dataset [k ]
49
+ dataset = fix_splits (dataset )
50
+
51
+ for k in list (dataset .keys ()):
52
+ if k not in self .default_splits :
53
+ del dataset [k ]
54
+ dataset = sample_dataset (dataset , max_rows , max_rows_eval ,seed = seed )
55
+
56
+ # field annotated with a string
57
+ substitutions = {v :k for k ,v in self .to_dict ().items ()
58
+ if (k and k not in {'splits' ,'dataset_name' ,'config_name' }
59
+ and type (v )== str and k != v )}
60
+
61
+ dataset = dataset .remove_columns ([c for c in substitutions .values () if c in dataset ['train' ].features and c not in substitutions ])
62
+ dataset = dataset .rename_columns (substitutions )
63
+
64
+ # field annotated with a function
65
+ for k in self .to_dict ().keys ():
66
+ v = getattr (self , k )
67
+ if callable (v ) and k not in {"post_process" ,"pre_process" ,"load" }:
68
+ dataset = dataset .map (self .__map_to_target ,
69
+ fn_kwargs = {'fn' :v ,'target' :k })
70
+
71
+ dataset = dataset .remove_columns (
72
+ get_column_names (dataset )- set (self .to_dict ().keys ()))
73
+ dataset = fix_labels (dataset )
74
+ dataset = fix_splits (dataset ) # again: label mapping changed
75
+ dataset = self .post_process (dataset )
76
+ return dataset
77
+
78
+
79
+ @dataclass
80
+ class cat (Preprocessing ):
81
+ fields :Union [str ,list ]= None
82
+ separator :str = ' '
83
+
84
+ def __call__ (self , example = None ):
85
+ y = [np .char .array (example [f ]) + sep
86
+ for f ,sep in zip (self .fields [::- 1 ],itertools .repeat (self .separator ))]
87
+ y = list (sum (* y ))
88
+ if len (y )== 1 :
89
+ y = y [0 ]
90
+ return y
91
+
92
+
93
+ def pretty (f ):
94
+ class pretty_f (DotWiz ):
95
+ def __init__ (self ,* args ):
96
+ self .__f_arg = f (* args )
97
+ for a in args :
98
+ setattr (self ,'value' ,a )
99
+
100
+ def __call__ (self , * args ,** kwargs ):
101
+ return self .__f_arg (* args ,** kwargs )
102
+
103
+ def __repr__ (self ):
104
+ return f"{ self .__f_arg .__qualname__ .split ('.' )[0 ]} ({ self .value } )"
105
+ return pretty_f
106
+
107
+ class dotgetter :
108
+ def __init__ (self , path = '' ):
109
+ self .path = path
110
+
111
+ def __bool__ (self ):
112
+ return bool (self .path )
113
+
114
+ def __getattr__ (self , k ):
115
+ return self .__class__ (f'{ self .path } .{ k } ' .lstrip ('.' ))
116
+
117
+ def __getitem__ (self , i ):
118
+ return self .__class__ (f'{ self .path } [{ i } ]' )
119
+
120
+ def __call__ (self , example = None ):
121
+ return magicattr .get (DotWiz (example ), self .path )
122
+
123
+ def __hash__ (self ):
124
+ return hash (self .path )
125
+
126
+
127
+ @dataclass
128
+ class ClassificationFields (Preprocessing ):
129
+ sentence1 :str = 'sentence1'
130
+ sentence2 :str = 'sentence2'
131
+ labels :str = 'labels'
132
+
133
+ @dataclass
134
+ class Seq2SeqLMFields (Preprocessing ):
135
+ prompt :str = 'prompt'
136
+ output :str = 'output'
137
+
138
+ @dataclass
139
+ class TokenClassificationFields (Preprocessing ):
140
+ tokens :str = 'tokens'
141
+ labels :str = 'labels'
142
+
143
+ @dataclass
144
+ class MultipleChoiceFields (Preprocessing ):
145
+ inputs :str = 'input'
146
+ choices :Iterable = tuple ()
147
+ labels :str = 'labels'
148
+ choices_list :str = None
149
+ def __post_init__ (self ):
150
+ for i , c in enumerate (self .choices ):
151
+ setattr (self ,f'choice{ i } ' ,c )
152
+ delattr (self ,'choices' )
153
+ if not self .choices_list :
154
+ delattr (self ,'choices_list' )
155
+
156
+ def __call__ (self ,dataset , * args , ** kwargs ):
157
+ dataset = super ().__call__ (dataset , * args , ** kwargs )
158
+ if self .choices_list :
159
+ dataset = dataset .filter (lambda x : 1 < len (x ['choices_list' ]))
160
+ n_options = min ([len (x ) for k in dataset for x in dataset [k ]['choices_list' ]])
161
+ n_options = min (5 ,n_options )
162
+ dataset = dataset .map (self .flatten , fn_kwargs = {'n_options' :n_options })
163
+ return dataset
164
+
165
+ @staticmethod
166
+ def flatten (x , n_options = None ):
167
+ n_neg = n_options - 1 if n_options else None
168
+ choices = x ['choices_list' ]
169
+ label = x ['labels' ]
170
+ neg = choices [:label ] + choices [label + 1 :]
171
+ pos = choices [label ]
172
+ x ['labels' ]= 0
173
+ x ['choices_list' ]= [pos ]+ neg [:n_neg ]
174
+ for i ,o in enumerate (x ['choices_list' ]):
175
+ x [f'choice{ i } ' ]= o
176
+ del x ['choices_list' ]
177
+ return x
178
+
179
+ @dataclass
180
+ class SharedFields :
181
+ splits :list = Preprocessing .default_splits
182
+ dataset_name :str = None
183
+ config_name :str = None
184
+ pre_process : callable = lambda x :x
185
+ post_process : callable = lambda x :x
186
+ #language:str="en"
187
+
188
+
189
+ @dataclass
190
+ class Classification (SharedFields , ClassificationFields ): pass
191
+
192
+ @dataclass
193
+ class MultipleChoice (SharedFields , MultipleChoiceFields ): pass
194
+
195
+ @dataclass
196
+ class TokenClassification (SharedFields , TokenClassificationFields ): pass
197
+
198
+ @dataclass
199
+ class Seq2SeqLM (SharedFields , Seq2SeqLMFields ): pass
200
+
201
+ get = dotgetter ()
202
+ constant = pretty (fc .constantly )
203
+ regen = lambda x : list (exrex .generate (x ))
204
+
205
+ def name (label_name , classes ):
206
+ return lambda x :classes [x [label_name ]]
207
+
208
+ def fix_splits (dataset ):
209
+
210
+ if len (dataset )== 1 and "train" not in dataset :
211
+ k = list (dataset )[0 ]
212
+ dataset ['train' ] = copy .deepcopy (dataset [k ])
213
+ del dataset [k ]
214
+
215
+ if 'auxiliary_train' in dataset :
216
+ del dataset ['auxiliary_train' ]
217
+
218
+ if 'test' in dataset : # manage obfuscated labels
219
+ if 'labels' in dataset ['test' ].features :
220
+ if len (set (fc .flatten (dataset ['test' ].to_dict ()['labels' ])))== 1 :
221
+ del dataset ['test' ]
222
+
223
+ if 'validation' in dataset and 'train' not in dataset :
224
+ train_validation = dataset ['validation' ].train_test_split (0.5 , seed = 0 )
225
+ dataset ['train' ] = train_validation ['train' ]
226
+ dataset ['validation' ]= train_validation ['test' ]
227
+
228
+ if 'validation' in dataset and 'test' not in dataset :
229
+ validation_test = dataset ['validation' ].train_test_split (0.5 , seed = 0 )
230
+ dataset ['validation' ] = validation_test ['train' ]
231
+ dataset ['test' ]= validation_test ['test' ]
232
+
233
+ if 'train' in dataset and 'validation' not in dataset :
234
+ train_val = dataset ['train' ].train_test_split (train_size = 0.90 , seed = 0 )
235
+ dataset ['train' ] = train_val ['train' ]
236
+ dataset ['validation' ]= train_val ['test' ]
237
+
238
+ if 'test' in dataset and 'validation' not in dataset :
239
+ validation_test = dataset ['test' ].train_test_split (0.5 , seed = 0 )
240
+ dataset ['validation' ] = validation_test ['train' ]
241
+ dataset ['test' ]= validation_test ['test' ]
242
+
243
+ if 'validation' not in dataset and 'test' not in dataset :
244
+ train_val_test = dataset ["train" ].train_test_split (train_size = 0.90 , seed = 0 )
245
+ val_test = train_val_test ["test" ].train_test_split (0.5 , seed = 0 )
246
+ dataset ["train" ] = train_val_test ["train" ]
247
+ dataset ["validation" ] = val_test ["train" ]
248
+ dataset ["test" ] = val_test ["test" ]
249
+
250
+ return dataset
251
+
252
+ def fix_labels (dataset , label_key = 'labels' ):
253
+ if type (dataset ['train' ][label_key ][0 ]) in [int ,list ,float ]:
254
+ return dataset
255
+ labels = set (fc .flatten (dataset [k ][label_key ] for k in {"train" }))
256
+ if set (labels )== {'entailment' ,'neutral' ,'contradiction' }:
257
+ order = lambda x :dict (fc .flip (enumerate (['entailment' ,'neutral' ,'contradiction' ]))).get (x ,x )
258
+ else :
259
+ order = str
260
+ labels = sorted (labels , key = order )
261
+ dataset = dataset .cast_column (label_key , datasets .ClassLabel (names = labels ))
262
+ return dataset
263
+
264
+ def concatenate_dataset_dict (l ):
265
+ """Concatenate a list of DatastDict objects sharing same splits and columns."""
266
+ keys = l [0 ].keys ()
267
+ return datasets .DatasetDict ({k : datasets .concatenate_datasets ([x [k ] for x in l ]) for k in keys })
0 commit comments