Skip to content

Commit

Permalink
fix(cli): fix saving plots when they are also shown immediately
Browse files Browse the repository at this point in the history
  • Loading branch information
nmaarnio committed Nov 15, 2024
1 parent 3f80048 commit 1783195
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,17 +683,17 @@ def parallel_coordinates_cli(
curved_lines=curved_lines,
)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

echo_str_end = "."
if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100%")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("Parallel coordinates plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("Parallel coordinates plot completed")


# PCA FOR RASTER DATA
Expand Down Expand Up @@ -3540,16 +3540,17 @@ def plot_roc_curve_cli(

_ = plot_roc_curve(y_true=y_true, y_prob=y_prob)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100% \n")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("ROC curve plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("ROC curve plot completed")


@app.command()
Expand Down Expand Up @@ -3580,16 +3581,17 @@ def plot_det_curve_cli(

_ = plot_det_curve(y_true=y_true, y_prob=y_prob)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100% \n")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("DET curve plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("DET curve plot completed")


@app.command()
Expand Down Expand Up @@ -3619,16 +3621,17 @@ def plot_precision_recall_curve_cli(

_ = plot_precision_recall_curve(y_true=y_true, y_prob=y_prob)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100% \n")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("Precision-Recall curve plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("Precision-Recall curve plot completed")


@app.command()
Expand Down Expand Up @@ -3658,16 +3661,17 @@ def plot_calibration_curve_cli(

_ = plot_calibration_curve(y_true=y_true, y_prob=y_prob, n_bins=n_bins)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100% \n")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("Calibration curve plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("Calibration curve plot completed")


@app.command()
Expand All @@ -3693,16 +3697,17 @@ def plot_confusion_matrix_cli(
matrix = confusion_matrix(y_true, y_pred)
_ = plot_confusion_matrix(confusion_matrix=matrix)
typer.echo("Progress: 75%")
if show_plot:
plt.show()

if output_file is not None:
dpi = "figure" if save_dpi is None else save_dpi
plt.savefig(output_file, dpi=dpi)
echo_str_end = f", output figure saved to {output_file}."
typer.echo("Progress: 100% \n")
typer.echo(f"Output figure saved to {output_file}.")

typer.echo("Confusion matrix plot completed" + echo_str_end)
if show_plot:
plt.show()

typer.echo("Progress: 100%")
typer.echo("Confusion matrix plot completed.")


@app.command()
Expand Down

0 comments on commit 1783195

Please sign in to comment.