Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
danny50610 committed Aug 20, 2023
1 parent ed6f622 commit 5d4045a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
10 changes: 9 additions & 1 deletion src/Encoding.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

class Encoding
{
protected $name;

protected $mergeableRanks;

protected $decodeMergeableRanks;
Expand All @@ -19,8 +21,9 @@ class Encoding

protected $decodeSpecialTokens;

public function __construct(&$mergeableRanks, string $pattenRegex, array $specialTokens = [], ?int $explicitNVocab = null)
public function __construct(string $name, &$mergeableRanks, string $pattenRegex, array $specialTokens = [], ?int $explicitNVocab = null)
{
$this->name = $name;
$this->mergeableRanks = $mergeableRanks;
$this->pattenRegex = $pattenRegex . 'u'; // u for unicode
$this->specialTokens = $specialTokens;
Expand Down Expand Up @@ -64,6 +67,11 @@ public function __construct(&$mergeableRanks, string $pattenRegex, array $specia
}
}

public function getName()
{
return $this->name;
}

public function getSpecialTokensSet()
{
return array_keys($this->specialTokens);
Expand Down
17 changes: 9 additions & 8 deletions src/EncodingFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use Closure;
use Exception;
use InvalidArgumentException;
use SplFileObject;

class EncodingFactory
Expand Down Expand Up @@ -115,8 +116,8 @@ public static function createByModelName(string $modelName): Encoding
}

if (is_null($encodingName)) {
throw new Exception(
"Could not automatically map {$modelName} to a tokeniser." +
throw new InvalidArgumentException(
"Could not automatically map \"{$modelName}\" to a tokeniser. " .
"Please use `createByEncodingName` to explicitly get the tokeniser you expect."
);
}
Expand All @@ -133,7 +134,7 @@ public static function createByEncodingName(string $encodingName): Encoding
static::initConstructor();

if (!array_key_exists($encodingName, static::$encodingConstructors)) {
throw new Exception("Unknown encoding {$encodingName}");
throw new InvalidArgumentException("Unknown encoding: \"{$encodingName}\"");
}

$constructor = static::$encodingConstructors[$encodingName];
Expand All @@ -158,7 +159,7 @@ protected static function initConstructor()
self::ENDOFTEXT => 50256,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257);
return new Encoding('gpt2', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257);
},
'r50k_base' => function () {
$mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/r50k_base.tiktoken');
Expand All @@ -167,7 +168,7 @@ protected static function initConstructor()
self::ENDOFTEXT => 50256,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257);
return new Encoding('r50k_base', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50257);
},
'p50k_base' => function () {
$mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/p50k_base.tiktoken');
Expand All @@ -176,7 +177,7 @@ protected static function initConstructor()
self::ENDOFTEXT => 50256,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50281);
return new Encoding('p50k_base', $mergeableRanks, $pattenRegex, $specialTokens, explicitNVocab: 50281);
},
'p50k_edit' => function () {
$mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/p50k_base.tiktoken');
Expand All @@ -188,7 +189,7 @@ protected static function initConstructor()
self::FIM_SUFFIX => 50283,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens);
return new Encoding('p50k_edit', $mergeableRanks, $pattenRegex, $specialTokens);
},
'cl100k_base' => function () {
$mergeableRanks = static::loadTiktokenBpe(__DIR__ . '/../assets/cl100k_base.tiktoken');
Expand All @@ -201,7 +202,7 @@ protected static function initConstructor()
self::ENDOFPROMPT => 100276,
];

return new Encoding($mergeableRanks, $pattenRegex, $specialTokens);
return new Encoding('cl100k_base', $mergeableRanks, $pattenRegex, $specialTokens);
},
];
}
Expand Down
47 changes: 47 additions & 0 deletions tests/EncodingFactoryTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<?php

namespace Danny50610\BpeTokeniser\Tests;

use Danny50610\BpeTokeniser\EncodingFactory;
use InvalidArgumentException;
use PHPUnit\Framework\TestCase;

class EncodingFactoryTest extends TestCase
{
public function testCreateByEncodingName()
{
$enc = EncodingFactory::createByEncodingName('cl100k_base');

$this->assertSame('cl100k_base', $enc->getName());
}

public function testCreateByEncodingNameWithNonExist()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Unknown encoding: "danny"');

EncodingFactory::createByEncodingName('danny');
}

public function testCreateByModelName()
{
$enc = EncodingFactory::createByModelName('gpt-4');

$this->assertSame('cl100k_base', $enc->getName());
}

public function testCreateByModelNameUsePrefix()
{
$enc = EncodingFactory::createByModelName('gpt-3.5-turbo-0301');

$this->assertSame('cl100k_base', $enc->getName());
}

public function testCreateByModelNameWithNonExist()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Could not automatically map "danny" to a tokeniser. Please use `createByEncodingName` to explicitly get the tokeniser you expect.');

EncodingFactory::createByModelName('danny');
}
}

0 comments on commit 5d4045a

Please sign in to comment.