Skip to content

Commit 6b863c1

Browse files
authored
use base64 encoding (#247)
1 parent 11b5d54 commit 6b863c1

File tree

2 files changed

+35
-35
lines changed

2 files changed

+35
-35
lines changed

pymc_bart/pgbart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class PGBART(ArrayStepShared):
118118
default_blocked = False
119119
generates_stats = True
120120
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
121-
"variable_inclusion": (int, []),
121+
"variable_inclusion": (object, []),
122122
"tune": (bool, []),
123123
}
124124

pymc_bart/utils.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pylint: disable=too-many-branches
22
"""Utility function for variable selection and bart interpretability."""
33

4+
import base64
45
import warnings
56
from collections.abc import Callable
67
from typing import Any, TypeVar
@@ -708,7 +709,7 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
708709
"""
709710
n_vars = X.shape[1]
710711
vi_xarray = idata["sample_stats"]["variable_inclusion"]
711-
if "variable_inclusion_dim_0" in vi_xarray.coords:
712+
if vi_xarray.variable_inclusion_dim_0.size > 1:
712713
if model is None or bart_var_name is None:
713714
raise ValueError(
714715
"The InfereceData was generated from a model with multiple BART variables, \n"
@@ -727,13 +728,13 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
727728
n_vars = len(indices)
728729

729730
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
730-
labels = list(X.columns)
731+
labels = list(X.columns[indices])
731732

732733
if labels is None:
733-
labels = [str(i) for i in range(n_vars)]
734+
labels = [str(i) for i in indices]
734735

735736
if to_kulprit:
736-
return [labels[:idx] for idx in range(n_vars)]
737+
return [labels[:idx] for idx in range(n_vars + 1)]
737738
else:
738739
return VI_norm[indices], labels
739740

@@ -884,7 +885,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
884885

885886
if method in ["VI", "backward_VI"]:
886887
vi_xarray = idata["sample_stats"]["variable_inclusion"]
887-
if "variable_inclusion_dim_0" in vi_xarray.coords:
888+
if vi_xarray.variable_inclusion_dim_0.size > 1:
888889
if model is None:
889890
raise ValueError(
890891
"The InfereceData was generated from a model with multiple BART variables, \n"
@@ -968,7 +969,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
968969

969970
# Save values for plotting later
970971
r2_mean[i_var - init] = max_r_2
971-
r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars)
972+
r2_hdi[i_var - init] = array_stats.hdi(
973+
r_2_without_least_important_vars, prob=rcParams["stats.ci_prob"]
974+
)
972975
preds[i_var - init] = least_important_samples.squeeze()
973976

974977
# extend current list of least important variable
@@ -1282,37 +1285,34 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
12821285
return ax
12831286

12841287

1285-
def _decode_vi(n: int, length: int) -> list[int]:
1286-
"""
1287-
Decode the variable inclusion from the BART model.
1288-
"""
1289-
bits = bin(n)[2:]
1290-
vi_list: list[int] = []
1288+
def _decode_vi(s: str, length: int) -> list[int]:
1289+
"""Decode base64 string back to vector."""
1290+
data = base64.b64decode(s)
1291+
result: list[int] = []
12911292
i = 0
1292-
while len(vi_list) < length:
1293-
# Count prefix ones
1294-
prefix_len = 0
1295-
while bits[i] == "1":
1296-
prefix_len += 1
1293+
while len(result) < length and i < len(data):
1294+
num = 0
1295+
shift = 0
1296+
while i < len(data):
1297+
byte = data[i]
12971298
i += 1
1298-
i += 1 # skip the '0'
1299-
b = bits[i : i + prefix_len]
1300-
vi_list.append(int(b, 2))
1301-
i += prefix_len
1302-
return vi_list
1299+
num |= (byte & 0x7F) << shift
1300+
if not (byte & 0x80):
1301+
break
1302+
shift += 7
1303+
result.append(num)
1304+
return result
13031305

13041306

1305-
def _encode_vi(vec: npt.NDArray) -> int:
1307+
def _encode_vi(vec: list[int]) -> str:
13061308
"""
1307-
Encode variable inclusion vector into a single integer.
1308-
1309-
The encoding is done by converting each element of the vector into a binary string,
1310-
where each element contributes a prefix of '1's followed by a '0' and its binary representation.
1311-
The final result is the integer representation of the concatenated binary string.
1309+
Encode vector to base64 string.
13121310
"""
1313-
bits = ""
1314-
for x in vec:
1315-
b = bin(x)[2:]
1316-
prefix = "1" * len(b) + "0"
1317-
bits += prefix + b
1318-
return int(bits, 2)
1311+
result = bytearray()
1312+
for num in vec:
1313+
n = num
1314+
while n > 127:
1315+
result.append((n & 0x7F) | 0x80)
1316+
n >>= 7
1317+
result.append(n & 0x7F)
1318+
return base64.b64encode(bytes(result)).decode("ascii")

0 commit comments

Comments
 (0)