Skip to content

Commit 273e2b0

Browse files
authored
Merge pull request #28 from satra/fix-figures
Fix figures
2 parents 96dd28a + 8d57049 commit 273e2b0

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,6 @@ pre-commit install
213213

214214
### Project structure
215215

216-
- `tasks.py` contain the annotated Pydra tasks.
217-
- `classifier.py` contains the Pydra workflow.
216+
- `tasks.py` contain the Python functions.
217+
- `classifier.py` contains the Pydra workflow and the annotated tasks.
218218
- `report.py` contains report generation code.

pydra_ml/report.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,28 +336,44 @@ def gen_report(
336336
hue_order=["data", "null"],
337337
order=order,
338338
)
339+
ax.xaxis.set_ticks_position("top")
340+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center")
339341
ax.set_ylabel(name)
342+
ax.legend(loc="center right", bbox_to_anchor=(1.2, 0.5), ncol=1)
343+
ax.tick_params(axis="both", which="both", length=0)
340344
sns.despine(left=True)
345+
plt.tight_layout()
346+
341347
import datetime
342348

343349
timestamp = datetime.datetime.utcnow().isoformat()
344350
timestamp = timestamp.replace(":", "").replace("-", "")
345351
plt.savefig(f"test-{name}-{timestamp}.png")
352+
plt.close()
346353

347354
# Create comparison stats table if the metric is a score
348355
if "score" in name:
349356
effects, pvalues, = compute_pairwise_stats(subdf)
350-
plt.figure(figsize=(8, 8))
357+
sns.set(style="whitegrid", palette="pastel", color_codes=True)
358+
sns.set_context("talk")
359+
plt.figure(figsize=(2 * len(order), 2 * len(order)))
360+
# plt.figure(figsize=(8, 8))
351361
ax = sns.heatmap(
352362
effects,
353363
annot=np.fix(-np.log10(pvalues)),
354364
yticklabels=order,
355365
xticklabels=order,
356366
cbar=True,
367+
cbar_kws={"shrink": 0.7},
357368
square=True,
358369
)
359370
ax.xaxis.set_ticks_position("top")
371+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center")
372+
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha="right")
373+
ax.tick_params(axis="both", which="both", length=0)
374+
plt.tight_layout()
360375
plt.savefig(f"stats-{name}-{timestamp}.png")
376+
plt.close()
361377
save_obj(
362378
dict(effects=effects, pvalues=pvalues, order=order),
363379
f"stats-{name}-{timestamp}.pkl",

0 commit comments

Comments
 (0)