Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More primitive, much faster counter #62

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions outrank/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def main():
help='Name of the target attribute for ranking. Note that this can be any other feature for most implemented heuristics.',
)

parser.add_argument(
'--max_unique_hist_constraint',
type=int,
default=30_000,
help='Max number of unique values for which counts are recalled.',
)

parser.add_argument(
'--transformers',
type=str,
Expand Down
8 changes: 0 additions & 8 deletions outrank/algorithms/sketches/counting_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, depth=6, width=2**15, M=None):
self.width = width
self.hash_seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=depth), dtype=np.uint32)
self.M = np.zeros((depth, width), dtype=np.int32) if M is None else M
self.tmp_vals = set()

@staticmethod
@njit
Expand All @@ -33,8 +32,6 @@ def _add(M, x, depth, width, hash_seeds, delta=1):
M[i, location] += delta

def add(self, x, delta=1):
if len(self.tmp_vals) < 10 ** 4 or sys.getsizeof(self.tmp_vals) / (10 ** 3) < 100.0:
self.tmp_vals.add(x)
CountMinSketch._add(self.M, x, self.depth, self.width, self.hash_seeds, delta)

def batch_add(self, lst, delta=1):
Expand All @@ -47,10 +44,6 @@ def query(self, x):
def get_matrix(self):
return self.M

def stream_hist_update(self):
""" A bit hacky way to aggregate cms results """
return Counter(self.query(x) for x in self.tmp_vals)


if __name__ == '__main__':
from collections import Counter
Expand All @@ -69,4 +62,3 @@ def stream_hist_update(self):
print(cms.query(5))

print(Counter(items)) # Print the exact counts for comparison
print(cms.stream_hist_update())
40 changes: 40 additions & 0 deletions outrank/algorithms/sketches/counting_counters_ordinary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from collections import Counter


class PrimitiveConstrainedCounter:
"""
A memory-efficient implementation of the count min sketch algorithm with optimized hashing using Numba JIT.
"""

def __init__(self, bound: int=(10**4) * 3):
self.max_bound_thr = bound
self.default_counter: Counter = Counter()

def batch_add(self, lst):
if len(self.default_counter) < self.max_bound_thr:
self.default_counter = self.default_counter + Counter(lst)

def add(self, val):
if len(self.default_counter) < self.max_bound_thr:
self.default_counter[val] += 1

def stream_hist_update(self):
miha-jenko marked this conversation as resolved.
Show resolved Hide resolved
return dict(self.default_counter)


if __name__ == '__main__':
from collections import Counter

depth = 8
width = 2**22
import numpy as np
cms = PrimitiveConstrainedCounter()

items = [1, 1, 2, 3, 3, 3, 4, 5, 2] * 10000
cms.batch_add(items) # Use the batch_add function

print(Counter(items)) # Print the exact counts for comparison
print(cms.stream_hist_update())
print(list(v for _, v in cms.stream_hist_update().items()))
miha-jenko marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions outrank/core_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tqdm

from outrank.algorithms.importance_estimator import get_importances_estimate_pairwise
from outrank.algorithms.sketches.counting_cms import CountMinSketch
from outrank.algorithms.sketches.counting_counters_ordinary import PrimitiveConstrainedCounter
from outrank.algorithms.sketches.counting_ultiloglog import (
HyperLogLogWCache as HyperLogLog,
)
Expand Down Expand Up @@ -421,7 +421,7 @@ def compute_value_counts(input_dataframe: pd.DataFrame, args: Any):
del GLOBAL_RARE_VALUE_STORAGE[to_remove_val]


def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any) -> None:
def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any, max_unique_hist_constraint: int) -> None:
"""Compute cardinalities of features, incrementally"""

global GLOBAL_CARDINALITY_STORAGE
Expand All @@ -434,7 +434,7 @@ def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any) -> None:
)

if column not in GLOBAL_COUNTS_STORAGE:
GLOBAL_COUNTS_STORAGE[column] = CountMinSketch()
GLOBAL_COUNTS_STORAGE[column] = PrimitiveConstrainedCounter(max_unique_hist_constraint)

[GLOBAL_COUNTS_STORAGE[column].add(value) for value in input_dataframe[column].values]

Expand Down Expand Up @@ -553,7 +553,7 @@ def compute_batch_ranking(
feature_memory_consumption = compute_feature_memory_consumption(
input_dataframe, args,
)
compute_cardinalities(input_dataframe, pbar)
compute_cardinalities(input_dataframe, pbar, args.max_unique_hist_constraint)

if args.task == 'identify_rare_values':
compute_value_counts(input_dataframe, args)
Expand Down
2 changes: 1 addition & 1 deletion outrank/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def parse_namespace(namespace_path: str) -> tuple[set[str], dict[str, str]]:
if type_name == 'f32':
float_set.add(feature)
except Exception as es:
logging.error(f'\U0001F631 {es} -- {namespace_parts}')
pass
miha-jenko marked this conversation as resolved.
Show resolved Hide resolved

return float_set, id_feature_map

Expand Down
2 changes: 1 addition & 1 deletion outrank/task_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def outrank_task_conduct_ranking(args: Any) -> None:
with open(f'{args.output_folder}/value_repetitions.json', 'w') as out_counts:
out_dict = {}
for k, v in GLOBAL_ITEM_COUNTS.items():
actual_hist = np.array([k + v for k, v in v.stream_hist_update().items()])
actual_hist = np.array([v for _, v in v.stream_hist_update().items()])
miha-jenko marked this conversation as resolved.
Show resolved Hide resolved
more_than = lambda n, ary: len(np.where(ary > n)[0])
out_dict[k] = {x: more_than(x, actual_hist) for x in [0] + [1 * 10 ** x for x in range(6)]}
out_counts.write(json.dumps(out_dict))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _read_description():
packages = [x for x in setuptools.find_packages() if x != 'test']
setuptools.setup(
name='outrank',
version='0.95.8',
version='0.95.9',
miha-jenko marked this conversation as resolved.
Show resolved Hide resolved
description='OutRank: Feature ranking for massive sparse data sets.',
long_description=_read_description(),
long_description_content_type='text/markdown',
Expand Down
21 changes: 0 additions & 21 deletions tests/cms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_init(self):
self.assertEqual(self.cms.width, self.width)
self.assertEqual(self.cms.M.shape, (self.depth, self.width))
self.assertEqual(len(self.cms.hash_seeds), self.depth)
self.assertIsInstance(self.cms.tmp_vals, set)

def test_add_and_query_single_element(self):
# Test adding a single element and querying it
Expand All @@ -46,26 +45,6 @@ def test_batch_add_and_query(self):
for elem in set(elements):
self.assertGreaterEqual(self.cms.query(elem), 10)

def test_stream_hist_update(self):
self.cms.add('foo')
self.cms.add('foo')
self.cms.add('bar')

hist = self.cms.stream_hist_update()

# Note: we cannot test for exact counts because the CountMinSketch is a probabilistic data structure
# and may overcount. However, we never expect it to undercount an element.
self.assertGreaterEqual(hist[self.cms.query('foo')], 1)
self.assertGreaterEqual(hist[self.cms.query('bar')], 1)

def test_overflow_protection(self):
# This test ensures that the set doesn't grow beyond its allowed size and memory usage
for i in range(100001):
self.cms.add(f'element{i}')

self.assertLessEqual(len(self.cms.tmp_vals), 100000)
self.assertLessEqual(sys.getsizeof(self.cms.tmp_vals) / (10 ** 3), 4200.0)

def test_hash_uniformity(self):
# Basic check for hash function's distribution
seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=self.depth), dtype=np.uint32)
Expand Down
Loading