Skip to content

Commit

Permalink
Merge pull request #50 from outbrain/more-speedups
Browse files Browse the repository at this point in the history
Around 5x faster MI
  • Loading branch information
SkBlaz committed Oct 16, 2023
2 parents 2fc30f2 + dded349 commit 5f4bf26
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions outrank/algorithms/feature_ranking/ranking_mi_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def numba_unique(a):


@njit(
'float32(int32[:], int32[:], int32, float32, int32[:])',
'float32(uint32[:], int32[:], int32, float32, uint32[:])',
cache=True,
fastmath=True,
error_model='numpy',
Expand Down Expand Up @@ -82,11 +82,17 @@ def compute_entropies(
initial_prob = _f_value_counts / all_events
x_value_subspace = np.where(X == f_values[f_index])

Y_classes = Y[x_value_subspace]
Y_classes_spoofed = np.roll(Y, _f_value_counts)[x_value_subspace]
Y_classes = Y[x_value_subspace].astype(np.uint32)
subspace_size = x_value_subspace[0].size

nonzero_class_counts = np.zeros(len(class_values), dtype=np.int32)
nonzero_class_counts_spoofed = np.zeros(len(class_values), dtype=np.int32)
# Right-shift to simulate noise
Y_classes_spoofed = np.zeros(subspace_size, dtype=np.uint32)
for enx, el in enumerate(x_value_subspace[0]):
index = (el + _f_value_counts) % len(Y)
Y_classes_spoofed[enx] = Y[index]

nonzero_class_counts = np.zeros(len(class_values), dtype=np.uint32)
nonzero_class_counts_spoofed = np.zeros(len(class_values), dtype=np.uint32)

# Cache nonzero counts
for index, c in enumerate(class_values):
Expand Down

0 comments on commit 5f4bf26

Please sign in to comment.