6
6
#include " dr_wav.h"
7
7
8
8
#include < cmath>
9
+ #include < cstring>
9
10
#include < fstream>
10
11
#include < regex>
12
+ #include < locale>
13
+ #include < codecvt>
14
+ #include < sstream>
11
15
12
16
#ifndef M_PI
13
17
#define M_PI 3.14159265358979323846
14
18
#endif
15
19
20
+ #if defined(_MSC_VER)
21
+ #pragma warning(disable: 4244 4267) // possible loss of data
22
+ #endif
23
+
16
24
bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
17
25
for (int i = 1 ; i < argc; i++) {
18
26
std::string arg = argv[i];
@@ -52,7 +60,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
52
60
if (params.prompt .back () == ' \n ' ) {
53
61
params.prompt .pop_back ();
54
62
}
55
- } else {
63
+ } else if (arg == " -tt" || arg == " --token_test" ) {
64
+ params.token_test = argv[++i];
65
+ }
66
+ else {
56
67
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
57
68
gpt_print_usage (argc, argv, params);
58
69
exit (0 );
@@ -73,6 +84,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
73
84
fprintf (stderr, " prompt to start generation with (default: random)\n " );
74
85
fprintf (stderr, " -f FNAME, --file FNAME\n " );
75
86
fprintf (stderr, " load prompt from a file\n " );
87
+ fprintf (stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n " );
88
+ fprintf (stderr, " test tokenization\n " );
76
89
fprintf (stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n " , params.n_predict );
77
90
fprintf (stderr, " --top_k N top-k sampling (default: %d)\n " , params.top_k );
78
91
fprintf (stderr, " --top_p N top-p sampling (default: %.1f)\n " , params.top_p );
@@ -117,6 +130,10 @@ std::string replace(const std::string & s, const std::string & from, const std::
117
130
return result;
118
131
}
119
132
133
+ void gpt_vocab::add_special_token (const std::string & token) {
134
+ special_tokens.push_back (token);
135
+ }
136
+
120
137
std::map<std::string, int32_t > json_parse (const std::string & fname) {
121
138
std::map<std::string, int32_t > result;
122
139
@@ -208,8 +225,28 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
208
225
return result;
209
226
}
210
227
211
- void gpt_vocab::add_special_token (const std::string & token) {
212
- special_tokens.push_back (token);
228
+ std::string convert_to_utf8 (const std::wstring & input) {
229
+ std::wstring_convert<std::codecvt_utf8<wchar_t >> converter;
230
+ return converter.to_bytes (input);
231
+ }
232
+
233
+
234
+ std::wstring convert_to_wstring (const std::string & input) {
235
+ std::wstring_convert<std::codecvt_utf8<wchar_t >> converter;
236
+ return converter.from_bytes (input);
237
+ }
238
+
239
+ void gpt_split_words (std::string str, std::vector<std::string>& words) {
240
+ const std::string pattern = R"( 's|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" ;
241
+ const std::regex re (pattern);
242
+ std::smatch m;
243
+
244
+ while (std::regex_search (str, m, re)) {
245
+ for (auto x : m) {
246
+ words.push_back (x);
247
+ }
248
+ str = m.suffix ();
249
+ }
213
250
}
214
251
215
252
std::vector<gpt_vocab::id> gpt_tokenize (const gpt_vocab & vocab, const std::string & text) {
@@ -218,70 +255,123 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
218
255
// first split the text into words
219
256
{
220
257
std::string str = text;
221
- std::string pat = R"( 's|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" ;
222
258
223
259
// Generate the subpattern from the special_tokens vector if it's not empty
224
260
if (!vocab.special_tokens .empty ()) {
261
+ const std::regex escape (R"( [\[\\\^\$\.\|\?\*\+\(\)\{\}])" );
225
262
std::string special_tokens_subpattern;
226
263
for (const auto & token : vocab.special_tokens ) {
227
264
if (!special_tokens_subpattern.empty ()) {
228
265
special_tokens_subpattern += " |" ;
229
266
}
230
- special_tokens_subpattern += token;
267
+ special_tokens_subpattern += std::regex_replace ( token, escape, R"( \$& )" ) ;
231
268
}
232
269
233
- // Modify the regex pattern with the generated special tokens subpattern
234
- pat = special_tokens_subpattern + " |" + pat;
235
- }
236
-
237
- std::regex re (pat);
238
- std::smatch m;
239
-
240
- while (std::regex_search (str, m, re)) {
241
- for (auto x : m) {
242
- words.push_back (x);
270
+ std::regex re (special_tokens_subpattern);
271
+ std::smatch m;
272
+ // Split the text by special tokens.
273
+ while (std::regex_search (str, m, re)) {
274
+ // Split the substrings in-between special tokens into words.
275
+ gpt_split_words (m.prefix (), words);
276
+ // Add matched special tokens as words.
277
+ for (auto x : m) {
278
+ words.push_back (x);
279
+ }
280
+ str = m.suffix ();
243
281
}
244
- str = m. suffix ();
282
+ // Remaining text without special tokens will be handled below.
245
283
}
284
+
285
+ gpt_split_words (str, words);
246
286
}
247
287
248
- // find the longest tokens that form the words:
288
+ // find the longest token that forms each word in words:
249
289
std::vector<gpt_vocab::id> tokens;
250
290
for (const auto & word : words) {
251
- if (word.size () == 0 ) continue ;
252
-
253
- int i = 0 ;
254
- int n = word.size ();
255
- while (i < n) {
256
- int j = n;
257
- while (j > i) {
258
- auto it = vocab.token_to_id .find (word.substr (i, j-i));
259
- if (it != vocab.token_to_id .end ()) {
291
+ for (int i = 0 ; i < (int ) word.size (); ){
292
+ for (int j = word.size () - 1 ; j >= i; j--){
293
+ auto cand = word.substr (i, j-i+1 );
294
+ auto it = vocab.token_to_id .find (cand);
295
+ if (it != vocab.token_to_id .end ()){ // word.substr(i, j-i+1) in vocab
260
296
tokens.push_back (it->second );
261
- i = j;
262
- j = n;
297
+ i = j + 1 ;
263
298
break ;
264
299
}
265
- --j;
266
- }
267
- if (i == n) {
268
- break ;
269
- }
270
- if (j == i) {
271
- auto sub = word.substr (i, 1 );
272
- if (vocab.token_to_id .find (sub) != vocab.token_to_id .end ()) {
273
- tokens.push_back (vocab.token_to_id .at (sub));
274
- } else {
275
- fprintf (stderr, " %s: unknown token '%s'\n " , __func__, sub.data ());
300
+ else if (j == i){ // word.substr(i, 1) has no matching
301
+ fprintf (stderr, " %s: unknown token '%s'\n " , __func__, word.substr (i, 1 ).data ());
302
+ i++;
276
303
}
277
- ++i;
278
304
}
279
305
}
280
306
}
281
307
282
308
return tokens;
283
309
}
284
310
311
+ std::vector<gpt_vocab::id> parse_tokens_from_string (const std::string& input, char delimiter) {
312
+ std::vector<gpt_vocab::id> output;
313
+ std::stringstream ss (input);
314
+ std::string token;
315
+
316
+ while (std::getline (ss, token, delimiter)) {
317
+ output.push_back (std::stoi (token));
318
+ }
319
+
320
+ return output;
321
+ }
322
+
323
+ std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file (const std::string & fpath_test){
324
+ if (fpath_test.empty ()){
325
+ fprintf (stderr, " %s : No test file found.\n " , __func__);
326
+ return std::map<std::string, std::vector<gpt_vocab::id>>();
327
+ }
328
+
329
+ std::map<std::string, std::vector<gpt_vocab::id>> tests;
330
+
331
+ auto fin = std::ifstream (fpath_test, std::ios_base::in);
332
+ const char * delimeter = " => " ;
333
+ const char del_tok = ' ,' ;
334
+ std::string line;
335
+ while (std::getline (fin, line)) {
336
+ size_t delimiterPos = line.find (delimeter);
337
+ if (delimiterPos != std::string::npos) {
338
+ std::string text = line.substr (0 , delimiterPos);
339
+ std::string s_tokens = line.substr (delimiterPos + std::strlen (delimeter));
340
+ tests[text] = parse_tokens_from_string (s_tokens, del_tok);
341
+ }
342
+ }
343
+ return tests;
344
+ }
345
+
346
+ void test_gpt_tokenizer (gpt_vocab & vocab, const std::string & fpath_test){
347
+ std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file (fpath_test);
348
+
349
+ size_t n_fails = 0 ;
350
+
351
+ for (const auto & test : tests) {
352
+ std::vector<gpt_vocab::id> tokens = gpt_tokenize (vocab, test.first );
353
+
354
+ if (tokens != test.second ){
355
+ n_fails++;
356
+
357
+ // print out failure cases
358
+ fprintf (stderr, " %s : failed test: '%s'\n " , __func__, test.first .c_str ());
359
+ fprintf (stderr, " %s : tokens in hf: " , __func__);
360
+ for (const auto & t : test.second ) {
361
+ fprintf (stderr, " %s(%d), " , vocab.id_to_token [t].c_str (), t);
362
+ }
363
+ fprintf (stderr, " \n " );
364
+ fprintf (stderr, " %s : tokens in ggml: " , __func__);
365
+ for (const auto & t : tokens) {
366
+ fprintf (stderr, " %s(%d), " , vocab.id_to_token [t].c_str (), t);
367
+ }
368
+ fprintf (stderr, " \n " );
369
+ }
370
+ }
371
+
372
+ fprintf (stderr, " %s : %zu tests failed out of %zu tests.\n " , __func__, n_fails, tests.size ());
373
+ }
374
+
285
375
bool gpt_vocab_init (const std::string & fname, gpt_vocab & vocab) {
286
376
printf (" %s: loading vocab from '%s'\n " , __func__, fname.c_str ());
287
377
@@ -381,6 +471,122 @@ gpt_vocab::id gpt_sample_top_k_top_p(
381
471
return logits_id[idx].second ;
382
472
}
383
473
474
+ gpt_vocab::id gpt_sample_top_k_top_p_repeat (
475
+ const gpt_vocab & vocab,
476
+ const float * logits,
477
+ const int32_t * last_n_tokens_data,
478
+ size_t last_n_tokens_data_size,
479
+ int top_k,
480
+ double top_p,
481
+ double temp,
482
+ int repeat_last_n,
483
+ float repeat_penalty,
484
+ std::mt19937 & rng) {
485
+
486
+ int n_logits = vocab.id_to_token .size ();
487
+
488
+ const auto * plogits = logits;
489
+
490
+ const auto last_n_tokens = std::vector<int32_t >(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
491
+
492
+ if (temp <= 0 ) {
493
+ // select the token with the highest logit directly
494
+ float max_logit = plogits[0 ];
495
+ gpt_vocab::id max_id = 0 ;
496
+
497
+ for (int i = 1 ; i < n_logits; ++i) {
498
+ if (plogits[i] > max_logit) {
499
+ max_logit = plogits[i];
500
+ max_id = i;
501
+ }
502
+ }
503
+ return max_id;
504
+ }
505
+
506
+
507
+ std::vector<std::pair<double , gpt_vocab::id>> logits_id;
508
+ logits_id.reserve (n_logits);
509
+
510
+ {
511
+ const float scale = 1 .0f /temp;
512
+ for (int i = 0 ; i < n_logits; ++i) {
513
+ // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
514
+ // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
515
+ if (repeat_last_n > 0 && std::find (last_n_tokens.end ()-repeat_last_n, last_n_tokens.end (), i) != last_n_tokens.end ()) {
516
+ // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
517
+ if (plogits[i] < 0 .0f ) {
518
+ logits_id.push_back (std::make_pair (plogits[i]*scale*repeat_penalty, i));
519
+ } else {
520
+ logits_id.push_back (std::make_pair (plogits[i]*scale/repeat_penalty, i));
521
+ }
522
+ } else {
523
+ logits_id.push_back (std::make_pair (plogits[i]*scale, i));
524
+ }
525
+ }
526
+ }
527
+
528
+ // find the top K tokens
529
+ std::partial_sort (
530
+ logits_id.begin (),
531
+ logits_id.begin () + top_k, logits_id.end (),
532
+ [](const std::pair<double , gpt_vocab::id> & a, const std::pair<double , gpt_vocab::id> & b) {
533
+ return a.first > b.first ;
534
+ });
535
+
536
+ logits_id.resize (top_k);
537
+
538
+ double maxl = -INFINITY;
539
+ for (const auto & kv : logits_id) {
540
+ maxl = std::max (maxl, kv.first );
541
+ }
542
+
543
+ // compute probs for the top K tokens
544
+ std::vector<double > probs;
545
+ probs.reserve (logits_id.size ());
546
+
547
+ double sum = 0.0 ;
548
+ for (const auto & kv : logits_id) {
549
+ double p = exp (kv.first - maxl);
550
+ probs.push_back (p);
551
+ sum += p;
552
+ }
553
+
554
+ // normalize the probs
555
+ for (auto & p : probs) {
556
+ p /= sum;
557
+ }
558
+
559
+ if (top_p < 1 .0f ) {
560
+ double cumsum = 0 .0f ;
561
+ for (int i = 0 ; i < top_k; i++) {
562
+ cumsum += probs[i];
563
+ if (cumsum >= top_p) {
564
+ top_k = i + 1 ;
565
+ probs.resize (top_k);
566
+ logits_id.resize (top_k);
567
+ break ;
568
+ }
569
+ }
570
+
571
+ cumsum = 1.0 /cumsum;
572
+ for (int i = 0 ; i < (int ) probs.size (); i++) {
573
+ probs[i] *= cumsum;
574
+ }
575
+ }
576
+
577
+ // printf("\n");
578
+ // for (int i = 0; i < (int) probs.size(); i++) {
579
+ // for (int i = 0; i < 10; i++) {
580
+ // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
581
+ // }
582
+
583
+ std::discrete_distribution<> dist (probs.begin (), probs.end ());
584
+ int idx = dist (rng);
585
+
586
+ return logits_id[idx].second ;
587
+
588
+ }
589
+
384
590
bool read_wav (const std::string & fname, std::vector<float >& pcmf32, std::vector<std::vector<float >>& pcmf32s, bool stereo) {
385
591
drwav wav;
386
592
std::vector<uint8_t > wav_data; // used for pipe input from stdin
0 commit comments