Skip to content

Commit

Permalink
jitter array of whole numbers (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 22, 2023
1 parent 83f2409 commit 2925162
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,16 +697,24 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
"""
Jitter duplicated values.
"""
seen = []
for idx, num in enumerate(array):
if num in seen:
array[idx] = num + np.random.normal(0, std / 12)
else:
seen.append(num)
if are_whole_number(array):
seen = []
for idx, num in enumerate(array):
if num in seen:
array[idx] = num + np.random.normal(0, std / 12)
else:
seen.append(num)

return array


@njit
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
"""Check if all values in array are whole numbers"""
new_array = np.mod(array, 1)
return np.all(new_array == 0)


def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
"""Compile PyTensor function of the model and the input and output variables.
Expand Down

0 comments on commit 2925162

Please sign in to comment.