Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pdpbox/pdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu
'inter_fill_alpha': 0.8,
# fontsize for interact plot text
'inter_fontsize': 9,
# custom matplotlib.colors.Normalize object used for normalizing 'marginal effect' color scale
'norm': None,
}

Returns
Expand Down Expand Up @@ -752,6 +754,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu

_plot_title(title=title, subtitle=subtitle, title_ax=title_ax, plot_params=plot_params)

norm = plot_params.get('norm', None)
inter_params = {'plot_type': plot_type, 'x_quantile': x_quantile, 'plot_params': plot_params}
if num_charts == 1:
feature_names_adj = feature_names
Expand All @@ -766,7 +769,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu
else:
inter_ax = plt.subplot(outer_grid[1])
fig.add_subplot(inter_ax)
_pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[0], inter_ax=inter_ax, norm=None,
_pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[0], inter_ax=inter_ax, norm=norm,
feature_names=feature_names_adj, **inter_params)
else:
wspace = 0.3
Expand All @@ -788,7 +791,7 @@ def pdp_interact_plot(pdp_interact_out, feature_names, plot_type='contour', x_qu
inner_inter_ax = plt.subplot(inner_grid[inner_idx])
fig.add_subplot(inner_inter_ax)
_pdp_inter_one(pdp_interact_out=pdp_interact_plot_data[inner_idx], inter_ax=inner_inter_ax,
norm=None, feature_names=feature_names_adj, **inter_params)
norm=norm, feature_names=feature_names_adj, **inter_params)
inter_ax.append(inner_inter_ax)

axes = {'title_ax': title_ax, 'pdp_inter_ax': inter_ax}
Expand Down