Skip to content

Commit 96dd28a

Browse files
authored
Merge pull request #27 from satra/fix-group
enh: add group variable, regression test, update README
2 parents f283288 + 5819e3e commit 96dd28a

File tree

8 files changed

+199
-63
lines changed

8 files changed

+199
-63
lines changed

README.md

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,22 @@ pip install pydra-ml
2929

3030
This repo installs `pydraml` a CLI to allow usage without any programming.
3131

32-
To test the CLI for a classification example, copy the `pydra_ml/tests/data/breast_cancer.csv` and
32+
To test the CLI for a classification example, copy the `pydra_ml/tests/data/breast_cancer.csv` and
3333
`short-spec.json.sample` to a folder and run.
3434

3535
```
3636
$ pydraml -s short-spec.json.sample
3737
```
38-
To check a regression example, copy `pydra_ml/tests/data/diabetes_table.csv` and `diabetes_spec.json`
39-
to a folder and run.
38+
To check a regression example, copy `pydra_ml/tests/data/diabetes_table.csv` and
39+
`diabetes_spec.json` to a folder and run.
4040

4141
```
4242
$ pydraml -s diabetes_spec.json
4343
```
4444

45-
For each case pydra-ml will generate a result folder with the spec file name that includes
46-
`test-{metric}-{timestamp}.png` file for each metric together with a pickled results file
47-
containing all the scores from the model evaluations.
45+
For each case pydra-ml will generate a result folder with the spec file name that
46+
includes `test-{metric}-{timestamp}.png` file for each metric together with a
47+
pickled results file containing all the scores from the model evaluations.
4848

