Skip to content

bambi.interpret_plot_predictions() fails when we condition and color by the same categorical variable #870

@tomicapretto

Description

@tomicapretto

See this example

import bambi as bmb
import numpy as np
import pandas as pd

rng = np.random.default_rng(1234)
levels = list("ABC")
df = pd.DataFrame({"y": rng.normal(size=100), "factor": rng.choice(levels, size=100)})

model = bmb.Model("y ~ factor", data=df)
idata = model.fit()

bmb.interpret.plot_predictions(
    model=model,
    idata=idata,
    conditional="factor",
    subplot_kwargs={"group": "factor"}
);
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 12
      9 model = bmb.Model("y ~ factor", data=df)
     10 idata = model.fit()
---> 12 bmb.interpret.plot_predictions(
     13     model=model,
     14     idata=idata,
     15     conditional="factor",
     16     subplot_kwargs={"group": "factor"}
     17 );

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/bambi/interpret/plotting.py:230, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
    226     axes = plot_numeric(covariates, cap_data, transforms, legend, axes)
    227 elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype(
    228     cap_data[covariates.main]
    229 ):
--> 230     axes = plot_categoric(covariates, cap_data, legend, axes)
    231 else:
    232     raise ValueError("Main covariate must be numeric or categoric.")

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/bambi/interpret/plot_types.py:223, in plot_categoric(covariates, plot_data, legend, axes)
    221         idx = (plot_data[color] == clr).to_numpy()
    222         idxs = idxs_main + colors_offset[i]
--> 223         ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
    224         ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
    225 elif not "group" in covariates and "panel" in covariates:

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/matplotlib/__init__.py:1473, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1470 @functools.wraps(func)
   1471 def inner(ax, *args, data=None, **kwargs):
   1472     if data is None:
-> 1473         return func(
   1474             ax,
   1475             *map(sanitize_sequence, args),
   1476             **{k: sanitize_sequence(v) for k, v in kwargs.items()})
   1478     bound = new_sig.bind(ax, *args, **kwargs)
   1479     auto_label = (bound.arguments.get(label_namer)
   1480                   or bound.kwargs.get(label_namer))

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/matplotlib/axes/_axes.py:4787, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, plotnonfinite, **kwargs)
   4785 y = np.ma.ravel(y)
   4786 if x.size != y.size:
-> 4787     raise ValueError("x and y must be the same size")
   4789 if s is None:
   4790     s = (20 if mpl.rcParams['_internal.classic_mode'] else
   4791          mpl.rcParams['lines.markersize'] ** 2.0)

ValueError: x and y must be the same size

These are the problematic lines

idx = (plot_data[color] == clr).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")

Length of idxs is greater than sum(idx). In this case we need to also slice idxs with idx. But I'm not sure if that would cause issues with other scenarios.

Metadata

Metadata

Labels

buggood first issueIf you want to contribute but are not sure where to get started, this issue is for you!

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions