From 5d4045a961f761ed6468232c7d60cf402a1438ab Mon Sep 17 00:00:00 2001 From: "Chan, Danny" Date: Sun, 20 Aug 2023 07:43:27 +0000 Subject: [PATCH] add more test --- src/Encoding.php | 10 +++++++- src/EncodingFactory.php | 17 +++++++------ tests/EncodingFactoryTest.php | 47 +++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 tests/EncodingFactoryTest.php diff --git a/src/Encoding.php b/src/Encoding.php index a5b0737..e084160 100644 --- a/src/Encoding.php +++ b/src/Encoding.php @@ -7,6 +7,8 @@ class Encoding { + protected $name; + protected $mergeableRanks; protected $decodeMergeableRanks; @@ -19,8 +21,9 @@ class Encoding protected $decodeSpecialTokens; - public function __construct(&$mergeableRanks, string $pattenRegex, array $specialTokens = [], ?int $explicitNVocab = null) + public function __construct(string $name, &$mergeableRanks, string $pattenRegex, array $specialTokens = [], ?int $explicitNVocab = null) { + $this->name = $name; $this->mergeableRanks = $mergeableRanks; $this->pattenRegex = $pattenRegex . 'u'; // u for unicode $this->specialTokens = $specialTokens; @@ -64,6 +67,11 @@ public function __construct(&$mergeableRanks, string $pattenRegex, array $specia } } + public function getName() + { + return $this->name; + } + public function getSpecialTokensSet() { return array_keys($this->specialTokens); diff --git a/src/EncodingFactory.php b/src/EncodingFactory.php index af267c6..56453a1 100644 --- a/src/EncodingFactory.php +++ b/src/EncodingFactory.php @@ -4,6 +4,7 @@ use Closure; use Exception; +use InvalidArgumentException; use SplFileObject; class EncodingFactory @@ -115,8 +116,8 @@ public static function createByModelName(string $modelName): Encoding } if (is_null($encodingName)) { - throw new Exception( - "Could not automatically map {$modelName} to a tokeniser." + + throw new InvalidArgumentException( + "Could not automatically map \"{$modelName}\" to a tokeniser. " . "Please use `createByEncodingName` to explicitly get the tokeniser you expect." ); } @@ -133,7 +134,7 @@ public static function createByEncodingName(string $encodingName): Encoding static::initConstructor(); if (!array_key_exists($encodingName, static::$encodingConstructors)) { - throw new Exception("Unknown encoding {$encodingName}"); + throw new InvalidArgumentException("Unknown encoding: \"{$encodingName}\""); } $constructor = static::$encodingConstructors[$encodingName]; @@ -158,7 +159,7 @@ protected static function initConstructor() self::ENDOFTEXT => 50256, ]; - return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257); + return new Encoding('gpt2', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257); }, 'r50k_base' => function () { $mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/r50k_base.tiktoken'); @@ -167,7 +168,7 @@ protected static function initConstructor() self::ENDOFTEXT => 50256, ]; - return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257); + return new Encoding('r50k_base', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257); }, 'p50k_base' => function () { $mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/p50k_base.tiktoken'); @@ -176,7 +177,7 @@ protected static function initConstructor() self::ENDOFTEXT => 50256, ]; - return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50281); + return new Encoding('p50k_base', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50281); }, 'p50k_edit' => function () { $mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/p50k_base.tiktoken'); @@ -188,7 +189,7 @@ protected static function initConstructor() self::FIM_SUFFIX => 50283, ]; - return new Encoding($mergeableRanks, $pattenRegex, $specialTokens); + return new Encoding('p50k_edit', $mergeableRanks, $pattenRegex, $specialTokens); }, 'cl100k_base' => function () { $mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/cl100k_base.tiktoken'); @@ -201,7 +202,7 @@ protected static function initConstructor() self::ENDOFPROMPT => 100276, ]; - return new Encoding($mergeableRanks, $pattenRegex, $specialTokens); + return new Encoding('cl100k_base', $mergeableRanks, $pattenRegex, $specialTokens); }, ]; } diff --git a/tests/EncodingFactoryTest.php b/tests/EncodingFactoryTest.php new file mode 100644 index 0000000..edddff6 --- /dev/null +++ b/tests/EncodingFactoryTest.php @@ -0,0 +1,47 @@ +assertSame('cl100k_base', $enc->getName()); + } + + public function testCreateByEncodingNameWithNonExist() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Unknown encoding: "danny"'); + + EncodingFactory::createByEncodingName('danny'); + } + + public function testCreateByModelName() + { + $enc = EncodingFactory::createByModelName('gpt-4'); + + $this->assertSame('cl100k_base', $enc->getName()); + } + + public function testCreateByModelNameUsePrefix() + { + $enc = EncodingFactory::createByModelName('gpt-3.5-turbo-0301'); + + $this->assertSame('cl100k_base', $enc->getName()); + } + + public function testCreateByModelNameWithNonExist() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Could not automatically map "danny" to a tokeniser. Please use `createByEncodingName` to explicitly get the tokeniser you expect.'); + + EncodingFactory::createByModelName('danny'); + } +}