Skip to content

Commit b9c68a9

Browse files
committed
fix: names
1 parent 0718c34 commit b9c68a9

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

pydra_ml/report.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,10 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
182182
indexes_all = {}
183183

184184
for model_results in results:
185-
model_name = model_results[0].get("ml_wf.clf_info")[1]
185+
model_name = model_results[0].get("ml_wf.clf_info")
186+
if isinstance(model_name[0], list):
187+
model_name = model_name[-1]
188+
model_name = model_name[1]
186189
indexes_all[model_name] = []
187190
shaps = model_results[
188191
1
@@ -312,13 +315,16 @@ def gen_report(
312315
if not isinstance(score, list):
313316
score = [score]
314317

315-
clf = val[0][prefix + ".clf_info"][1]
316-
if isinstance(clf, list):
318+
clf = val[0][prefix + ".clf_info"]
319+
if isinstance(clf[0], list):
317320
clf = clf[-1][1]
321+
else:
322+
clf = clf[1]
318323
if "Classifier" in clf:
319324
name = clf.split("Classifier")[0]
320325
else:
321326
name = clf.split("Regressor")[0]
327+
name = name.split("CV")[0]
322328
permute = val[0][prefix + ".permute"]
323329
for scoreval in score:
324330
for idx, metric in enumerate(metrics):

0 commit comments

Comments
 (0)