Skip to content

Commit 2c301e9

Browse files
authored
common : handle unicode during partial json parsing (ggml-org#16526)
* common : handle unicode during partial json parsing * common : set missing `ensure_ascii = true` during json dump
1 parent 4b2dae3 commit 2c301e9

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed

common/chat-parser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
432432
if (is_arguments_path({})) {
433433
// Entire JSON is the arguments and was parsed fully.
434434
return consume_json_result {
435-
partial->json.dump(),
435+
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
436436
/* .is_partial = */ false,
437437
};
438438
}
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
444444
std::vector<std::string> path;
445445
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
446446
if (is_arguments_path(path)) {
447-
auto arguments = j.dump();
447+
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
448448
if (is_partial() && !partial->healing_marker.marker.empty()) {
449449
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
450450
if (idx != std::string::npos) {

common/json-partial.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <nlohmann/json.hpp>
66

77
#include <string>
8+
#include <regex>
89

910
using json = nlohmann::ordered_json;
1011

@@ -168,6 +169,47 @@ bool common_json_parse(
168169
}
169170
}
170171

172+
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
173+
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
174+
175+
auto is_high_surrogate = [&](const std::string & s) {
176+
// Check if a partial of a high surrogate (U+D800-U+DBFF)
177+
return s.length() >= 4 &&
178+
s[0] == '\\' && s[1] == 'u' &&
179+
std::tolower(s[2]) == 'd' &&
180+
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
181+
};
182+
183+
// Initialize the unicode marker to a low surrogate to handle the edge case
184+
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
185+
// backslash (\)
186+
std::string unicode_marker_padding = "udc00";
187+
std::smatch last_unicode_seq;
188+
189+
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
190+
std::smatch second_last_seq;
191+
std::string prelude = str.substr(0, last_unicode_seq.position());
192+
193+
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
194+
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
195+
196+
if (is_high_surrogate(last_unicode_seq.str())) {
197+
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
198+
unicode_marker_padding += "\\udc00";
199+
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
200+
if (is_high_surrogate(second_last_seq.str())) {
201+
// If this follows a high surrogate, pad it to be a low surrogate
202+
if (last_unicode_seq.length() == 2) {
203+
unicode_marker_padding = "dc00";
204+
} else if (last_unicode_seq.length() == 3) {
205+
unicode_marker_padding = "c00";
206+
} else {
207+
// The original unicode_marker_padding is already padded with 0s
208+
}
209+
}
210+
}
211+
}
212+
171213
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
172214

173215
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
@@ -186,6 +228,9 @@ bool common_json_parse(
186228
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
187229
// Was inside an object value string after an escape
188230
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
231+
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
232+
// Was inside an object value string after a partial unicode escape
233+
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
189234
} else {
190235
// find last :
191236
auto last_pos = str.find_last_of(':');
@@ -205,6 +250,9 @@ bool common_json_parse(
205250
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
206251
// Was inside an array value string after an escape
207252
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
253+
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
254+
// Was inside an array value string after a partial unicode escape
255+
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
208256
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
209257
// Had just finished a value
210258
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
@@ -230,6 +278,9 @@ bool common_json_parse(
230278
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
231279
// Was inside an object key string after an escape
232280
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
281+
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
282+
// Was inside an object key string after a partial unicode escape
283+
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
233284
} else {
234285
auto last_pos = str.find_last_of(':');
235286
if (last_pos == std::string::npos) {

tests/test-chat-parser.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
524524
R"({"foo": "bar", "args": {"arg1": [)",
525525
R"({"foo":"bar","args":"{\"arg1\":["})"
526526
);
527+
528+
// Unicode tests
529+
test_with_args(
530+
R"({"foo": "bar", "args": {"arg1": "\u)",
531+
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
532+
);
533+
test_with_args(
534+
R"({"foo": "bar", "args": {"arg1": "\u0)",
535+
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
536+
);
537+
test_with_args(
538+
R"({"foo": "bar", "args": {"arg1": "\u00)",
539+
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
540+
);
541+
test_with_args(
542+
R"({"foo": "bar", "args": {"arg1": "\u000)",
543+
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
544+
);
545+
test_with_args(
546+
R"({"foo": "bar", "args": {"arg1": "\u0000)",
547+
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
548+
);
549+
test_with_args(
550+
R"({"foo": "bar", "args": {"arg1": "\ud8)",
551+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
552+
);
553+
test_with_args(
554+
R"({"foo": "bar", "args": {"arg1": "\ud80)",
555+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
556+
);
557+
test_with_args(
558+
R"({"foo": "bar", "args": {"arg1": "\ud800)",
559+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
560+
);
561+
test_with_args(
562+
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
563+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
564+
);
565+
test_with_args(
566+
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
567+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
568+
);
569+
test_with_args(
570+
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
571+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
572+
);
573+
test_with_args(
574+
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
575+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
576+
);
577+
test_with_args(
578+
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
579+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
580+
);
581+
test_with_args(
582+
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
583+
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
584+
);
527585
}
528586

529587
static void test_positions() {

tests/test-json-partial.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static void test_json_healing() {
5858
for (const auto & input : inputs) {
5959
common_json out;
6060
assert_equals(true, common_json_parse(input, "$foo", out));
61-
assert_equals<std::string>(expected, out.json.dump());
61+
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
6262
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
6363
}
6464
};
@@ -228,6 +228,56 @@ static void test_json_healing() {
228228
R"({"key":"$foo"})",
229229
R"(:"$foo)"
230230
);
231+
// Test unicode escape sequences
232+
test(
233+
{
234+
R"({"a":"\u)",
235+
},
236+
R"({"a":"\u0000$foo"})",
237+
R"(0000$foo)"
238+
);
239+
test(
240+
{
241+
R"({"a":"\u00)",
242+
},
243+
R"({"a":"\u0000$foo"})",
244+
R"(00$foo)"
245+
);
246+
test(
247+
{
248+
R"({"a":"\ud300)",
249+
},
250+
R"({"a":"\ud300$foo"})",
251+
R"($foo)"
252+
);
253+
test(
254+
{
255+
R"({"a":"\ud800)",
256+
},
257+
R"({"a":"\ud800\udc00$foo"})",
258+
R"(\udc00$foo)"
259+
);
260+
test(
261+
{
262+
R"({"a":"\ud800\)",
263+
},
264+
R"({"a":"\ud800\udc00$foo"})",
265+
R"(udc00$foo)"
266+
);
267+
test(
268+
{
269+
R"({"a":"\ud800\u)",
270+
},
271+
R"({"a":"\ud800\udc00$foo"})",
272+
R"(dc00$foo)"
273+
);
274+
test(
275+
{
276+
R"({"a":"\ud800\udc00)",
277+
},
278+
R"({"a":"\ud800\udc00$foo"})",
279+
R"($foo)"
280+
);
231281
}
232282

233283
int main() {

0 commit comments

Comments
 (0)