Skip to content

Commit

Permalink
add encode and decode
Browse files Browse the repository at this point in the history
  • Loading branch information
danny50610 committed Aug 19, 2023
1 parent ed12300 commit 888adc3
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 33 deletions.
16 changes: 16 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"$schema": "https://raw.githubusercontent.com/streetsidesoftware/cspell/main/cspell.schema.json",
// Version of the setting file. Always 0.2
"version": "0.2",
// language - current active spelling language
"language": "en",
"useGitignore": true,
"ignorePaths": [
"vender/**"
],
// words - list of words to be always considered correct
"words": [
"mergeable",
"tokeniser",
]
}
154 changes: 153 additions & 1 deletion src/Encoding.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,72 @@

namespace Danny50610\BpeTokeniser;

use Exception;
use ValueError;

class Encoding
{
protected $mergeableRanks;

protected $decodeMergeableRanks;

protected $pattenRegex;

public function __construct(&$mergeableRanks, $pattenRegex)
protected $specialRegex;

protected $specialTokens;

protected $decodeSpecialTokens;

public function __construct(&$mergeableRanks, string $pattenRegex, array $specialTokens)
{
$this->mergeableRanks = $mergeableRanks;
$this->pattenRegex = $pattenRegex . 'u'; // u for unicode
$this->specialTokens = $specialTokens;

$escapeToken = [];
foreach ($this->specialTokens as $token => $rank) {
$escapeToken[] = str_replace('|', '\|', $token);
}
$this->specialRegex = '/' . implode('|', $escapeToken) . '/u';

// for decode
$this->decodeMergeableRanks = [];
foreach ($this->mergeableRanks as $token => $rank) {
$this->decodeMergeableRanks[$rank] = $token;
}

if (count($this->mergeableRanks) !== count($this->decodeMergeableRanks)) {
throw new Exception('Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?');
}

$this->decodeSpecialTokens = [];
foreach ($this->specialTokens as $token => $rank) {
$this->decodeSpecialTokens[$rank] = $token;
}

/** TODO: check
self.max_token_value = max(
max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
)
if explicit_n_vocab:
assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
assert self.max_token_value == explicit_n_vocab - 1
*/
}

public function getSpecialTokensSet()
{
return array_keys($this->specialTokens);
}

/**
* Encodes a string into tokens, ignoring special tokens.
* This is equivalent to `encode($text, disallowedSpecial=[])` (but slightly faster).
*
* @param string $text
* @return int[]
*/
public function encodeOrdinary(string $text): array
{
$result = [];
Expand All @@ -33,6 +87,89 @@ public function encodeOrdinary(string $text): array
return $result;
}

public function encode(string $text, $allowedSpecial = [], $disallowedSpecial = 'all'): array
{
if ($allowedSpecial === 'all') {
$allowedSpecial = $this->getSpecialTokensSet();
}
if ($disallowedSpecial === 'all') {
$disallowedSpecial = array_diff($this->getSpecialTokensSet(), $allowedSpecial);
}
if (count($disallowedSpecial) > 0) {
$escapeToken = [];
foreach ($disallowedSpecial as $token) {
$escapeToken[] = str_replace('|', '\|', $token);
}
$disallowedSpecialRegex = '/' . implode('|', $escapeToken) . '/u';

preg_match_all($disallowedSpecialRegex, $text, $matches);
if (count($matches[0]) > 0) {
$token = $matches[0][0];
throw new ValueError(
"Encountered text corresponding to disallowed special token '{$token}'.\n" .
"If you want this text to be encoded as a special token, " .
"pass it to `allowedSpecial`, e.g. `allowedSpecial: ['{$token}', ...]`.\n" .
"If you want this text to be encoded as normal text, disable the check for this token " .
"by passing `disallowedSpecial: array_diff(\$enc->getSpecialTokensSet(), ['{$token}']))`.\n" .
"To disable this check for all special tokens, pass `disallowedSpecial: []`.\n"
);
}
}

$result = [];
$start = 0;
while (true) {
$hasNextSpecial = false;
$nextSpecial = null;

$startFind = $start;
while (true) {
// Find the next allowed special token, if any
preg_match($this->specialRegex, $text, $matches, PREG_OFFSET_CAPTURE, $startFind);
if (count($matches) > 0) {
if (in_array($matches[0][0], $allowedSpecial, true)) {
$hasNextSpecial = true;
$nextSpecial = $matches[0][0];
break;
}

$startFind = $matches[0][1] + 1;
} else {
break;
}
}
if ($hasNextSpecial) {
$end = $matches[0][1];
} else {
$end = strlen($text);
}

// Okay, here we go, compare this logic to _encode_ordinary_native
preg_match_all($this->pattenRegex, substr($text, $start, $end - $start), $matches);
foreach ($matches[0] as $match) {
$token = $this->mergeableRanks[$match] ?? null;
if (!is_null($token)) {
$result[] = $token;
} else {
$resultList = $this->bytePairEncode($match, $this->mergeableRanks);
foreach ($resultList as $item) {
$result[] = $item;
}
}
}

if ($hasNextSpecial) {
$token = $this->specialTokens[$nextSpecial];
$result[] = $token;
$start = $end + strlen($nextSpecial);
} else {
break;
}
}

return $result;
}

protected function bytePairEncode(string $piece, $ranks): array
{
// This is a vector of (start, rank).
Expand Down Expand Up @@ -119,4 +256,19 @@ protected function bytePairEncode(string $piece, $ranks): array

return $out;
}

public function decode(array $tokens): string
{
$result = '';
foreach ($tokens as $token) {
$out = $this->decodeMergeableRanks[$token];
if (is_null($out)) {
$out = $this->decodeSpecialTokens[$token];
}

$result .= $out;
}

return $result;
}
}
17 changes: 15 additions & 2 deletions src/EncodingFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

