Skip to content

Commit

Permalink
[BUG] BORF failing without numba (#2254)
Browse files Browse the repository at this point in the history
* fix #2245

* Empty commit for CI

* try another solution to fix using njit and loop

* fixes

---------

Co-authored-by: MatthewMiddlehurst <[email protected]>
Co-authored-by: Tony Bagnall <[email protected]>
Co-authored-by: MatthewMiddlehurst <[email protected]>
  • Loading branch information
4 people authored Oct 30, 2024
1 parent 15d44e1 commit b76073e
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions aeon/transformations/collection/dictionary_based/_borf.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,12 +548,15 @@ def _ndindex_2d_array(idx, dim2_shape):

@nb.njit(cache=True)
def _get_norm_bins(alphabet_size: int, mu=0, std=1):
return _ppf(np.linspace(0, 1, alphabet_size + 1)[1:-1], mu, std)
bins = []
for i in np.linspace(0, 1, alphabet_size + 1)[1:-1]:
bins.append(_ppf(i, mu, std))
return np.array(bins)


@nb.njit(fastmath=True, cache=True)
def _erfinv(x: float) -> float:
w = -math.log((1 - x) * (1 + x))
w = -np.log((1 - x) * (1 + x))
if w < 5:
w = w - 2.5
p = 2.81022636e-08
Expand All @@ -566,7 +569,7 @@ def _erfinv(x: float) -> float:
p = 0.246640727 + p * w
p = 1.50140941 + p * w
else:
w = math.sqrt(w) - 3
w = np.sqrt(w) - 3
p = -0.000200214257
p = 0.000100950558 + p * w
p = 0.00134934322 + p * w
Expand All @@ -579,9 +582,9 @@ def _erfinv(x: float) -> float:
return p * x


@nb.vectorize(cache=True)
@nb.njit(cache=True)
def _ppf(x, mu=0, std=1):
return mu + math.sqrt(2) * _erfinv(2 * x - 1) * std
return mu + np.sqrt(2) * _erfinv(2 * x - 1) * std


@nb.njit(fastmath=True, cache=True)
Expand Down Expand Up @@ -763,17 +766,16 @@ def _length(a):

@nb.njit(cache=True)
def _hash_function(v):

byte_mask = np.uint64(255)
bs = np.uint64(v)
x1 = (bs) & byte_mask
x2 = (bs >> 8) & byte_mask
x3 = (bs >> 16) & byte_mask
x4 = (bs >> 24) & byte_mask
x5 = (bs >> 32) & byte_mask
x6 = (bs >> 40) & byte_mask
x7 = (bs >> 48) & byte_mask
x8 = (bs >> 56) & byte_mask
x2 = (bs >> np.uint64(8)) & byte_mask
x3 = (bs >> np.uint64(16)) & byte_mask
x4 = (bs >> np.uint64(24)) & byte_mask
x5 = (bs >> np.uint64(32)) & byte_mask
x6 = (bs >> np.uint64(40)) & byte_mask
x7 = (bs >> np.uint64(48)) & byte_mask
x8 = (bs >> np.uint64(56)) & byte_mask

FNV_primer = np.uint64(1099511628211)
FNV_bias = np.uint64(14695981039346656037)
Expand All @@ -800,7 +802,7 @@ def _hash_function(v):
@nb.njit(cache=True)
def _make_hash_table(ar):
a = _length(len(ar))
mask = a - 1
mask = np.uint64(a - 1)

uniques = np.empty(a, dtype=ar.dtype)
uniques_cnt = np.zeros(a, dtype=np.int_)
Expand All @@ -822,7 +824,7 @@ def _set_item(uniques, uniques_cnt, mask, h, v, total, miss_hits, weight):
break
else:
miss_hits += 1
index += 1
index += np.uint64(1)
index = index & mask
return total, miss_hits

Expand Down

0 comments on commit b76073e

Please sign in to comment.