4949
```
5050
$ pydraml --help
@@ -82,14 +82,17 @@ will want to generate `x_indices` programmatically.
8282
group.
8383
- *x_indices*: Numeric (0-based) or string list of columns to use as input features
8484
- *target_vars*: String list of target variable (at present only one is supported)
85+
- *group_var*: String to indicate column to use for grouping
8586
- *n_splits*: Number of shuffle split iterations to use
8687
- *test_size*: Fraction of data to use for test set in each iteration
8788
- *clf_info*: List of scikit-learn classifiers to use.
8889
- *permute*: List of booleans to indicate whether to generate a null model or not
8990
- *gen_shap*: Boolean indicating whether shap values are generated
9091
- *nsamples*: Number of samples to use for shap estimation
9192
- *l1_reg*: Type of regularizer to use for shap estimation
92-
- *plot_top_n_shap*: Number or proportion of top SHAP values to plot (e.g., 16 or 0.1 for top 10%). Set to 1.0 (float) to plot all features or 1 (int) to plot top first feature.
93+
- *plot_top_n_shap*: Number or proportion of top SHAP values to plot (e.g., 16
94+
or 0.1 for top 10%). Set to 1.0 (float) to plot all features or 1 (int) to plot
95+
top first feature.
9396
- *metrics*: scikit-learn metric to use
9497

9598
## `clf_info` specification
@@ -113,6 +116,7 @@ then an empty dictionary **MUST** be provided as parameter 3.
113116
"x_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
114117
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
115118
"target_vars": ["target"],
119+
"group_var": null,
116120
"n_splits": 100,
117121
"test_size": 0.2,
118122
"clf_info": [
@@ -140,25 +144,51 @@ then an empty dictionary **MUST** be provided as parameter 3.
140144

141145
## Output:
142146
The workflow will output:
143-
- `results-{timestamp}.pkl` containing 1 list per model used. For example, if assigned to variable `results`, it is accessed through `results[0]` to `results[N]`
144-
(if `permute: [false,true]` then it will output the model trained on the labels first `results[0]` and the model trained on permuted labels second `results[1]`.
147+
- `results-{timestamp}.pkl` containing 1 list per model used. For example, if
148+
assigned to variable `results`, it is accessed through `results[0]` to `results[N]`
149+
(if `permute: [false,true]` then it will output the model trained on the labels
150+
first `results[0]` and the model trained on permuted labels second `results[1]`.
145151
Each model contains:
146-
- `dict` accesed through `results[0][0]` with model information: `{'ml_wf.clf_info': ['sklearn.neural_network', 'MLPClassifier', {'alpha': 1, 'max_iter': 1000}], 'ml_wf.permute': False}`
147-
- `pydra Result obj` accesed through `results[0][1]` with attribute `output` which itself has attributes:
152+
- `dict` accesed through `results[0][0]` with model information:
153+
`{'ml_wf.clf_info': ['sklearn.neural_network', 'MLPClassifier',
154+
{'alpha': 1, 'max_iter': 1000}], 'ml_wf.permute': False}`
155+
- `pydra Result obj` accesed through `results[0][1]` with attribute `output`
156+
which itself has attributes:
148157
- `feature_names`: from the columns of the data csv.
149158
And the following attributes organized in N lists for N bootstrapping samples:
150159
- `output`: N lists, each one with two lists for true and predicted labels.
151160
- `score`: N lists each one containing M different metric scores.
152-
- `shaps`: N lists each one with a list of shape (P,F) where P is the amount of predictions and F the different SHAP values for each feature. `shaps` is empty if `gen_shap` is set to `false` or if `permute` is set to true.
153-
- One figure per metric with performance distribution across splits (with or without null distribution trained on permuted labels)
161+
- `shaps`: N lists each one with a list of shape (P,F) where P is the
162+
amount of predictions and F the different SHAP values for each feature.
163+
`shaps` is empty if `gen_shap` is set to `false` or if `permute` is set
164+
to true.
165+
- One figure per metric with performance distribution across splits (with or
166+
without null distribution trained on permuted labels)
167+
- One figure per any metric with the word `score` in it reporting the results of
168+
a Wilcoxon signed rank test. The figure reports one-sided stats values as the
169+
color of each cell and the corresponding `-log10(pvalue)` as the annotation.
170+
Higher numbers indicate stronger effect (color) and lower p-values (annotation).
171+
The actual numeric values are stored in a correspondingly named pkl file.
154172
- `shap-{timestamp}` dir
155173
- SHAP values are computed for each prediction in each split's test set
156-
(e.g., 30 bootstrapping splits with 100 prediction will create (30,100) array). The mean is taken across predictions for each split (e.g., resulting in a (64,30) array for 64 features and 30 bootstrapping samples).
157-
- For binary classification, a more accurate display of feature importance obtained by splitting predictions into TP, TN, FP, and FN,
158-
which in turn can allow for error auditing (i.e., what a model pays attention to when making incorrect/false predictions)
159-
- `quadrant_indexes.pkl`: The TP, TN, FP, FN indexes are saved in as a `dict` with one `key` per model (permuted models without SHAP values will be skipped automatically), and each key `values` being a bootstrapping split.
160-
- `summary_values_shap_{model_name}_{prediction_type}.csv` contains all SHAP values and summary statistics ranked by the mean SHAP value across bootstrapping splits. A sample_n column can be empty or NaN if this split did not have the type of prediction in the filename (e.g., you may not have FNs or FPs in a given split with high performance).
161-
- `summary_shap_{model_name}_{plot_top_n_shap}.png` contains SHAP value summary statistics for all features (set to 1.0) or only the top N most important features for better visualization.
174+
(e.g., 30 bootstrapping splits with 100 prediction will create (30,100) array).
175+
The mean is taken across predictions for each split (e.g., resulting in a
176+
(64,30) array for 64 features and 30 bootstrapping samples).
177+
- For binary classification, a more accurate display of feature importance
178+
obtained by splitting predictions into TP, TN, FP, and FN, which in turn can
179+
allow for error auditing (i.e., what a model pays attention to when making
180+
incorrect/false predictions)
181+
- `quadrant_indexes.pkl`: The TP, TN, FP, FN indexes are saved in as a
182+
`dict` with one `key` per model (permuted models without SHAP values will
183+
be skipped automatically), and each key `values` being a bootstrapping split.
184+
- `summary_values_shap_{model_name}_{prediction_type}.csv` contains all
185+
SHAP values and summary statistics ranked by the mean SHAP value across
186+
bootstrapping splits. A sample_n column can be empty or NaN if this split
187+
did not have the type of prediction in the filename (e.g., you may not
188+
have FNs or FPs in a given split with high performance).
189+
- `summary_shap_{model_name}_{plot_top_n_shap}.png` contains SHAP value
190+
summary statistics for all features (set to 1.0) or only the top N most
191+
important features for better visualization.
162192

163193

164194
## Developer installation
@@ -171,10 +201,14 @@ cd pydra-ml
171201
pip install -e .[dev]
172202
```
173203

