Skip to content

Commit d224cf6

Browse files
authored
Merge pull request #24 from satra/enh-split-tasks
several pydra related updates to reflect better usage
2 parents f47b7d8 + b5fa649 commit d224cf6

File tree

5 files changed

+102
-35
lines changed

5 files changed

+102
-35
lines changed

pydra_ml/classifier.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,41 @@
11
#!/usr/bin/env python
22

33
import pydra
4+
from pydra.mark import task, annotate
5+
from pydra.utils.messenger import AuditFlag, FileMessenger
6+
import typing as ty
47
import os
58
from .tasks import read_file, gen_splits, train_test_kernel, calc_metric, get_shap
69
from .report import gen_report
710

11+
# Create pydra tasks
12+
read_file_pdt = task(
13+
annotate(
14+
{
15+
"return": {
16+
"X": ty.Any,
17+
"Y": ty.Any,
18+
"groups": ty.Any,
19+
"feature_names": ty.Any,
20+
}
21+
}
22+
)(read_file)
23+
)
24+
25+
gen_splits_pdt = task(
26+
annotate({"return": {"splits": ty.Any, "split_indices": ty.Any}})(gen_splits)
27+
)
28+
29+
train_test_kernel_pdt = task(
30+
annotate({"return": {"output": ty.Any, "model": ty.Any}})(train_test_kernel)
31+
)
32+
33+
calc_metric_pdt = task(
34+
annotate({"return": {"score": ty.Any, "output": ty.Any}})(calc_metric)
35+
)
36+
37+
get_shap_pdt = task(annotate({"return": {"shaps": ty.Any}})(get_shap))
38+
839

