From fca1214c2b668a91a7ac8060df9920150e783d48 Mon Sep 17 00:00:00 2001 From: Kalyan Dutia Date: Mon, 16 Sep 2024 13:53:26 +0100 Subject: [PATCH] add tests --- src/test/test_ml.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/test/test_ml.py b/src/test/test_ml.py index 0a5c14d..f9aad39 100644 --- a/src/test/test_ml.py +++ b/src/test/test_ml.py @@ -1,7 +1,7 @@ import numpy as np from src import config -from src.ml import SBERTEncoder +from src.ml import SBERTEncoder, sliding_window def test_encoder(): @@ -16,3 +16,40 @@ def test_encoder(): assert isinstance(encoder.encode_batch(["Hello world!"] * 100), np.ndarray) assert encoder.dimension == 768 + + +def test_encoder_sliding_window(): + """Assert that we can encode long texts using a sliding window.""" + + encoder = SBERTEncoder(config.SBERT_MODEL) + + long_text = "Hello world! " * 50 + short_text = "Hello world!" + + batch_to_encode = [short_text, long_text, short_text, short_text] + embeddings = encoder._encode_batch_using_sliding_window( + batch_to_encode, batch_size=32 + ) + + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape[0] == len(batch_to_encode) + assert embeddings.shape[1] == encoder.dimension + + # embeddings of all short texts should be the same + assert np.array_equal(embeddings[0, :], embeddings[2, :]) + assert np.array_equal(embeddings[0, :], embeddings[3, :]) + + # embedding of long text should not be the same as short text + assert not np.array_equal(embeddings[0, :], embeddings[1, :]) + + +def test_sliding_window(): + """Tests that the sliding_window function returns the correct embeddings.""" + text = "Hello world! " * 50 + window_size = 10 + stride = 5 + + windows = sliding_window(text=text, window_size=window_size, stride=stride) + + assert windows[0] == "Hello worl" + assert windows[1] == " world! He"