From 693f83b95025fd5e88402a0145d5618d2c73bbfa Mon Sep 17 00:00:00 2001 From: mbaak Date: Mon, 22 Apr 2024 16:31:42 +0200 Subject: [PATCH] ENH: added parse_blocking_func function that accepts strings as input If input string gets provided, check if included in blocking map. --- emm/helper/blocking_functions.py | 28 ++++++++++++++++++++++++++ emm/indexing/pandas_cos_sim_matcher.py | 5 +++-- emm/indexing/spark_cos_sim_matcher.py | 5 +++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/emm/helper/blocking_functions.py b/emm/helper/blocking_functions.py index 27366ab..6f0ccaf 100644 --- a/emm/helper/blocking_functions.py +++ b/emm/helper/blocking_functions.py @@ -23,6 +23,10 @@ Please don't modify the function names. """ +from __future__ import annotations + +from typing import Callable + def first(x: str) -> str: """First character blocking function.""" @@ -37,3 +41,27 @@ def first2(x: str) -> str: def first3(x: str) -> str: """First two characters blocking function.""" return x.strip().lower()[:3] + + +BLOCKING_MAP = {"first": first, "first2": first2, "first3": first3} + + +def _parse_blocking_func(input: Callable[[str], str] | str | None = None) -> Callable[[str], str] | None: + """Helper function to get blocking function + + Args: + input: blocking function or name of existing blocking function + + Returns: + blocking function or None + """ + if input is None or callable(input): + return input + if isinstance(input, str): + if input not in BLOCKING_MAP: + msg = f"Input {input} is not a recognized blocking function." + raise ValueError(msg) + return BLOCKING_MAP[input] + + msg = "Input is not None, no string and not callable." + raise TypeError(msg) diff --git a/emm/indexing/pandas_cos_sim_matcher.py b/emm/indexing/pandas_cos_sim_matcher.py index ec349c0..21ecb85 100644 --- a/emm/indexing/pandas_cos_sim_matcher.py +++ b/emm/indexing/pandas_cos_sim_matcher.py @@ -30,6 +30,7 @@ from sklearn.base import TransformerMixin from sparse_dot_topn import sp_matmul_topn +from emm.helper.blocking_functions import _parse_blocking_func from emm.helper.util import groupby from emm.indexing.base_indexer import CosSimBaseIndexer from emm.indexing.pandas_normalized_tfidf import PandasNormalizedTfidfVectorizer @@ -52,7 +53,7 @@ def __init__( max_features: int | None = None, n_jobs: int = 1, spark_session: Any | None = None, - blocking_func: Callable[[str], str] | None = None, + blocking_func: Callable[[str], str] | str | None = None, dtype: type[float] = np.float32, indexer_id: int | None = None, ) -> None: @@ -99,7 +100,7 @@ def __init__( self.dtype = dtype self.cos_sim_lower_bound = cos_sim_lower_bound self.partition_size = partition_size - self.blocking_func = blocking_func + self.blocking_func = _parse_blocking_func(blocking_func) self.n_jobs = n_jobs if n_jobs != -1 else multiprocessing.cpu_count() self.spark_session = spark_session # attributes below are set during fit diff --git a/emm/indexing/spark_cos_sim_matcher.py b/emm/indexing/spark_cos_sim_matcher.py index 4da5e4c..7b84bc0 100644 --- a/emm/indexing/spark_cos_sim_matcher.py +++ b/emm/indexing/spark_cos_sim_matcher.py @@ -37,6 +37,7 @@ from pyspark.sql.types import ArrayType, FloatType, LongType, StringType, StructField, StructType from sparse_dot_topn import awesome_cossim_topn +from emm.helper.blocking_functions import _parse_blocking_func from emm.helper.spark_custom_reader_writer import SparkReadable, SparkWriteable from emm.helper.spark_utils import set_spark_job_group from emm.indexing.base_indexer import BaseIndexer, CosSimBaseIndexer @@ -74,7 +75,7 @@ def __init__( num_candidates: int = 2, cos_sim_lower_bound: float = 0.5, max_features: int = 2**25, - blocking_func: Callable[[str], str] | None = None, + blocking_func: Callable[[str], str] | str | None = None, streaming: bool = False, indexer_id: int | None = None, keep_all_cols: bool = False, @@ -146,7 +147,7 @@ def __init__( num_candidates=parameters.get("num_candidates", 10), cos_sim_lower_bound=parameters["cos_sim_lower_bound"], streaming=parameters["streaming"], - blocking_func=parameters["blocking_func"], + blocking_func=_parse_blocking_func(parameters["blocking_func"]), indexer_id=parameters["indexer_id"], n_threads=parameters.get("n_threads", 1), )