forked from microsoft/semantic-kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TextChunker.cs
268 lines (227 loc) · 10.3 KB
/
TextChunker.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
namespace Microsoft.SemanticKernel.Text;
/// <summary>
/// Split text in chunks, attempting to leave meaning intact.
/// For plain text, split looking at new lines first, then periods, and so on.
/// For markdown, split looking at punctuation first, and so on.
/// </summary>
public static class TextChunker
{
private static readonly char[] s_spaceChar = new[] { ' ' };
private static readonly string?[] s_plaintextSplitOptions = new[] { "\n\r", ".", "?!", ";", ":", ",", ")]}", " ", "-", null };
private static readonly string?[] s_markdownSplitOptions = new[] { ".", "?!", ";", ":", ",", ")]}", " ", "-", "\n\r", null };
/// <summary>
/// Split plain text into lines.
/// </summary>
/// <param name="text">Text to split</param>
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
/// <returns>List of lines.</returns>
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine)
{
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_plaintextSplitOptions);
}
/// <summary>
/// Split markdown text into lines.
/// </summary>
/// <param name="text">Text to split</param>
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
/// <returns>List of lines.</returns>
public static List<string> SplitMarkDownLines(string text, int maxTokensPerLine)
{
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_markdownSplitOptions);
}
/// <summary>
/// Split plain text into paragraphs.
/// </summary>
/// <param name="lines">Lines of text.</param>
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
/// <returns>List of paragraphs.</returns>
public static List<string> SplitPlainTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0)
{
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_plaintextSplitOptions));
}
/// <summary>
/// Split markdown text into paragraphs.
/// </summary>
/// <param name="lines">Lines of text.</param>
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
/// <returns>List of paragraphs.</returns>
public static List<string> SplitMarkdownParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0)
{
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_markdownSplitOptions));
}
private static List<string> InternalSplitTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens, Func<string, int, List<string>> longLinesSplitter)
{
if (maxTokensPerParagraph <= 0)
{
throw new ArgumentException("maxTokensPerParagraph should be a positive number");
}
if (maxTokensPerParagraph <= overlapTokens)
{
throw new ArgumentException("overlapTokens cannot be larger than maxTokensPerParagraph");
}
if (lines.Count == 0)
{
return new List<string>();
}
var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens;
// Split long lines first
IEnumerable<string> truncatedLines = lines.SelectMany(line => longLinesSplitter(line, adjustedMaxTokensPerParagraph));
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, longLinesSplitter);
// distribute text more evenly in the last paragraphs when the last paragraph is too short.
if (paragraphs.Count > 1)
{
var lastParagraph = paragraphs[paragraphs.Count - 1];
var secondLastParagraph = paragraphs[paragraphs.Count - 2];
if (TokenCount(lastParagraph.Length) < adjustedMaxTokensPerParagraph / 4)
{
var lastParagraphTokens = lastParagraph.Split(s_spaceChar, StringSplitOptions.RemoveEmptyEntries);
var secondLastParagraphTokens = secondLastParagraph.Split(s_spaceChar, StringSplitOptions.RemoveEmptyEntries);
var lastParagraphTokensCount = lastParagraphTokens.Length;
var secondLastParagraphTokensCount = secondLastParagraphTokens.Length;
if (lastParagraphTokensCount + secondLastParagraphTokensCount <= adjustedMaxTokensPerParagraph)
{
var newSecondLastParagraph = string.Join(" ", secondLastParagraphTokens);
var newLastParagraph = string.Join(" ", lastParagraphTokens);
paragraphs[paragraphs.Count - 2] = $"{newSecondLastParagraph} {newLastParagraph}";
paragraphs.RemoveAt(paragraphs.Count - 1);
}
}
}
if (overlapTokens > 0 && paragraphs.Count > 1)
{
var lastParagraph = paragraphs.Last();
paragraphs = paragraphs.Zip(paragraphs.Skip(1), (currentParagraph, nextParagraph) =>
{
var split = longLinesSplitter(nextParagraph, overlapTokens);
return $"{currentParagraph} {split.FirstOrDefault()}";
}).ToList();
paragraphs.Add(lastParagraph);
}
return paragraphs;
}
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, Func<string, int, List<string>> longLinesSplitter)
{
StringBuilder paragraphBuilder = new();
List<string> paragraphs = new();
foreach (string line in truncatedLines)
{
if (paragraphBuilder.Length > 0 && TokenCount(paragraphBuilder.Length) + TokenCount(line.Length) + 1 >= maxTokensPerParagraph)
{
// Complete the paragraph and prepare for the next
paragraphs.Add(paragraphBuilder.ToString().Trim());
paragraphBuilder.Clear();
}
paragraphBuilder.AppendLine(line);
}
if (paragraphBuilder.Length > 0)
{
// Add the final paragraph if there's anything remaining
paragraphs.Add(paragraphBuilder.ToString().Trim());
}
return paragraphs;
}
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions)
{
var result = new List<string>();
text = text.NormalizeLineEndings();
result.Add(text);
for (int i = 0; i < splitOptions.Length; i++)
{
int count = result.Count; // track where the original input left off
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim);
result.AddRange(splits2);
result.RemoveRange(0, count); // remove the original input
if (!inputWasSplit2)
{
break;
}
}
return result;
}
private static (List<string>, bool) Split(List<string> input, int maxTokens, ReadOnlySpan<char> separators, bool trim)
{
bool inputWasSplit = false;
List<string> result = new();
int count = input.Count;
for (int i = 0; i < count; i++)
{
var (splits, split) = Split(input[i].AsSpan(), input[i], maxTokens, separators, trim);
result.AddRange(splits);
inputWasSplit |= split;
}
return (result, inputWasSplit);
}
private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim)
{
Debug.Assert(inputString is null || input.SequenceEqual(inputString.AsSpan()));
List<string> result = new();
var inputWasSplit = false;
if (TokenCount(input.Length) > maxTokens)
{
inputWasSplit = true;
int half = input.Length / 2;
int cutPoint = -1;
if (separators.IsEmpty)
{
cutPoint = half;
}
else if (input.Length > 2)
{
int pos = 0;
while (true)
{
int index = input.Slice(pos, input.Length - 1 - pos).IndexOfAny(separators);
if (index < 0)
{
break;
}
index += pos;
if (Math.Abs(half - index) < Math.Abs(half - cutPoint))
{
cutPoint = index + 1;
}
pos = index + 1;
}
}
if (cutPoint > 0)
{
var firstHalf = input.Slice(0, cutPoint);
var secondHalf = input.Slice(cutPoint);
if (trim)
{
firstHalf = firstHalf.Trim();
secondHalf = secondHalf.Trim();
}
// Recursion
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim);
result.AddRange(splits1);
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim);
result.AddRange(splits2);
inputWasSplit = split1 || split2;
return (result, inputWasSplit);
}
}
result.Add((inputString is not null, trim) switch
{
(true, true) => inputString!.Trim(),
(true, false) => inputString!,
(false, true) => input.Trim().ToString(),
(false, false) => input.ToString(),
});
return (result, inputWasSplit);
}
private static int TokenCount(int inputLength)
{
// TODO: partitioning methods should be configurable to allow for different tokenization strategies
// depending on the model to be called. For now, we use an extremely rough estimate.
return inputLength / 4;
}
}