diff --git a/emm/indexing/pandas_normalized_tfidf.py b/emm/indexing/pandas_normalized_tfidf.py index 4458615..3b3162b 100644 --- a/emm/indexing/pandas_normalized_tfidf.py +++ b/emm/indexing/pandas_normalized_tfidf.py @@ -37,8 +37,6 @@ class PandasNormalizedTfidfVectorizer(TfidfVectorizer): """Implementation of customized TFIDF vectorizer""" - dtype = np.float32 - def __init__(self, **kwargs: Any) -> None: """Implementation of customized TFIDF vectorizer @@ -53,6 +51,7 @@ def __init__(self, **kwargs: Any) -> None: Args: kwargs: kew-word arguments are same as TfidfVectorizer. """ + kwargs.setdefault("dtype", np.float32) kwargs.update({"norm": None, "smooth_idf": True, "lowercase": True}) if kwargs.get("analyzer") in {"word", None}: kwargs["token_pattern"] = r"\w+" @@ -74,6 +73,8 @@ def fit(self, X: pd.Series | pd.DataFrame) -> TfidfVectorizer: with Timer("CustomizedTfidfVectorizer.fit") as timer: timer.label("super fit") super().fit(X) + # scikit-learn's TfidfVectorizer does not preserve dtype for large X, so we force it here + self.idf_ = self.idf_.astype(self.dtype) timer.label("normalize") n_features = self.idf_.shape[0] diff --git a/emm/version.py b/emm/version.py index 006f8eb..acb041b 100644 --- a/emm/version.py +++ b/emm/version.py @@ -17,6 +17,6 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -VERSION = "2.1.5" +VERSION = "2.1.6" __version__ = VERSION diff --git a/tests/integration/test_pandas_em.py b/tests/integration/test_pandas_em.py index b60e132..f14e30d 100644 --- a/tests/integration/test_pandas_em.py +++ b/tests/integration/test_pandas_em.py @@ -21,6 +21,7 @@ import logging import os +import uuid import numpy as np import pandas as pd @@ -135,6 +136,25 @@ def test_pandas_tfidf(dtype): np.testing.assert_allclose(actual_value, exp_value, rtol=0, atol=0.001) +def test_pandas_tfidf_default_dtype(): + pandas_t = PandasNormalizedTfidfVectorizer() + unique_names = [str(uuid.uuid4()) for i in range(100)] + gt_names = pd.Series(unique_names) + pandas_t.fit(gt_names) + assert pandas_t.idf_.dtype == np.float32 + + +@pytest.mark.parametrize( + ("dtype", "data_size"), [(np.float32, 100), (np.float64, 100), (np.float32, 1000000), (np.float64, 1000000)] +) +def test_pandas_tfidf_dtype_for_different_input_sizes(dtype, data_size): + pandas_t = PandasNormalizedTfidfVectorizer(dtype=dtype) + unique_names = [str(uuid.uuid4()) for i in range(data_size)] + gt_names = pd.Series(unique_names) + pandas_t.fit(gt_names) + assert pandas_t.idf_.dtype == dtype + + def test_pandas_tfidf_ngram(): pandas_t = PandasNormalizedTfidfVectorizer(binary=True, analyzer="char", ngram_range=(3, 3)) gt_names = pd.Series(["aaab", "bbbc"])