class EncodingFactory
{
protected const ENDOFTEXT = "<|endoftext|>";
protected const FIM_PREFIX = "<|fim_prefix|>";
protected const FIM_MIDDLE = "<|fim_middle|>";
protected const FIM_SUFFIX = "<|fim_suffix|>";
protected const ENDOFPROMPT = "<|endofprompt|>";

protected static $modelToEncoding = [
# chat
"gpt-4" => "cl100k_base",
Expand Down Expand Up @@ -118,8 +124,15 @@ protected static function initConstructor()
'cl100k_base' => function () {
$mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/cl100k_base.tiktoken');
$pattenRegex = "/(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/";

return new Encoding($mergeableRanks, $pattenRegex);
$specialTokens = [
self::ENDOFTEXT => 100257,
self::FIM_PREFIX => 100258,
self::FIM_MIDDLE => 100259,
self::FIM_SUFFIX => 100260,
self::ENDOFPROMPT => 100276,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens);
},
];
}
Expand Down
30 changes: 0 additions & 30 deletions tests/EncodingCl100kBaseTest.php

This file was deleted.

72 changes: 72 additions & 0 deletions tests/EncodingTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
<?php

namespace Danny50610\BpeTokeniser\Tests;

use Danny50610\BpeTokeniser\EncodingFactory;
use PHPUnit\Framework\TestCase;

class EncodingTest extends TestCase
{
/**
* @dataProvider textDataProvider
*/
public function testEncodeAndDecode($encodingName, $testCaseList)
{
$enc = EncodingFactory::createByEncodingName($encodingName);

foreach ($testCaseList as $testCase) {
[$text, $tokens] = $testCase;

$outputTokens = $enc->encode($text);
$this->assertSame($tokens, $outputTokens);

$outputText = $enc->decode($tokens);
$this->assertSame($text, $outputText);
}
}

public static function textDataProvider()
{
return [
[
'cl100k_base',
[
['tiktoken is great!', [83, 1609, 5963, 374, 2294, 0]],
['台北 101 高度 508 公尺', [55038, 49409, 220, 4645, 18630, 41519, 27479, 220, 19869, 35469, 105, 16175, 118]],
['🫡🍣顏文字', [9468, 104, 94, 9468, 235, 96, 14167, 237, 88435]],
],
]
];
}

// TODO: test: encodeOrdinary === encode($text, disallowedSpecial=[])

/**
* @dataProvider specialDataProvider
*/
public function testEncodeWithSpecial($encodingName, $testCaseList)
{
$enc = EncodingFactory::createByEncodingName($encodingName);

foreach ($testCaseList as $testCase) {
[$text, $tokens] = $testCase;

$outputTokens = $enc->encode($text, allowedSpecial: 'all');
$this->assertSame($tokens, $outputTokens);
}
}

public static function specialDataProvider()
{
return [
[
'cl100k_base',
[
['<|endoftext|>', [100257]],
['Hello World<|endoftext|>Hello danny.', [9906, 4435, 100257, 9906, 294, 13184, 13]],
['中文 <|endoftext|> 博大精深 aaa <|endofprompt|> bbbb', [16325, 17161, 220, 100257, 67621, 248, 27384, 90397, 122, 85315, 109, 84565, 220, 100276, 293, 54251]],
],
],
];
}
}

0 comments on commit 888adc3

Please sign in to comment.