From 710635753a6201ec06072f71b71f99db277ee0cb Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 15 Dec 2023 11:32:22 +0100 Subject: [PATCH] get axes --- dianna/visualization/tabular.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dianna/visualization/tabular.py b/dianna/visualization/tabular.py index 0e6cc9a7..0a4b42b6 100644 --- a/dianna/visualization/tabular.py +++ b/dianna/visualization/tabular.py @@ -31,7 +31,7 @@ def plot_tabular( """ if not num_features: num_features = len(x) - fig = plt.figure() + fig, ax = plt.subplots() abs_values = [abs(i) for i in x] top_values = [x for _, x in sorted(zip(abs_values, x), reverse=True)][:num_features] top_features = [x for _, x in sorted(zip(abs_values, y), reverse=True)][ @@ -39,9 +39,9 @@ def plot_tabular( ] colors = ["r" if x >= 0 else "b" for x in top_values] - plt.barh(top_features, top_values, color=colors) - plt.xlabel(x_label) - plt.ylabel(y_label) + ax.barh(top_features, top_values, color=colors) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) if show_plot: plt.show()