Skip to content

Commit

Permalink
add jitter to duplicated values for continuous splitting rule (#129)
Browse files Browse the repository at this point in the history
* jit continuous rule

* omit intermidiate step
  • Loading branch information
aloctavodia authored Nov 21, 2023
1 parent d202b07 commit c415074
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,9 @@ def __init__(
else:
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]

jittered = np.random.normal(self.X, self.X.std(axis=0) / 12)
min_values = np.min(self.X, axis=0)
max_values = np.max(self.X, axis=0)
self.X = np.clip(jittered, min_values, max_values)
for idx, rule in enumerate(self.split_rules):
if rule is ContinuousSplitRule:
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.std(self.X[:, idx]))

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
Expand Down Expand Up @@ -693,6 +692,21 @@ def inverse_cdf(
return new_indices


@njit
def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[np.float_]:
"""
Jitter duplicated values.
"""
seen = []
for idx, num in enumerate(array):
if num in seen:
array[idx] = num + np.random.normal(0, std / 12)
else:
seen.append(num)

return array


def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
"""Compile PyTensor function of the model and the input and output variables.
Expand Down
2 changes: 1 addition & 1 deletion pymc_bart/split_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def divide(available_splitting_values, split_value):
class SubsetSplitRule(SplitRule):
"""
Choose a random subset of the categorical values and branch on belonging to that set.
This is the approach taken by Sameer K. Deshpande.
This is the approach taken by Sameer K. Deshpande.
flexBART: Flexible Bayesian regression trees with categorical predictors. arXiv,
`link <https://arxiv.org/abs/2211.04459>`__
"""
Expand Down

0 comments on commit c415074

Please sign in to comment.