174-
It is also useful to install pre-commit:
204+
It is also useful to install pre-commit, which takes care of styling when
205+
committing code. When pre-commit is used you may have to run git commit twice,
206+
since pre-commit may make additional changes to your code for styling and will
207+
not commit these changes by default:
208+
175209
```
176210
pip install pre-commit
177-
pre-commit
211+
pre-commit install
178212
```
179213

180214
### Project structure

diabetes_spec.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
{"filename": "./diabetes_table.csv",
1+
{"filename": "diabetes_table.csv",
22
"x_indices": [0,1,2,3,4,5,6,7,8,9],
33
"target_vars": ["target"],
4+
"group_var": null,
45
"n_splits": 4,
56
"test_size": 0.2,
67
"clf_info": [
@@ -14,4 +15,4 @@
1415
"l1_reg": "aic",
1516
"plot_top_n_shap": 10,
1617
"metrics":["explained_variance_score","mean_squared_error","mean_absolute_error"]
17-
}
18+
}

long-spec.json.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"x_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
33
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
44
"target_vars": ["target"],
5+
"group_var": null,
56
"n_splits": 100,
67
"test_size": 0.2,
78
"clf_info": [

pydra_ml/classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
5555
filename=wf.lzin.filename,
5656
x_indices=wf.lzin.x_indices,
5757
target_vars=wf.lzin.target_vars,
58+
group=wf.lzin.group_var,
5859
)
5960
)
6061
wf.add(

pydra_ml/report.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pickle
55
import pandas as pd
66
import numpy as np
7-
from sklearn.metrics import accuracy_score, explained_variance_score
7+
from sklearn.metrics import explained_variance_score
8+
from scipy.stats import wilcoxon
89
import seaborn as sns
910
import matplotlib.pyplot as plt
1011

@@ -42,7 +43,6 @@ def plot_summary(summary, output_dir=None, filename="shap_plot", plot_top_n_shap
4243
plt.tight_layout()
4344
plt.show(block=False)
4445
plt.savefig(output_dir + f"summary_{filename}.png", dpi=100)
45-
return
4646

4747

4848
def shaps_to_summary(
@@ -73,7 +73,6 @@ def shaps_to_summary(
7373
filename=filename,
7474
plot_top_n_shap=plot_top_n_shap,
7575
)
76-
return
7776

7877

7978
def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
@@ -112,7 +111,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
112111
shaps_i = shaps[split_i] # all shap values for this bootstrapping split
113112
y_true = y_true_and_preds[split_i][0]
114113
y_pred = y_true_and_preds[split_i][1]
115-
#split_performance = accuracy_score(y_true, y_pred)
114+
# split_performance = accuracy_score(y_true, y_pred)
116115
split_performance = explained_variance_score(y_true, y_pred)
117116

118117
# split prediction indexes into TP, TN, FP, FN, good for error auditing
@@ -166,7 +165,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
166165
plot_top_n_shap=plot_top_n_shap,
167166
)
168167
save_obj(indexes_all, shap_dir + "indexes_quadrant.pkl")
169-
return
168+
170169

171170
def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
172171
# Create shap_dir
@@ -197,7 +196,7 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
197196
"lp": [],
198197
"lm": [],
199198
"um": [],
200-
"up": []
199+
"up": [],
201200
} # this is key with shape (F, N) where F is feature_names, N is mean shap values across splits
202201
# Obtain values for each bootstrapping split, then append summary statistics to shaps_n_splits
203202
for split_i in range(n_splits):
@@ -208,8 +207,8 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
208207

209208
# split prediction indexes into upper, median, lower, good for error auditing
210209
indexes = {"lp": [], "lm": [], "um": [], "up": []}
211-
q=np.array([25,50,75])
212-
prc=np.percentile(y_true,q)
210+
q = np.array([25, 50, 75])
211+
prc = np.percentile(y_true, q)
213212
for i in range(len(y_true)):
214213
if prc[0] >= y_pred[i]:
215214
indexes["lp"].append(i)
@@ -259,7 +258,45 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
259258
plot_top_n_shap=plot_top_n_shap,
260259
)
261260
save_obj(indexes_all, shap_dir + "indexes_quadrant.pkl")
262-
return
261+
262+
263+
def compute_pairwise_stats(df):
264+
"""Run Wilcoxon signed rank tests across pairs of classifiers.
265+
266+
When comparing a classifier to itself, compare to its null distribution.
267+
A one sided test is used.
268+
269+
Assumes that the dataframe has three keys: Classifier, type, and score
270+
with type referring to either the data distribution or the null distribution
271+
272+
"""
273+
N = len(df.Classifier.unique())
274+
effects = np.zeros((N, N)) * np.nan
275+
pvalues = np.zeros((N, N)) * np.nan
276+
for idx1, group1 in enumerate(df.groupby("Classifier")):
277+
filter = group1[1].apply(lambda x: x.type == "data", axis=1).values
278+
group1df = group1[1].iloc[filter, :]
279+
filter = group1[1].apply(lambda x: x.type == "null", axis=1).values
280+
group1nulldf = group1[1].iloc[filter, :]
281+
for idx2, group2 in enumerate(df.groupby("Classifier")):
282+
filter = group2[1].apply(lambda x: x.type == "data", axis=1).values
283+
group2df = group2[1].iloc[filter, :]
284+
if group1[0] != group2[0]:
285+
stat, pval = wilcoxon(
286+
group1df["score"].values,
287+
group2df["score"].values,
288+
alternative="greater",
289+
)
290+
else:
291+
stat, pval = wilcoxon(
292+
group1df["score"].values,
293+
group1nulldf["score"].values,
294+
alternative="greater",
295+
)
296+
effects[idx1, idx2] = stat
297+
pvalues[idx1, idx2] = pval
298+
return effects, pvalues
299+
263300

