Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow embedding model to capture full content of text blocks #28

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
data/**
!data/**/.gitkeep

# Cached models folder
models/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
70 changes: 68 additions & 2 deletions src/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from src import config


def sliding_window(text: str, window_size: int, stride: int) -> list[str]:
"""Split a text string into overlapping windows."""
windows = [
text[i : i + window_size]
for i in range(0, len(text), stride)
if i + window_size <= len(text)
]
return windows


class SentenceEncoder(ABC):
"""Base class for a sentence encoder"""

Expand Down Expand Up @@ -45,6 +55,15 @@ def __init__(self, model_name: str):
model_name, cache_folder=config.INDEX_ENCODER_CACHE_FOLDER
)

def get_n_tokens(self, text: str) -> int:
"""Return the number of tokens in the text."""

tokenized = self.encoder[0].tokenizer(
text, return_attention_mask=False, return_token_type_ids=False
)

return len(tokenized["input_ids"])

def encode(self, text: str, device: Optional[str] = None) -> np.ndarray:
"""Encode a string, return a numpy array.

Expand All @@ -70,10 +89,57 @@ def encode_batch(
Returns:
np.ndarray
"""
return self.encoder.encode(
text_batch, batch_size=batch_size, show_progress_bar=False, device=device
return self._encode_batch_using_sliding_window(
text_batch, batch_size=batch_size, device=device
)

def _encode_batch_using_sliding_window(
self, text_batch: list[str], batch_size: int = 32, device: Optional[str] = None
):
"""
Encode a batch of strings accommodating long texts using a sliding window.

The sliding window has length the size of the underlying encoder's context
window, and a stride of half of it.

For args, see encode_batch.
"""

max_seq_length = self.encoder.max_seq_length
assert isinstance(max_seq_length, int)

# Split the texts based on length and apply sliding window only to longer texts
processed_texts = []
window_lengths = []

for text in text_batch:
if self.get_n_tokens(text) > max_seq_length:
windows = sliding_window(
text, window_size=max_seq_length, stride=max_seq_length // 2
) # Use reasonable, safe, calculated values for the sliding window
processed_texts.extend(windows)
window_lengths.append(len(windows))
else:
processed_texts.append(text)
window_lengths.append(1)

embeddings = self.encoder.encode(
processed_texts, batch_size=batch_size, device=device # type: ignore
)

# Reduce the embeddings to the original number of texts
reduced_embeddings = []

for length in window_lengths:
if length > 1:
reduced_embeddings.append(np.mean(embeddings[:length], axis=0))
embeddings = embeddings[length:]
else:
reduced_embeddings.append(embeddings[0])
embeddings = embeddings[1:]

return np.vstack(reduced_embeddings)

@property
def dimension(self) -> int:
"""Return the dimension of the embedding."""
Expand Down
39 changes: 38 additions & 1 deletion src/test/test_ml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from src import config
from src.ml import SBERTEncoder
from src.ml import SBERTEncoder, sliding_window


def test_encoder():
Expand All @@ -16,3 +16,40 @@ def test_encoder():
assert isinstance(encoder.encode_batch(["Hello world!"] * 100), np.ndarray)

assert encoder.dimension == 768


def test_encoder_sliding_window():
"""Assert that we can encode long texts using a sliding window."""

encoder = SBERTEncoder(config.SBERT_MODEL)

long_text = "Hello world! " * 50
short_text = "Hello world!"

batch_to_encode = [short_text, long_text, short_text, short_text]
embeddings = encoder._encode_batch_using_sliding_window(
batch_to_encode, batch_size=32
)

assert isinstance(embeddings, np.ndarray)
assert embeddings.shape[0] == len(batch_to_encode)
assert embeddings.shape[1] == encoder.dimension

# embeddings of all short texts should be the same
assert np.array_equal(embeddings[0, :], embeddings[2, :])
assert np.array_equal(embeddings[0, :], embeddings[3, :])

# embedding of long text should not be the same as short text
assert not np.array_equal(embeddings[0, :], embeddings[1, :])


def test_sliding_window():
"""Tests that the sliding_window function returns the correct embeddings."""
text = "Hello world! " * 50
window_size = 10
stride = 5

windows = sliding_window(text=text, window_size=window_size, stride=stride)

assert windows[0] == "Hello worl"
assert windows[1] == " world! He"
Loading