Skip to content

Commit

Permalink
Recursive chunking (#8182)
Browse files Browse the repository at this point in the history
Co-authored-by: Szymon Dudycz <[email protected]>
GitOrigin-RevId: 5c5610943bfe8cac35d6e799cb8b21db09fc0db2
  • Loading branch information
2 people authored and Manul from Pathway committed Feb 25, 2025
1 parent e4d6d91 commit 74259ca
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]

### Added
- Added `RecursiveSplitter`



## [0.20.0] - 2025-02-25

### Added
Expand Down
19 changes: 19 additions & 0 deletions docs/2.developers/4.user-guide/50.llm-xpack/50.splitters.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,22 @@ text_splitter = TokenCountSplitter(
This configuration creates chunks of 100–500 tokens using the `cl100k_base` tokenizer, compatible with OpenAI's embedding models.

For more on token encodings, refer to [OpenAI's tiktoken guide](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#encodings).


## RecursiveSplitter

Another kind of splitter that you can use to chunk your documents in [`RecursiveSplitter`](/developers/api-docs/pathway-xpacks-llm/splitters#pathway.xpacks.llm.splitters.RecursiveSplitter).
It functions similarly to `TokenCountSplitter` in that it measures chunk length based on the number of tokens required to encode the text.
However, the way it determines split points differs.
`RecursiveSplitter` processes a document by iterating through a list of ordered `separators` (configurable in the constructor), starting with the most granular and moving to the least. For example, it may first attempt to split using `\n\n` and, if necessary, fall back to splitting at periods (`.`).
The splitter continues this process until all chunks are smaller than `chunk_size`.
Additionally, you can introduce overlapping chunks by setting the `chunk_overlap` parameter. This is particularly useful if you want to capture different contexts in your chunks. However, keep in mind that enabling overlap increases the total number of chunks retrieved, which could impact performance.

```python
splitter = RecursiveSplitter(
chunk_size=400,
chunk_overlap=200,
separators=["\n#", "\n##", "\n\n", "\n"], # separators for markdown documents
model_name="gpt-4o-mini",
)
```
23 changes: 23 additions & 0 deletions integration_tests/xpack/test_splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pandas as pd

import pathway as pw
from pathway.xpacks.llm.splitters import RecursiveSplitter


def test_recursive_from_hf_tokenizer():

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

splitter = RecursiveSplitter(hf_tokenizer=tokenizer, chunk_size=25, chunk_overlap=0)
txt = "Pójdź, kińże tę chmurność w głąb flaszy 🍾." # 23 tokens in bert tokenizer
big_txt = "\n\n".join([txt] * 5)
input_table = pw.debug.table_from_pandas(pd.DataFrame([dict(ret=big_txt)]))

result = input_table.select(ret=splitter(pw.this.ret)).flatten(pw.this.ret)
result = pw.debug.table_to_pandas(result)

assert len(result) == 5
assert result.iloc[0].ret[0] == txt
assert result.iloc[0].ret[1] == pw.Json({})
78 changes: 78 additions & 0 deletions python/pathway/xpacks/llm/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import abc
import unicodedata
from typing import TYPE_CHECKING

import pathway as pw
from pathway.optional_import import optional_imports
Expand Down Expand Up @@ -80,6 +81,83 @@ def chunk(self, text: str, metadata: dict = {}, **kwargs) -> list[tuple[str, dic
pass


SEPARATORS = ["\n\n", "\n", " ", ""]


# wrapper around Langchain splitter
class RecursiveSplitter(BaseSplitter):
"""
Splitter that splits a long text into smaller chunks based on a set of separators.
Chunking is performed recursively using first separator in the list and then second
separator in the list and so on, until the text is split into chunks of length smaller than ``chunk_size``.
Length of the chunks is measured by the number of characters in the text if none of
``encoding_name``, ``model_name`` or ``hf_tokenizer`` is provided. Otherwise, the length of the
chunks is measured by the number of tokens that particular tokenizer would output.
Under the hood it is a wrapper around ``langchain_text_splitters.RecursiveTextSplitter`` (MIT license).
Args:
chunk_size: maximum size of a chunk in characters/tokens.
chunk_overlap: number of characters/tokens to overlap between chunks.
separators: list of strings to split the text on.
is_separator_regex: whether the separators are regular expressions.
encoding_name: name of the encoding from ``tiktoken``.
For the list of available encodings please refer to tiktoken documentation:
https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
model_name: name of the model from ``tiktoken``. See the link above for more details.
hf_tokenizer: Huggingface tokenizer to use for tokenization.
"""

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase

def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 0,
separators: list[str] = SEPARATORS,
is_separator_regex: bool = False,
encoding_name: str | None = None,
model_name: str | None = None,
hf_tokenizer: "PreTrainedTokenizerBase | None" = None,
):
super().__init__()

with optional_imports("xpack-llm"):
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TextSplitter,
)

self.kwargs = dict(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=separators,
is_separator_regex=is_separator_regex,
)

self._splitter: TextSplitter

if encoding_name is not None:
self._splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
encoding_name=encoding_name, **self.kwargs
)
elif model_name is not None:
self._splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
model_name=model_name, **self.kwargs
)
elif hf_tokenizer is not None:
self._splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=hf_tokenizer, **self.kwargs
)
else:
self._splitter = RecursiveCharacterTextSplitter(**self.kwargs)

def chunk(self, text: str, metadata: dict = {}, **kwargs) -> list[tuple[str, dict]]:
chunked = self._splitter.split_text(text)
return [(chunk, metadata) for chunk in chunked]


class NullSplitter(BaseSplitter):
"""A splitter which returns its argument as one long text ith null metadata.
Expand Down
36 changes: 35 additions & 1 deletion python/pathway/xpacks/llm/tests/test_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import pathway as pw
from pathway.tests.utils import assert_table_equality
from pathway.xpacks.llm.splitters import NullSplitter, TokenCountSplitter
from pathway.xpacks.llm.splitters import (
NullSplitter,
RecursiveSplitter,
TokenCountSplitter,
)


def test_null():
Expand All @@ -25,3 +29,33 @@ def test_tokencount():
result = input_table.select(ret=splitter(pw.this.ret)[0][0])

assert_table_equality(result, input_table)


def test_recursive_from_encoding():
splitter = RecursiveSplitter(
encoding_name="cl100k_base", chunk_size=30, chunk_overlap=0
)
txt = "Pójdź, kińże tę chmurność w głąb flaszy 🍾." # 26 tokens in cl100k_base
big_txt = "\n\n".join([txt] * 5)
input_table = pw.debug.table_from_pandas(pd.DataFrame([dict(ret=big_txt)]))

result = input_table.select(ret=splitter(pw.this.ret)).flatten(pw.this.ret)
result = pw.debug.table_to_pandas(result)

assert len(result) == 5
assert result.iloc[0].ret[0] == txt
assert result.iloc[0].ret[1] == pw.Json({})


def test_recursive_from_model_name():
splitter = RecursiveSplitter(model_name="gpt-4", chunk_size=30, chunk_overlap=0)
txt = "Pójdź, kińże tę chmurność w głąb flaszy 🍾." # 26 tokens in cl100k_base
big_txt = "\n\n".join([txt] * 5)
input_table = pw.debug.table_from_pandas(pd.DataFrame([dict(ret=big_txt)]))

result = input_table.select(ret=splitter(pw.this.ret)).flatten(pw.this.ret)
result = pw.debug.table_to_pandas(result)

assert len(result) == 5
assert result.iloc[0].ret[0] == txt
assert result.iloc[0].ret[1] == pw.Json({})

0 comments on commit 74259ca

Please sign in to comment.