264301
def gen_report(
265302
results, prefix, metrics, gen_shap=True, output_dir="./", plot_top_n_shap=16
@@ -284,6 +321,7 @@ def gen_report(
284321
},
285322
ignore_index=True,
286323
)
324+
order = [group[0] for group in df.groupby("Classifier")]
287325
for name, subdf in df.groupby("metric"):
288326
sns.set(style="whitegrid", palette="pastel", color_codes=True)
289327
sns.set_context("talk")
@@ -296,7 +334,7 @@ def gen_report(
296334
split=True,
297335
inner="quartile",
298336
hue_order=["data", "null"],
299-
order=[group[0] for group in df.groupby("Classifier")],
337+
order=order,
300338
)
301339
ax.set_ylabel(name)
302340
sns.despine(left=True)
@@ -306,16 +344,43 @@ def gen_report(
306344
timestamp = timestamp.replace(":", "").replace("-", "")
307345
plt.savefig(f"test-{name}-{timestamp}.png")
308346

347+
# Create comparison stats table if the metric is a score
348+
if "score" in name:
349+
effects, pvalues, = compute_pairwise_stats(subdf)
350+
plt.figure(figsize=(8, 8))
351+
ax = sns.heatmap(
352+
effects,
353+
annot=np.fix(-np.log10(pvalues)),
354+
yticklabels=order,
355+
xticklabels=order,
356+
cbar=True,
357+
square=True,
358+
)
359+
ax.xaxis.set_ticks_position("top")
360+
plt.savefig(f"stats-{name}-{timestamp}.png")
361+
save_obj(
362+
dict(effects=effects, pvalues=pvalues, order=order),
363+
f"stats-{name}-{timestamp}.pkl",
364+
)
365+
309366
# create SHAP summary csv and figures
310367
if gen_shap:
311-
reg_metrics=["explained_variance_score","max_error",
312-
"mean_absolute_error","mean_squared_error",
313-
"mean_squared_log_error","median_absolute_error",
314-
"r2_score","mean_poisson_deviance",
315-
"mean_gamma_deviance"
316-
]
368+
reg_metrics = [
369+
"explained_variance_score",
370+
"max_error",
371+
"mean_absolute_error",
372+
"mean_squared_error",
373+
"mean_squared_log_error",
374+
"median_absolute_error",
375+
"r2_score",
376+
"mean_poisson_deviance",
377+
"mean_gamma_deviance",
378+
]
317379
if any([True for x in metrics if x in reg_metrics]):
318-
gen_report_shap_regres(results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap)
380+
gen_report_shap_regres(
381+
results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap
382+
)
319383
else:
320-
gen_report_shap_class(results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap)
321-
384+
gen_report_shap_class(
385+
results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap
386+
)

pydra_ml/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def read_file(filename, x_indices=None, target_vars=None, group=None):
2323
if group is None:
2424
groups = list(range(X.shape[0]))
2525
else:
26-
groups = data[:, [group]]
26+
groups = data[group].values
2727
feature_names = list(X.columns)
2828
return X.values, Y.values, groups, feature_names
2929

0 commit comments

Comments
 (0)