Skip to content

Commit

Permalink
Add cli option to enable check_add for shap.
Browse files Browse the repository at this point in the history
loucerac committed May 4, 2023
1 parent 04cbaa8 commit 5e47b8b
Showing 5 changed files with 19 additions and 6 deletions.
13 changes: 13 additions & 0 deletions drexml/cli/cli.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,16 @@
]


_check_add_option = [
click.option(
"--add/--no-add",
is_flag=True,
default=True,
help="Check the additivity when computing the SHAP values.",
)
]


def copy_files(ctx, fnames):
"""Copy files from tmp to ml folder."""
for fname in fnames:
@@ -155,6 +165,7 @@ def build_cmd(ctx):
str(int(ctx["n_gpus"])),
str(ctx["n_cpus"]),
str(int(ctx["debug"])),
str(int(ctx["add"])),
ctx["mode"],
]

@@ -230,6 +241,7 @@ def stability(**kwargs):

@main.command()
@add_options(_debug_option)
@add_options(_check_add_option)
@add_options(_n_iters_option)
@add_options(_n_gpus_option)
@add_options(_n_cpus_option)
@@ -247,6 +259,7 @@ def explain(**kwargs):

@main.command()
@add_options(_debug_option)
@add_options(_check_add_option)
@add_options(_n_iters_option)
@add_options(_n_gpus_option)
@add_options(_n_cpus_option)
4 changes: 2 additions & 2 deletions drexml/cli/stab_explainer.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
if __name__ == "__main__":
import sys

data_folder, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv)
data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv)
this_seed = 82
queue = multiprocessing.Queue()

@@ -112,7 +112,7 @@ def runner(model, bkg, new, check_add, use_gpu):
model=this_model,
bkg=features_bkg,
new=gb,
check_add=True,
check_add=add,
use_gpu=gpu,
)
for _, gb in features_val.groupby(
2 changes: 1 addition & 1 deletion drexml/cli/stab_scorer.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@

# client = Client('127.0.0.1:8786')
# pylint: disable=unbalanced-tuple-unpacking
data_folder, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv)
data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv)
model, stab_cv, features, targets = get_stab(
data_folder, n_splits, n_cpus, debug, n_iters
)
2 changes: 1 addition & 1 deletion drexml/cli/stab_trainer.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

# client = Client('127.0.0.1:8786')
# pylint: disable=unbalanced-tuple-unpacking
data_path, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv)
data_path, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv)
model, stab_cv, X, Y = get_stab(data_path, n_splits, n_cpus, debug, n_iters)

for i, split in enumerate(stab_cv):
4 changes: 2 additions & 2 deletions drexml/utils.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ def parse_stab(argv):
bool
Debug flag.
"""
_, data_folder, n_iters, n_gpus, n_cpus, debug, mode = argv
_, data_folder, n_iters, n_gpus, n_cpus, debug, add, mode = argv
n_iters = int(n_iters)
data_folder = Path(data_folder)
n_gpus = int(n_gpus)
@@ -79,7 +79,7 @@ def parse_stab(argv):
else:
n_splits = 3 if debug else 100

return data_folder, n_iters, n_gpus, n_cpus, n_splits, debug
return data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add


def get_stab(data_folder, n_splits, n_cpus, debug, n_iters):

0 comments on commit 5e47b8b

Please sign in to comment.