Skip to content

Commit 8e9c54d

Browse files
author
Damien Sileo
committed
fixed acces
1 parent a72c2f4 commit 8e9c54d

File tree

1 file changed

+80
-79
lines changed

1 file changed

+80
-79
lines changed

src/tasksource/access.py

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,81 @@
1-
from .preprocess import Preprocessing
2-
import re
3-
import pandas as pd
4-
from . import tasks
5-
from .metadata import dataset_rank
6-
from datasets import load_dataset
7-
import funcy as fc
8-
import os
9-
import copy
10-
from sorcery import dict_of
11-
12-
def parse_var_name(s):
13-
config_name,task_name = None,None
14-
if '__' in s and '___' not in s: # dataset__task
15-
dataset_name, task_name = s.split('__')
16-
elif '__' not in s.replace('___','') and '___' in s: #dataset___config
17-
dataset_name, config_name = s.split('___')
18-
elif '___' in s and '__' in s.split('___')[1]: #dataset___config__task
19-
dataset_name, config_task=s.split('___')
20-
config_name,task_name = config_task.split('__')
21-
else: # dataset
22-
dataset_name = s
23-
return dataset_name,config_name,task_name
24-
25-
def pretty_name(x):
26-
dn = x.dataset_name.split("/")[-1]
27-
cn = x.config_name if x.config_name else ""
28-
tn = x.task_name if x.task_name else ""
29-
return f"{dn}/{cn}/{tn}".replace('//','/').rstrip('/')
30-
31-
def list_tasks(tasks_path=f'{os.path.dirname(__file__)}/tasks.py'):
32-
task_order = open(tasks_path).readlines()
33-
task_order = [x.split('=')[0].rstrip() for x in task_order if '=' in x]
34-
task_order = [x for x in task_order if x.isidentifier()]
35-
task_order = fc.flip(dict(enumerate(task_order)))
36-
37-
l = []
38-
for key in dir(tasks):
39-
if key not in task_order:
40-
continue
41-
value=getattr(tasks, key)
42-
if isinstance(value,Preprocessing):
43-
dataset_name, config_name, task_name = parse_var_name(key)
44-
dataset_name = (value.dataset_name if value.dataset_name else dataset_name)
45-
config_name = (value.config_name if value.config_name else config_name)
46-
hasattr(value,key)
47-
l+=[{'dataset_name': dataset_name,
48-
'config_name' : config_name,
49-
'task_name': task_name,
50-
'preprocessing_name': key,
51-
'task_type': value.__class__.__name__,'mapping': value,
52-
'rank':task_order.get(key,None)}]
53-
df=pd.DataFrame(l).explode('config_name')
54-
df = df.sort_values('rank').reset_index(drop=True)
55-
df['id'] = df.apply(lambda x: pretty_name(x), axis=1)
56-
df.insert(0, 'id', df.pop('id'))
57-
del df['rank']
58-
return df
59-
60-
task_df = list_tasks()
61-
62-
def dict_to_query(d=dict(), **kwargs):
63-
d={**d,**kwargs}
64-
return '&'.join([f'`{k}`=="{v}"' for k,v in d.items()])
65-
66-
def load_preprocessing(tasks=tasks, **kwargs):
67-
y = task_df.copy().query(dict_to_query(**kwargs)).iloc[0]
68-
preprocessing= copy.deepcopy(getattr(tasks, y.preprocessing_name))
69-
for c in 'dataset_name','config_name':
70-
if not isinstance(getattr(preprocessing,c), str):
71-
setattr(preprocessing,c,getattr(y,c))
72-
return preprocessing
73-
74-
def load_task(id=None, dataset_name=None,config_name=None,task_name=None,preprocessing_name=None,
75-
max_rows=None, max_rows_eval=None):
76-
query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)
77-
query = {k:v for k,v in query.items() if v}
78-
preprocessing = load_preprocessing(**query)
79-
dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name)
1+
from .preprocess import Preprocessing
2+
import re
3+
import pandas as pd
4+
from . import tasks
5+
from .metadata import dataset_rank
6+
from datasets import load_dataset
7+
import funcy as fc
8+
import os
9+
import copy
10+
from sorcery import dict_of
11+
12+
def parse_var_name(s):
13+
config_name,task_name = None,None
14+
if '__' in s and '___' not in s: # dataset__task
15+
dataset_name, task_name = s.split('__')
16+
elif '__' not in s.replace('___','') and '___' in s: #dataset___config
17+
dataset_name, config_name = s.split('___')
18+
elif '___' in s and '__' in s.split('___')[1]: #dataset___config__task
19+
dataset_name, config_task=s.split('___')
20+
config_name,task_name = config_task.split('__')
21+
else: # dataset
22+
dataset_name = s
23+
return dataset_name,config_name,task_name
24+
25+
def pretty_name(x):
26+
dn = x.dataset_name.split("/")[-1]
27+
cn = x.config_name if x.config_name else ""
28+
tn = x.task_name if x.task_name else ""
29+
return f"{dn}/{cn}/{tn}".replace('//','/').rstrip('/')
30+
31+
def list_tasks(tasks_path=f'{os.path.dirname(__file__)}/tasks.py'):
32+
task_order = open(tasks_path).readlines()
33+
task_order = [x.split('=')[0].rstrip() for x in task_order if '=' in x]
34+
task_order = [x for x in task_order if x.isidentifier()]
35+
task_order = fc.flip(dict(enumerate(task_order)))
36+
37+
l = []
38+
for key in dir(tasks):
39+
if key not in task_order:
40+
continue
41+
value=getattr(tasks, key)
42+
if isinstance(value,Preprocessing):
43+
dataset_name, config_name, task_name = parse_var_name(key)
44+
dataset_name = (value.dataset_name if value.dataset_name else dataset_name)
45+
config_name = (value.config_name if value.config_name else config_name)
46+
hasattr(value,key)
47+
l+=[{'dataset_name': dataset_name,
48+
'config_name' : config_name,
49+
'task_name': task_name,
50+
'preprocessing_name': key,
51+
'task_type': value.__class__.__name__,'mapping': value,
52+
'rank':task_order.get(key,None)}]
53+
df=pd.DataFrame(l).explode('config_name')
54+
df = df.sort_values('rank').reset_index(drop=True)
55+
df['id'] = df.apply(lambda x: pretty_name(x), axis=1)
56+
df.insert(0, 'id', df.pop('id'))
57+
del df['rank']
58+
return df
59+
60+
task_df = list_tasks()
61+
62+
def dict_to_query(d=dict(), **kwargs):
63+
d={**d,**kwargs}
64+
return '&'.join([f'`{k}`=="{v}"' for k,v in d.items()])
65+
66+
def load_preprocessing(tasks=tasks, **kwargs):
67+
y = task_df.copy().query(dict_to_query(**kwargs)).iloc[0]
68+
preprocessing= copy.copy(getattr(tasks, y.preprocessing_name))
69+
#preprocessing= getattr(tasks, y.preprocessing_name)
70+
for c in 'dataset_name','config_name':
71+
if not isinstance(getattr(preprocessing,c), str):
72+
setattr(preprocessing,c,getattr(y,c))
73+
return preprocessing
74+
75+
def load_task(id=None, dataset_name=None,config_name=None,task_name=None,preprocessing_name=None,
76+
max_rows=None, max_rows_eval=None):
77+
query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)
78+
query = {k:v for k,v in query.items() if v}
79+
preprocessing = load_preprocessing(**query)
80+
dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name)
8081
return preprocessing(dataset,max_rows, max_rows_eval)

0 commit comments

Comments
 (0)