diff --git a/pdpbox/pdp.py b/pdpbox/pdp.py index de5a031..c545e7e 100644 --- a/pdpbox/pdp.py +++ b/pdpbox/pdp.py @@ -641,6 +641,8 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu 'inter_fill_alpha': 0.8, # fontsize for interact plot text 'inter_fontsize': 9, + # custom matplotlib.colors.Normalize object used for normalizing 'marginal effect' color scale + 'norm': None, } Returns @@ -752,6 +754,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu _plot_title(title=title, subtitle=subtitle, title_ax=title_ax, plot_params=plot_params) + norm = plot_params.get('norm', None) inter_params = {'plot_type': plot_type, 'x_quantile': x_quantile, 'plot_params': plot_params} if num_charts == 1: feature_names_adj = feature_names @@ -766,7 +769,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu else: inter_ax = plt.subplot(outer_grid[1]) fig.add_subplot(inter_ax) - _pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[0], inter_ax=inter_ax, norm=None, + _pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[0], inter_ax=inter_ax, norm=norm, feature_names=feature_names_adj, **inter_params) else: wspace = 0.3 @@ -788,7 +791,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu inner_inter_ax = plt.subplot(inner_grid[inner_idx]) fig.add_subplot(inner_inter_ax) _pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[inner_idx], inter_ax=inner_inter_ax, - norm=None, feature_names=feature_names_adj, **inter_params) + norm=norm, feature_names=feature_names_adj, **inter_params) inter_ax.append(inner_inter_ax) axes = {'title_ax': title_ax, 'pdp_inter_ax': inter_ax}