Skip to content

Commit

Permalink
Support Polars (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Aug 15, 2024
1 parent e8f258a commit d89e0b6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
14 changes: 12 additions & 2 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class BART(Distribution):
Parameters
----------
X : TensorLike
X : PyTensor Variable, Pandas/Polars DataFrame or Numpy array
The covariate matrix.
Y : TensorLike
Y : PyTensor Variable, Pandas/Polar DataFrame/Series,or Numpy array
The response vector.
m : int
Number of trees.
Expand Down Expand Up @@ -204,6 +204,16 @@ def preprocess_xy(
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()

try:
import polars as pl

if isinstance(X, (pl.Series, pl.DataFrame)):
X = X.to_numpy()
if isinstance(Y, (pl.Series, pl.DataFrame)):
Y = Y.to_numpy()
except ImportError:
pass

Y = Y.astype(float)
X = X.astype(float)

Expand Down
10 changes: 5 additions & 5 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _prepare_plot_data(
Parameters
----------
X : PyTensor Variable, Pandas DataFrame or Numpy array
X : PyTensor Variable, Pandas DataFrame, Polars DataFrame or Numpy array
Input data.
Y : array-like
Target data.
Expand Down Expand Up @@ -585,9 +585,9 @@ def _prepare_plot_data(
if isinstance(X, Variable):
X = X.eval()

if hasattr(X, "columns") and hasattr(X, "values"):
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
x_names = list(X.columns)
X = X.values
X = X.to_numpy()
else:
x_names = []

Expand Down Expand Up @@ -750,9 +750,9 @@ def plot_variable_importance( # noqa: PLR0915
else:
shape = bartrv.eval().shape[0]

if hasattr(X, "columns") and hasattr(X, "values"):
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.values
X = X.to_numpy()

n_vars = X.shape[1]

Expand Down

0 comments on commit d89e0b6

Please sign in to comment.