Skip to content

Commit

Permalink
ENH: added parse_blocking_func function that accepts strings as input
Browse files Browse the repository at this point in the history
If input string gets provided, check if included in blocking map.
  • Loading branch information
mbaak committed Apr 23, 2024
1 parent 5c67105 commit 693f83b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
28 changes: 28 additions & 0 deletions emm/helper/blocking_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
Please don't modify the function names.
"""

from __future__ import annotations

from typing import Callable


def first(x: str) -> str:
"""First character blocking function."""
Expand All @@ -37,3 +41,27 @@ def first2(x: str) -> str:
def first3(x: str) -> str:
"""First two characters blocking function."""
return x.strip().lower()[:3]


BLOCKING_MAP = {"first": first, "first2": first2, "first3": first3}


def _parse_blocking_func(input: Callable[[str], str] | str | None = None) -> Callable[[str], str] | None:
"""Helper function to get blocking function
Args:
input: blocking function or name of existing blocking function
Returns:
blocking function or None
"""
if input is None or callable(input):
return input
if isinstance(input, str):
if input not in BLOCKING_MAP:
msg = f"Input {input} is not a recognized blocking function."
raise ValueError(msg)
return BLOCKING_MAP[input]

msg = "Input is not None, no string and not callable."
raise TypeError(msg)
5 changes: 3 additions & 2 deletions emm/indexing/pandas_cos_sim_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sklearn.base import TransformerMixin
from sparse_dot_topn import sp_matmul_topn

from emm.helper.blocking_functions import _parse_blocking_func
from emm.helper.util import groupby
from emm.indexing.base_indexer import CosSimBaseIndexer
from emm.indexing.pandas_normalized_tfidf import PandasNormalizedTfidfVectorizer
Expand All @@ -52,7 +53,7 @@ def __init__(
max_features: int | None = None,
n_jobs: int = 1,
spark_session: Any | None = None,
blocking_func: Callable[[str], str] | None = None,
blocking_func: Callable[[str], str] | str | None = None,
dtype: type[float] = np.float32,
indexer_id: int | None = None,
) -> None:
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(
self.dtype = dtype
self.cos_sim_lower_bound = cos_sim_lower_bound
self.partition_size = partition_size
self.blocking_func = blocking_func
self.blocking_func = _parse_blocking_func(blocking_func)
self.n_jobs = n_jobs if n_jobs != -1 else multiprocessing.cpu_count()
self.spark_session = spark_session
# attributes below are set during fit
Expand Down
5 changes: 3 additions & 2 deletions emm/indexing/spark_cos_sim_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pyspark.sql.types import ArrayType, FloatType, LongType, StringType, StructField, StructType
from sparse_dot_topn import awesome_cossim_topn

from emm.helper.blocking_functions import _parse_blocking_func
from emm.helper.spark_custom_reader_writer import SparkReadable, SparkWriteable
from emm.helper.spark_utils import set_spark_job_group
from emm.indexing.base_indexer import BaseIndexer, CosSimBaseIndexer
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(
num_candidates: int = 2,
cos_sim_lower_bound: float = 0.5,
max_features: int = 2**25,
blocking_func: Callable[[str], str] | None = None,
blocking_func: Callable[[str], str] | str | None = None,
streaming: bool = False,
indexer_id: int | None = None,
keep_all_cols: bool = False,
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
num_candidates=parameters.get("num_candidates", 10),
cos_sim_lower_bound=parameters["cos_sim_lower_bound"],
streaming=parameters["streaming"],
blocking_func=parameters["blocking_func"],
blocking_func=_parse_blocking_func(parameters["blocking_func"]),
indexer_id=parameters["indexer_id"],
n_threads=parameters.get("n_threads", 1),
)
Expand Down

0 comments on commit 693f83b

Please sign in to comment.