Skip to content

Commit

Permalink
added test for default case of dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
chrispyl committed Oct 6, 2024
1 parent 31642ae commit 55b433c
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/integration/test_pandas_em.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,18 @@ def test_pandas_tfidf(dtype):
np.testing.assert_allclose(actual_value, exp_value, rtol=0, atol=0.001)


def test_pandas_tfidf_default_dtype():
pandas_t = PandasNormalizedTfidfVectorizer()
unique_names = [str(uuid.uuid4()) for i in range(100)]
gt_names = pd.Series(unique_names)
pandas_t.fit(gt_names)
assert pandas_t.idf_.dtype == np.float32


@pytest.mark.parametrize(
("dtype", "data_size"), [(np.float32, 100), (np.float64, 100), (np.float32, 1000000), (np.float64, 1000000)]
)
def test_pandas_tfidf_dtype(dtype, data_size):
def test_pandas_tfidf_dtype_for_different_input_sizes(dtype, data_size):
pandas_t = PandasNormalizedTfidfVectorizer(dtype=dtype)
unique_names = [str(uuid.uuid4()) for i in range(data_size)]
gt_names = pd.Series(unique_names)
Expand Down

0 comments on commit 55b433c

Please sign in to comment.