22"""Utility function for variable selection and bart interpretability."""
33
44import warnings
5- from typing import Any , Callable , Optional , Union
5+ from collections .abc import Callable
6+ from typing import Any , TypeVar
67
78import matplotlib .pyplot as plt
89import numpy as np
1819
1920from .tree import Tree
2021
21- TensorLike = Union [ npt .NDArray , pt .TensorVariable ]
22+ TensorLike = TypeVar ( "TensorLike" , npt .NDArray , pt .TensorVariable )
2223
2324
2425def _sample_posterior (
2526 all_trees : list [list [Tree ]],
2627 X : TensorLike ,
2728 rng : np .random .Generator ,
28- size : Optional [ Union [ int , tuple [int , ...]]] = None ,
29- excluded : Optional [ list [int ]] = None ,
29+ size : int | tuple [int , ...] | None = None ,
30+ excluded : list [int ] | None = None ,
3031 shape : int = 1 ,
3132) -> npt .NDArray :
3233 """
@@ -51,7 +52,7 @@ def _sample_posterior(
5152 X = X .eval ()
5253
5354 if size is None :
54- size_iter : Union [ list , tuple ] = (1 ,)
55+ size_iter : list | tuple = (1 ,)
5556 elif isinstance (size , int ):
5657 size_iter = [size ]
5758 else :
@@ -78,9 +79,9 @@ def _sample_posterior(
7879
7980def plot_convergence (
8081 idata : Any ,
81- var_name : Optional [ str ] = None ,
82+ var_name : str | None = None ,
8283 kind : str = "ecdf" ,
83- figsize : Optional [ tuple [float , float ]] = None ,
84+ figsize : tuple [float , float ] | None = None ,
8485 ax = None ,
8586) -> None :
8687 """
@@ -114,23 +115,23 @@ def plot_convergence(
114115def plot_ice (
115116 bartrv : Variable ,
116117 X : npt .NDArray ,
117- Y : Optional [ npt .NDArray ] = None ,
118- var_idx : Optional [ list [int ]] = None ,
119- var_discrete : Optional [ list [int ]] = None ,
120- func : Optional [ Callable ] = None ,
121- centered : Optional [ bool ] = True ,
118+ Y : npt .NDArray | None = None ,
119+ var_idx : list [int ] | None = None ,
120+ var_discrete : list [int ] | None = None ,
121+ func : Callable | None = None ,
122+ centered : bool | None = True ,
122123 samples : int = 100 ,
123124 instances : int = 30 ,
124- random_seed : Optional [ int ] = None ,
125+ random_seed : int | None = None ,
125126 sharey : bool = True ,
126127 smooth : bool = True ,
127128 grid : str = "long" ,
128129 color = "C0" ,
129130 color_mean : str = "C0" ,
130131 alpha : float = 0.1 ,
131- figsize : Optional [ tuple [float , float ]] = None ,
132- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
133- ax : Optional [ plt .Axes ] = None ,
132+ figsize : tuple [float , float ] | None = None ,
133+ smooth_kwargs : dict [str , Any ] | None = None ,
134+ ax : plt .Axes | None = None ,
134135) -> list [plt .Axes ]:
135136 """
136137 Individual conditional expectation plot.
@@ -258,24 +259,24 @@ def identity(x):
258259def plot_pdp (
259260 bartrv : Variable ,
260261 X : npt .NDArray ,
261- Y : Optional [ npt .NDArray ] = None ,
262+ Y : npt .NDArray | None = None ,
262263 xs_interval : str = "quantiles" ,
263- xs_values : Optional [ Union [ int , list [float ]]] = None ,
264- var_idx : Optional [ list [int ]] = None ,
265- var_discrete : Optional [ list [int ]] = None ,
266- func : Optional [ Callable ] = None ,
264+ xs_values : int | list [float ] | None = None ,
265+ var_idx : list [int ] | None = None ,
266+ var_discrete : list [int ] | None = None ,
267+ func : Callable | None = None ,
267268 samples : int = 200 ,
268269 ref_line : bool = True ,
269- random_seed : Optional [ int ] = None ,
270+ random_seed : int | None = None ,
270271 sharey : bool = True ,
271272 smooth : bool = True ,
272273 grid : str = "long" ,
273274 color = "C0" ,
274275 color_mean : str = "C0" ,
275276 alpha : float = 0.1 ,
276- figsize : Optional [ tuple [float , float ]] = None ,
277- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
278- ax : Optional [ plt .Axes ] = None ,
277+ figsize : tuple [float , float ] | None = None ,
278+ smooth_kwargs : dict [str , Any ] | None = None ,
279+ ax : plt .Axes = None ,
279280) -> list [plt .Axes ]:
280281 """
281282 Partial dependence plot.
@@ -425,8 +426,8 @@ def _create_figure_axes(
425426 var_idx : list [int ],
426427 grid : str = "long" ,
427428 sharey : bool = True ,
428- figsize : Optional [ tuple [float , float ]] = None ,
429- ax : Optional [ plt .Axes ] = None ,
429+ figsize : tuple [float , float ] | None = None ,
430+ ax : plt .Axes | None = None ,
430431) -> tuple [plt .Figure , list [plt .Axes ], int ]:
431432 """
432433 Create and return the figure and axes objects for plotting the variables.
@@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):
506507
507508def _prepare_plot_data (
508509 X : npt .NDArray ,
509- Y : Optional [ npt .NDArray ] = None ,
510+ Y : npt .NDArray | None = None ,
510511 xs_interval : str = "quantiles" ,
511- xs_values : Optional [ Union [ int , list [float ]]] = None ,
512- var_idx : Optional [ list [int ]] = None ,
513- var_discrete : Optional [ list [int ]] = None ,
512+ xs_values : int | list [float ] | None = None ,
513+ var_idx : list [int ] | None = None ,
514+ var_discrete : list [int ] | None = None ,
514515) -> tuple [
515516 npt .NDArray ,
516517 list [str ],
@@ -519,7 +520,7 @@ def _prepare_plot_data(
519520 list [int ],
520521 list [int ],
521522 str ,
522- Union [ int , None , list [float ] ],
523+ int | None | list [float ],
523524]:
524525 """
525526 Prepare data for plotting.
@@ -600,7 +601,7 @@ def _prepare_plot_data(
600601def _create_pdp_data (
601602 X : npt .NDArray ,
602603 xs_interval : str ,
603- xs_values : Optional [ Union [ int , list [float ]]] = None ,
604+ xs_values : int | list [float ] | None = None ,
604605) -> npt .NDArray :
605606 """
606607 Create data for partial dependence plot.
@@ -636,7 +637,7 @@ def _smooth_mean(
636637 new_x : npt .NDArray ,
637638 p_di : npt .NDArray ,
638639 kind : str = "neutral" ,
639- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
640+ smooth_kwargs : dict [str , Any ] | None = None ,
640641) -> tuple [np .ndarray , np .ndarray ]:
641642 """
642643 Smooth the mean data for plotting.
@@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
805806 fixed : int = 0 ,
806807 samples : int = 50 ,
807808 random_seed : int | None = None ,
808- ) -> dict [str , object ]:
809+ ) -> dict [str , npt . NDArray ]:
809810 """
810811 Estimates variable importance from the BART-posterior.
811812
@@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]:
10261027
10271028def plot_variable_importance (
10281029 vi_results : dict ,
1029- submodels : Optional [ Union [ list [int ], np .ndarray , tuple [int , ...]]] = None ,
1030- labels : Optional [ list [str ]] = None ,
1031- figsize : Optional [ tuple [float , float ]] = None ,
1032- plot_kwargs : Optional [ dict [str , Any ]] = None ,
1033- ax : Optional [ plt .Axes ] = None ,
1030+ submodels : list [int ] | np .ndarray | tuple [int , ...] | None = None ,
1031+ labels : list [str ] | None = None ,
1032+ figsize : tuple [float , float ] | None = None ,
1033+ plot_kwargs : dict [str , Any ] | None = None ,
1034+ ax : plt .Axes | None = None ,
10341035):
10351036 """
10361037 Estimates variable importance from the BART-posterior.
@@ -1128,13 +1129,13 @@ def plot_variable_importance(
11281129
11291130def plot_scatter_submodels (
11301131 vi_results : dict ,
1131- func : Optional [ Callable ] = None ,
1132- submodels : Optional [ Union [ list [int ], np .ndarray ]] = None ,
1132+ func : Callable | None = None ,
1133+ submodels : list [int ] | np .ndarray | None = None ,
11331134 grid : str = "long" ,
1134- labels : Optional [ list [str ]] = None ,
1135- figsize : Optional [ tuple [float , float ]] = None ,
1136- plot_kwargs : Optional [ dict [str , Any ]] = None ,
1137- ax : Optional [ plt .Axes ] = None ,
1135+ labels : list [str ] | None = None ,
1136+ figsize : tuple [float , float ] | None = None ,
1137+ plot_kwargs : dict [str , Any ] | None = None ,
1138+ ax : plt .Axes | None = None ,
11381139) -> list [plt .Axes ]:
11391140 """
11401141 Plot submodel's predictions against reference-model's predictions.
0 commit comments