diff --git a/cspell.json b/cspell.json new file mode 100644 index 0000000..2e5947c --- /dev/null +++ b/cspell.json @@ -0,0 +1,16 @@ +{ + "$schema": "https://raw.githubusercontent.com/streetsidesoftware/cspell/main/cspell.schema.json", + // Version of the setting file. Always 0.2 + "version": "0.2", + // language - current active spelling language + "language": "en", + "useGitignore": true, + "ignorePaths": [ + "vender/**" + ], + // words - list of words to be always considered correct + "words": [ + "mergeable", + "tokeniser", + ] +} diff --git a/src/Encoding.php b/src/Encoding.php index 1efcb85..6720a24 100644 --- a/src/Encoding.php +++ b/src/Encoding.php @@ -2,18 +2,72 @@ namespace Danny50610\BpeTokeniser; +use Exception; +use ValueError; + class Encoding { protected $mergeableRanks; + protected $decodeMergeableRanks; + protected $pattenRegex; - public function __construct(&$mergeableRanks, $pattenRegex) + protected $specialRegex; + + protected $specialTokens; + + protected $decodeSpecialTokens; + + public function __construct(&$mergeableRanks, string $pattenRegex, array $specialTokens) { $this->mergeableRanks = $mergeableRanks; $this->pattenRegex = $pattenRegex . 'u'; // u for unicode + $this->specialTokens = $specialTokens; + + $escapeToken = []; + foreach ($this->specialTokens as $token => $rank) { + $escapeToken[] = str_replace('|', '\|', $token); + } + $this->specialRegex = '/' . implode('|', $escapeToken) . '/u'; + + // for decode + $this->decodeMergeableRanks = []; + foreach ($this->mergeableRanks as $token => $rank) { + $this->decodeMergeableRanks[$rank] = $token; + } + + if (count($this->mergeableRanks) !== count($this->decodeMergeableRanks)) { + throw new Exception('Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?'); + } + + $this->decodeSpecialTokens = []; + foreach ($this->specialTokens as $token => $rank) { + $this->decodeSpecialTokens[$rank] = $token; + } + + /** TODO: check + self.max_token_value = max( + max(mergeable_ranks.values()), max(special_tokens.values(), default=0) + ) + if explicit_n_vocab: + assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab + assert self.max_token_value == explicit_n_vocab - 1 + */ } + public function getSpecialTokensSet() + { + return array_keys($this->specialTokens); + } + + /** + * Encodes a string into tokens, ignoring special tokens. + * This is equivalent to `encode($text, disallowedSpecial=[])` (but slightly faster). + * + * @param string $text + * @return int[] + */ public function encodeOrdinary(string $text): array { $result = []; @@ -33,6 +87,89 @@ public function encodeOrdinary(string $text): array return $result; } + public function encode(string $text, $allowedSpecial = [], $disallowedSpecial = 'all'): array + { + if ($allowedSpecial === 'all') { + $allowedSpecial = $this->getSpecialTokensSet(); + } + if ($disallowedSpecial === 'all') { + $disallowedSpecial = array_diff($this->getSpecialTokensSet(), $allowedSpecial); + } + if (count($disallowedSpecial) > 0) { + $escapeToken = []; + foreach ($disallowedSpecial as $token) { + $escapeToken[] = str_replace('|', '\|', $token); + } + $disallowedSpecialRegex = '/' . implode('|', $escapeToken) . '/u'; + + preg_match_all($disallowedSpecialRegex, $text, $matches); + if (count($matches[0]) > 0) { + $token = $matches[0][0]; + throw new ValueError( + "Encountered text corresponding to disallowed special token '{$token}'.\n" . + "If you want this text to be encoded as a special token, " . + "pass it to `allowedSpecial`, e.g. `allowedSpecial: ['{$token}', ...]`.\n" . + "If you want this text to be encoded as normal text, disable the check for this token " . + "by passing `disallowedSpecial: array_diff(\$enc->getSpecialTokensSet(), ['{$token}']))`.\n" . + "To disable this check for all special tokens, pass `disallowedSpecial: []`.\n" + ); + } + } + + $result = []; + $start = 0; + while (true) { + $hasNextSpecial = false; + $nextSpecial = null; + + $startFind = $start; + while (true) { + // Find the next allowed special token, if any + preg_match($this->specialRegex, $text, $matches, PREG_OFFSET_CAPTURE, $startFind); + if (count($matches) > 0) { + if (in_array($matches[0][0], $allowedSpecial, true)) { + $hasNextSpecial = true; + $nextSpecial = $matches[0][0]; + break; + } + + $startFind = $matches[0][1] + 1; + } else { + break; + } + } + if ($hasNextSpecial) { + $end = $matches[0][1]; + } else { + $end = strlen($text); + } + + // Okay, here we go, compare this logic to _encode_ordinary_native + preg_match_all($this->pattenRegex, substr($text, $start, $end - $start), $matches); + foreach ($matches[0] as $match) { + $token = $this->mergeableRanks[$match] ?? null; + if (!is_null($token)) { + $result[] = $token; + } else { + $resultList = $this->bytePairEncode($match, $this->mergeableRanks); + foreach ($resultList as $item) { + $result[] = $item; + } + } + } + + if ($hasNextSpecial) { + $token = $this->specialTokens[$nextSpecial]; + $result[] = $token; + $start = $end + strlen($nextSpecial); + } else { + break; + } + } + + return $result; + } + protected function bytePairEncode(string $piece, $ranks): array { // This is a vector of (start, rank). @@ -119,4 +256,19 @@ protected function bytePairEncode(string $piece, $ranks): array return $out; } + + public function decode(array $tokens): string + { + $result = ''; + foreach ($tokens as $token) { + $out = $this->decodeMergeableRanks[$token]; + if (is_null($out)) { + $out = $this->decodeSpecialTokens[$token]; + } + + $result .= $out; + } + + return $result; + } } diff --git a/src/EncodingFactory.php b/src/EncodingFactory.php index e234d3f..a8353f2 100644 --- a/src/EncodingFactory.php +++ b/src/EncodingFactory.php @@ -7,6 +7,12 @@ class EncodingFactory { + protected const ENDOFTEXT = "<|endoftext|>"; + protected const FIM_PREFIX = "<|fim_prefix|>"; + protected const FIM_MIDDLE = "<|fim_middle|>"; + protected const FIM_SUFFIX = "<|fim_suffix|>"; + protected const ENDOFPROMPT = "<|endofprompt|>"; + protected static $modelToEncoding = [ # chat "gpt-4" => "cl100k_base", @@ -118,8 +124,15 @@ protected static function initConstructor() 'cl100k_base' => function () { $mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/cl100k_base.tiktoken'); $pattenRegex = "/(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/"; - - return new Encoding($mergeableRanks, $pattenRegex); + $specialTokens = [ + self::ENDOFTEXT => 100257, + self::FIM_PREFIX => 100258, + self::FIM_MIDDLE => 100259, + self::FIM_SUFFIX => 100260, + self::ENDOFPROMPT => 100276, + ]; + + return new Encoding($mergeableRanks, $pattenRegex, $specialTokens); }, ]; } diff --git a/tests/EncodingCl100kBaseTest.php b/tests/EncodingCl100kBaseTest.php deleted file mode 100644 index 5e9ec8a..0000000 --- a/tests/EncodingCl100kBaseTest.php +++ /dev/null @@ -1,30 +0,0 @@ -encodeOrdinary($text); - - $this->assertSame($expectedTokens, $tokens); - } - - public static function textDataProvider() - { - return [ - ['tiktoken is great!', [83, 1609, 5963, 374, 2294, 0]], - ['台北 101 高度 508 公尺', [55038, 49409, 220, 4645, 18630, 41519, 27479, 220, 19869, 35469, 105, 16175, 118]], - // TODO: 表情符號 - ]; - } - - // TODO: Encode -> Decode chain -} diff --git a/tests/EncodingTest.php b/tests/EncodingTest.php new file mode 100644 index 0000000..cfef300 --- /dev/null +++ b/tests/EncodingTest.php @@ -0,0 +1,72 @@ +encode($text); + $this->assertSame($tokens, $outputTokens); + + $outputText = $enc->decode($tokens); + $this->assertSame($text, $outputText); + } + } + + public static function textDataProvider() + { + return [ + [ + 'cl100k_base', + [ + ['tiktoken is great!', [83, 1609, 5963, 374, 2294, 0]], + ['台北 101 高度 508 公尺', [55038, 49409, 220, 4645, 18630, 41519, 27479, 220, 19869, 35469, 105, 16175, 118]], + ['🫡🍣顏文字', [9468, 104, 94, 9468, 235, 96, 14167, 237, 88435]], + ], + ] + ]; + } + + // TODO: test: encodeOrdinary === encode($text, disallowedSpecial=[]) + + /** + * @dataProvider specialDataProvider + */ + public function testEncodeWithSpecial($encodingName, $testCaseList) + { + $enc = EncodingFactory::createByEncodingName($encodingName); + + foreach ($testCaseList as $testCase) { + [$text, $tokens] = $testCase; + + $outputTokens = $enc->encode($text, allowedSpecial: 'all'); + $this->assertSame($tokens, $outputTokens); + } + } + + public static function specialDataProvider() + { + return [ + [ + 'cl100k_base', + [ + ['<|endoftext|>', [100257]], + ['Hello World<|endoftext|>Hello danny.', [9906, 4435, 100257, 9906, 294, 13184, 13]], + ['中文 <|endoftext|> 博大精深 aaa <|endofprompt|> bbbb', [16325, 17161, 220, 100257, 67621, 248, 27384, 90397, 122, 85315, 109, 84565, 220, 100276, 293, 54251]], + ], + ], + ]; + } +}