940
def gen_workflow(inputs, cache_dir=None, cache_locations=None):
1041
wf = pydra.Workflow(
@@ -13,18 +44,21 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
1344
**inputs,
1445
cache_dir=cache_dir,
1546
cache_locations=cache_locations,
47+
audit_flags=AuditFlag.ALL,
48+
messengers=FileMessenger(),
49+
messenger_args={"message_dir": os.path.join(os.getcwd(), "messages")},
1650
)
1751
wf.split(["clf_info", "permute"])
1852
wf.add(
19-
read_file(
53+
read_file_pdt(
2054
name="readcsv",
2155
filename=wf.lzin.filename,
2256
x_indices=wf.lzin.x_indices,
2357
target_vars=wf.lzin.target_vars,
2458
)
2559
)
2660
wf.add(
27-
gen_splits(
61+
gen_splits_pdt(
2862
name="gensplit",
2963
n_splits=wf.lzin.n_splits,
3064
test_size=wf.lzin.test_size,
@@ -34,26 +68,25 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
3468
)
3569
)
3670
wf.add(
37-
train_test_kernel(
71+
train_test_kernel_pdt(
3872
name="fit_clf",
3973
X=wf.readcsv.lzout.X,
4074
y=wf.readcsv.lzout.Y,
4175
train_test_split=wf.gensplit.lzout.splits,
4276
split_index=wf.gensplit.lzout.split_indices,
4377
clf_info=wf.lzin.clf_info,
4478
permute=wf.lzin.permute,
45-
metrics=wf.lzin.metrics,
4679
)
4780
)
4881
wf.fit_clf.split("split_index")
4982
wf.add(
50-
calc_metric(
83+
calc_metric_pdt(
5184
name="metric", output=wf.fit_clf.lzout.output, metrics=wf.lzin.metrics
5285
)
5386
)
5487
wf.metric.combine("fit_clf.split_index")
5588
wf.add(
56-
get_shap(
89+
get_shap_pdt(
5790
name="shap",
5891
X=wf.readcsv.lzout.X,
5992
permute=wf.lzin.permute,
@@ -75,16 +108,21 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
75108
return wf
76109

77110

78-
def run_workflow(wf, plugin, plugin_args):
111+
def run_workflow(wf, plugin, plugin_args, specfile="localspec"):
79112
cwd = os.getcwd()
80113
with pydra.Submitter(plugin=plugin, **plugin_args) as sub:
81114
sub(runnable=wf)
82115
results = wf.result(return_inputs=True)
83116
os.chdir(cwd)
117+
84118
import pickle as pk
85119
import datetime
86120

87121
timestamp = datetime.datetime.utcnow().isoformat()
122+
timestamp = timestamp.replace(":", "").replace("-", "")
123+
result_dir = f"out-{os.path.basename(specfile)}-{timestamp}"
124+
os.makedirs(result_dir)
125+
os.chdir(result_dir)
88126
with open(f"results-{timestamp}.pkl", "wb") as fp:
89127
pk.dump(results, fp)
90128

@@ -95,4 +133,5 @@ def run_workflow(wf, plugin, plugin_args):
95133
gen_shap=wf.inputs.gen_shap,
96134
plot_top_n_shap=wf.inputs.plot_top_n_shap,
97135
)
136+
os.chdir(cwd)
98137
return results

pydra_ml/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ def main(specfile, plugin, cache):
4545
if plugin[0] == "cf" and key == "n_procs":
4646
value = int(value)
4747
plugin_args[key] = value
48-
run_workflow(wf, plugin[0], plugin_args)
48+
run_workflow(wf, plugin[0], plugin_args, specfile)

pydra_ml/report.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ def save_obj(obj, path):
1414
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
1515

1616

17-
def plot_summary(
18-
summary,
19-
output_dir=None,
20-
filename="shap_LogisticRegression_all_predictions",
21-
plot_top_n_shap=16,
22-
):
17+
def plot_summary(summary, output_dir=None, filename="shap_plot", plot_top_n_shap=16):
2318
plt.clf()
2419
plt.figure(figsize=(8, 12))
2520
# plot without all bootstrapping values
@@ -54,7 +49,7 @@ def shaps_to_summary(
5449
shaps_n_splits,
5550
feature_names=None,
5651
output_dir=None,
57-
filename="shap_LogisticRegression_all_predictions",
52+
filename="shap_summary",
5853
plot_top_n_shap=16,
5954
):
6055
shaps_n_splits.columns = [
@@ -84,6 +79,7 @@ def shaps_to_summary(
8479
def gen_report_shap(results, output_dir="./", plot_top_n_shap=16):
8580
# Create shap_dir
8681
timestamp = datetime.datetime.utcnow().isoformat()
82+
timestamp = timestamp.replace(":", "").replace("-", "")
8783
shap_dir = output_dir + f"shap-{timestamp}/"
8884
os.mkdir(shap_dir)
8985

@@ -213,6 +209,7 @@ def gen_report(
213209
import datetime
214210

215211
timestamp = datetime.datetime.utcnow().isoformat()
212+
timestamp = timestamp.replace(":", "").replace("-", "")
216213
plt.savefig(f"test-{name}-{timestamp}.png")
217214

218215
# create SHAP summary csv and figures

pydra_ml/tasks.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#!/usr/bin/env python
22

3-
import pydra
4-
import typing as ty
5-
import numpy as np
63

4+
def read_file(filename, x_indices=None, target_vars=None, group=None):
5+
"""Read a CSV data file
76
8-
@pydra.mark.task
9-
@pydra.mark.annotate(
10-
{"return": {"X": ty.Any, "Y": ty.Any, "groups": ty.Any, "feature_names": ty.Any}}
11-
)
12-
def read_file(filename, x_indices=None, target_vars=None, group="groups"):
7+
:param filename: CSV filename containing a column header
8+
:param x_indices: integer or string indices
9+
:param target_vars: Target variables to use
10+
:param group: CSV column name containing grouping information
11+
:return: Tuple containing train data, target data, groups, features
12+
"""
1313
import pandas as pd
1414

1515
data = pd.read_csv(filename)
@@ -20,17 +20,27 @@ def read_file(filename, x_indices=None, target_vars=None, group="groups"):
2020
else:
2121
raise ValueError(f"{x_indices} is not a list of string or ints")
2222
Y = data[target_vars]
23-
if group in data.keys():
24-
groups = data[:, [group]]
25-
else:
23+
if group is None:
2624
groups = list(range(X.shape[0]))
25+
else:
26+
groups = data[:, [group]]
2727
feature_names = list(X.columns)
2828
return X.values, Y.values, groups, feature_names
2929

3030

31-
@pydra.mark.task
32-
@pydra.mark.annotate({"return": {"splits": ty.Any, "split_indices": ty.Any}})
3331
def gen_splits(n_splits, test_size, X, Y, groups=None, random_state=0):
32+
"""Generate train-test splits for the data.
33+
34+
Uses GroupShuffleSplit from scikit-learn
35+
36+
:param n_splits: Number of splits
37+
:param test_size: fractional test size
38+
:param X: Sample feature data
39+
:param Y: Sample target data
40+
:param groups: Grouping of sample data for shufflesplit
41+
:param random_state: randomization for shuffling (default 0)
42+
:return: splits and indices to splits
43+
"""
3444
from sklearn.model_selection import GroupShuffleSplit
3545

3646
gss = GroupShuffleSplit(
@@ -41,9 +51,17 @@ def gen_splits(n_splits, test_size, X, Y, groups=None, random_state=0):
4151
return train_test_splits, split_indices
4252

4353

44-
@pydra.mark.task
45-
@pydra.mark.annotate({"return": {"output": ty.Any, "model": ty.Any}})
46-
def train_test_kernel(X, y, train_test_split, split_index, clf_info, permute, metrics):
54+
def train_test_kernel(X, y, train_test_split, split_index, clf_info, permute):
55+
"""Core model fitting and predicting function
56+
57+
:param X: Input features
58+
:param y: Target variables
59+
:param train_test_split: split indices
60+
:param split_index: which index to use
61+
:param clf_info: how to construct the classifier
62+
:param permute: whether to run it in permuted mode or not
63+
:return: outputs, trained classifier with sample indices
64+
"""
4765
from sklearn.preprocessing import StandardScaler
4866
from sklearn.pipeline import Pipeline
4967
import numpy as np
@@ -68,9 +86,13 @@ def train_test_kernel(X, y, train_test_split, split_index, clf_info, permute, me
6886
return (y[test_index], predicted), (pipe, train_index, test_index)
6987

7088

71-
@pydra.mark.task
72-
@pydra.mark.annotate({"return": {"score": ty.Any, "output": ty.Any}})
7389
def calc_metric(output, metrics):
90+
"""Calculate the scores for the predicted outputs
91+
92+
:param output: true, predicted output
93+
:param metrics: list of metrics to evaluate
94+
:return: list of scores and pass the output
95+
"""
7496
score = []
7597
for metric in metrics:
7698
metric_mod = __import__("sklearn.metrics", fromlist=[metric])
@@ -79,9 +101,17 @@ def calc_metric(output, metrics):
79101
return score, output
80102

81103

82-
@pydra.mark.task
83-
@pydra.mark.annotate({"return": {"shaps": ty.Any}})
84104
def get_shap(X, permute, model, gen_shap=False, nsamples="auto", l1_reg="aic"):
105+
"""Compute shap information for the test data
106+
107+
:param X: sample data
108+
:param permute: whether model was permuted or not
109+
:param model: model containing trained classifier and train/test index
110+
:param gen_shap: whether to generate shap features
111+
:param nsamples: number of samples for shap evaluation
112+
:param l1_reg: L1 regularization for shap evaluation
113+
:return: shap values for each test sample
114+
"""
85115
if permute or not gen_shap:
86116
return []
87117
pipe, train_index, test_index = model

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers =
2424
python_requires = >= 3.7
2525
install_requires =
2626
pydra >= 0.6
27+
psutil
2728
scikit-learn
2829
seaborn
2930
click

0 commit comments

Comments
 (0)