Skip to content

Commit

Permalink
change encoder to use sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
kdutia committed Sep 16, 2024
1 parent 4e7dc40 commit ab850b3
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions src/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from src import config


def sliding_window(text: str, window_size: int, stride: int) -> list[str]:
"""Split a text string into overlapping windows."""
windows = [
text[i : i + window_size]
for i in range(0, len(text), stride)
if i + window_size <= len(text)
]
return windows


class SentenceEncoder(ABC):
"""Base class for a sentence encoder"""

Expand Down Expand Up @@ -45,6 +55,15 @@ def __init__(self, model_name: str):
model_name, cache_folder=config.INDEX_ENCODER_CACHE_FOLDER
)

def get_n_tokens(self, text: str) -> int:
"""Return the number of tokens in the text."""

tokenized = self.encoder[0].tokenizer(
text, return_attention_mask=False, return_token_type_ids=False
)

return len(tokenized["input_ids"])

def encode(self, text: str, device: Optional[str] = None) -> np.ndarray:
"""Encode a string, return a numpy array.
Expand All @@ -70,10 +89,57 @@ def encode_batch(
Returns:
np.ndarray
"""
return self.encoder.encode(
text_batch, batch_size=batch_size, show_progress_bar=False, device=device
return self._encode_batch_using_sliding_window(
text_batch, batch_size=batch_size, device=device
)

def _encode_batch_using_sliding_window(
self, text_batch: list[str], batch_size: int = 32, device: Optional[str] = None
):
"""
Encode a batch of strings accommodating long texts using a sliding window.
The sliding window has length the size of the underlying encoder's context
window, and a stride of half of it.
For args, see encode_batch.
"""

max_seq_length = self.encoder.max_seq_length
assert isinstance(max_seq_length, int)

# Split the texts based on length and apply sliding window only to longer texts
processed_texts = []
window_lengths = []

for text in text_batch:
if self.get_n_tokens(text) > max_seq_length:
windows = sliding_window(
text, window_size=max_seq_length, stride=max_seq_length // 2
) # Use max_seq_length as window size and half of it as stride
processed_texts.extend(windows)
window_lengths.append(len(windows))
else:
processed_texts.append(text)
window_lengths.append(1)

embeddings = self.encoder.encode(
processed_texts, batch_size=batch_size, device=device # type: ignore
)

# Reduce the embeddings to the original number of texts
reduced_embeddings = []

for length in window_lengths:
if length > 1:
reduced_embeddings.append(np.mean(embeddings[:length], axis=0))
embeddings = embeddings[length:]
else:
reduced_embeddings.append(embeddings[0])
embeddings = embeddings[1:]

return np.vstack(reduced_embeddings)

@property
def dimension(self) -> int:
"""Return the dimension of the embedding."""
Expand Down

0 comments on commit ab850b3

Please sign in to comment.