Skip to content

Commit 09da7a4

Browse files
committed
Add new serde option "detect_mixed_tables"
This option would allow detecting mixed tables (with array-like and map-like entries or several borders) to encoding them chosing the best method (as a map or as a table).
1 parent bad2037 commit 09da7a4

File tree

4 files changed

+192
-28
lines changed

4 files changed

+192
-28
lines changed

src/serde/de.rs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ use crate::userdata::AnyUserData;
1515
use crate::value::Value;
1616

1717
/// A struct for deserializing Lua values into Rust values.
18-
#[derive(Debug)]
18+
#[derive(Debug, Default)]
1919
pub struct Deserializer {
2020
value: Value,
2121
options: Options,
2222
visited: Rc<RefCell<FxHashSet<*const c_void>>>,
23+
len: Option<usize>, // A length hint for sequences
2324
}
2425

2526
/// A struct with options to change default deserializer behavior.
@@ -54,6 +55,19 @@ pub struct Options {
5455
///
5556
/// Default: **false**
5657
pub encode_empty_tables_as_array: bool,
58+
59+
/// If true, enable detection of mixed tables.
60+
///
61+
/// A mixed table is a table that has both array-like and map-like entries or several borders.
62+
/// See [`The Length Operator`] documentation for details about borders.
63+
///
64+
/// When this option is disabled, a table with a non-zero length (with one or more borders) will
65+
/// be always encoded as an array.
66+
///
67+
/// Default: **false**
68+
///
69+
/// [`The Length Operator`]: https://www.lua.org/manual/5.4/manual.html#3.4.7
70+
pub detect_mixed_tables: bool,
5771
}
5872

5973
impl Default for Options {
@@ -70,6 +84,7 @@ impl Options {
7084
deny_recursive_tables: true,
7185
sort_keys: false,
7286
encode_empty_tables_as_array: false,
87+
detect_mixed_tables: false,
7388
}
7489
}
7590

@@ -108,6 +123,15 @@ impl Options {
108123
self.encode_empty_tables_as_array = enabled;
109124
self
110125
}
126+
127+
/// Sets [`detect_mixed_tables`] option.
128+
///
129+
/// [`detect_mixed_tables`]: #structfield.detect_mixed_tables
130+
#[must_use]
131+
pub const fn detect_mixed_tables(mut self, enable: bool) -> Self {
132+
self.detect_mixed_tables = enable;
133+
self
134+
}
111135
}
112136

113137
impl Deserializer {
@@ -121,7 +145,7 @@ impl Deserializer {
121145
Deserializer {
122146
value,
123147
options,
124-
visited: Rc::new(RefCell::new(FxHashSet::default())),
148+
..Default::default()
125149
}
126150
}
127151

@@ -130,8 +154,14 @@ impl Deserializer {
130154
value,
131155
options,
132156
visited,
157+
..Default::default()
133158
}
134159
}
160+
161+
fn with_len(mut self, len: usize) -> Self {
162+
self.len = Some(len);
163+
self
164+
}
135165
}
136166

137167
impl<'de> serde::Deserializer<'de> for Deserializer {
@@ -155,11 +185,13 @@ impl<'de> serde::Deserializer<'de> for Deserializer {
155185
Ok(s) => visitor.visit_str(&s),
156186
Err(_) => visitor.visit_bytes(&s.as_bytes()),
157187
},
158-
Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
159-
Value::Table(ref t) if self.options.encode_empty_tables_as_array && t.is_empty() => {
160-
self.deserialize_seq(visitor)
188+
Value::Table(ref t) => {
189+
if let Some(len) = t.encode_as_array(self.options) {
190+
self.with_len(len).deserialize_seq(visitor)
191+
} else {
192+
self.deserialize_map(visitor)
193+
}
161194
}
162-
Value::Table(_) => self.deserialize_map(visitor),
163195
Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
164196
Value::UserData(ud) if ud.is_serializable() => {
165197
serde_userdata(ud, |value| value.deserialize_any(visitor))
@@ -270,14 +302,14 @@ impl<'de> serde::Deserializer<'de> for Deserializer {
270302
Value::Table(t) => {
271303
let _guard = RecursionGuard::new(&t, &self.visited);
272304

273-
let len = t.raw_len();
305+
let len = self.len.unwrap_or_else(|| t.raw_len());
274306
let mut deserializer = SeqDeserializer {
275-
seq: t.sequence_values(),
307+
seq: t.sequence_values().with_len(len),
276308
options: self.options,
277309
visited: self.visited,
278310
};
279311
let seq = visitor.visit_seq(&mut deserializer)?;
280-
if deserializer.seq.count() == 0 {
312+
if deserializer.seq.next().is_none() {
281313
Ok(seq)
282314
} else {
283315
Err(de::Error::invalid_length(len, &"fewer elements in the table"))

src/table.rs

Lines changed: 101 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -668,26 +668,40 @@ impl Table {
668668
guard: self.0.lua.lock(),
669669
table: self,
670670
index: 1,
671+
len: None,
671672
_phantom: PhantomData,
672673
}
673674
}
674675

675676
/// Iterates over the sequence part of the table, invoking the given closure on each value.
677+
///
678+
/// This methods is similar to [`Table::sequence_values`], but optimized for performance.
676679
#[doc(hidden)]
677-
pub fn for_each_value<V>(&self, mut f: impl FnMut(V) -> Result<()>) -> Result<()>
678-
where
679-
V: FromLua,
680-
{
680+
pub fn for_each_value<V: FromLua>(&self, f: impl FnMut(V) -> Result<()>) -> Result<()> {
681+
self.for_each_value_by_len(None, f)
682+
}
683+
684+
fn for_each_value_by_len<V: FromLua>(
685+
&self,
686+
len: impl Into<Option<usize>>,
687+
mut f: impl FnMut(V) -> Result<()>,
688+
) -> Result<()> {
689+
let len = len.into();
681690
let lua = self.0.lua.lock();
682691
let state = lua.state();
683692
unsafe {
684693
let _sg = StackGuard::new(state);
685694
check_stack(state, 4)?;
686695

687696
lua.push_ref(&self.0);
688-
let len = ffi::lua_rawlen(state, -1);
689-
for i in 1..=len {
690-
ffi::lua_rawgeti(state, -1, i as _);
697+
for i in 1.. {
698+
if len.map(|len| i > len).unwrap_or(false) {
699+
break;
700+
}
701+
let t = ffi::lua_rawgeti(state, -1, i as _);
702+
if len.is_none() && t == ffi::LUA_TNIL {
703+
break;
704+
}
691705
f(V::from_stack(-1, &lua)?)?;
692706
ffi::lua_pop(state, 1);
693707
}
@@ -720,8 +734,9 @@ impl Table {
720734
Ok(())
721735
}
722736

737+
/// Checks if the table has the array metatable attached.
723738
#[cfg(feature = "serde")]
724-
pub(crate) fn is_array(&self) -> bool {
739+
fn has_array_metatable(&self) -> bool {
725740
let lua = self.0.lua.lock();
726741
let state = lua.state();
727742
unsafe {
@@ -737,6 +752,70 @@ impl Table {
737752
}
738753
}
739754

755+
/// If the table is an array, returns the number of non-nil elements and max index.
756+
///
757+
/// Returns `None` if the table is not an array.
758+
///
759+
/// This operation has O(n) complexity.
760+
#[cfg(feature = "serde")]
761+
fn find_array_len(&self) -> Option<(usize, usize)> {
762+
let lua = self.0.lua.lock();
763+
let ref_thread = lua.ref_thread();
764+
unsafe {
765+
let _sg = StackGuard::new(ref_thread);
766+
767+
let (mut count, mut max_index) = (0, 0);
768+
ffi::lua_pushnil(ref_thread);
769+
while ffi::lua_next(ref_thread, self.0.index) != 0 {
770+
if ffi::lua_type(ref_thread, -2) != ffi::LUA_TNUMBER {
771+
return None;
772+
}
773+
774+
let k = ffi::lua_tonumber(ref_thread, -2);
775+
if k.trunc() != k || k < 1.0 {
776+
return None;
777+
}
778+
max_index = std::cmp::max(max_index, k as usize);
779+
count += 1;
780+
ffi::lua_pop(ref_thread, 1);
781+
}
782+
Some((count, max_index))
783+
}
784+
}
785+
786+
/// Determines if the table should be encoded as an array or a map.
787+
///
788+
/// The algorithm is the following:
789+
/// 1. If `detect_mixed_tables` is enabled, iterate over all keys in the table checking is they
790+
/// all are positive integers. If non-array key is found, return `None` (encode as map).
791+
/// Otherwise check the sparsity of the array. Too sparse arrays are encoded as maps.
792+
///
793+
/// 2. If `detect_mixed_tables` is disabled, check if the table has a positive length or has the
794+
/// array metatable. If so, encode as array. If the table is empty and
795+
/// `encode_empty_tables_as_array` is enabled, encode as array.
796+
///
797+
/// Returns the length of the array if it should be encoded as an array.
798+
#[cfg(feature = "serde")]
799+
pub(crate) fn encode_as_array(&self, options: crate::serde::de::Options) -> Option<usize> {
800+
if options.detect_mixed_tables {
801+
if let Some((len, max_idx)) = self.find_array_len() {
802+
// If the array is too sparse, serialize it as a map instead
803+
if len < 10 || len * 2 >= max_idx {
804+
return Some(max_idx);
805+
}
806+
}
807+
} else {
808+
let len = self.raw_len();
809+
if len > 0 || self.has_array_metatable() {
810+
return Some(len);
811+
}
812+
if options.encode_empty_tables_as_array && self.is_empty() {
813+
return Some(0);
814+
}
815+
}
816+
None
817+
}
818+
740819
#[cfg(feature = "luau")]
741820
#[inline(always)]
742821
fn check_readonly_write(&self, lua: &RawLua) -> Result<()> {
@@ -980,6 +1059,15 @@ impl<'a> SerializableTable<'a> {
9801059
}
9811060
}
9821061

1062+
impl<V> TableSequence<'_, V> {
1063+
/// Sets the length (hint) of the sequence.
1064+
#[cfg(feature = "serde")]
1065+
pub(crate) fn with_len(mut self, len: usize) -> Self {
1066+
self.len = Some(len);
1067+
self
1068+
}
1069+
}
1070+
9831071
#[cfg(feature = "serde")]
9841072
impl Serialize for SerializableTable<'_> {
9851073
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
@@ -1001,14 +1089,10 @@ impl Serialize for SerializableTable<'_> {
10011089
let _guard = RecursionGuard::new(self.table, visited);
10021090

10031091
// Array
1004-
let len = self.table.raw_len();
1005-
if len > 0
1006-
|| self.table.is_array()
1007-
|| (self.options.encode_empty_tables_as_array && self.table.is_empty())
1008-
{
1092+
if let Some(len) = self.table.encode_as_array(self.options) {
10091093
let mut seq = serializer.serialize_seq(Some(len))?;
10101094
let mut serialize_err = None;
1011-
let res = self.table.for_each_value::<Value>(|value| {
1095+
let res = self.table.for_each_value_by_len::<Value>(len, |value| {
10121096
let skip = check_value_for_skip(&value, self.options, visited)
10131097
.map_err(|err| Error::SerializeError(err.to_string()))?;
10141098
if skip {
@@ -1132,13 +1216,11 @@ pub struct TableSequence<'a, V> {
11321216
guard: LuaGuard,
11331217
table: &'a Table,
11341218
index: Integer,
1219+
len: Option<usize>,
11351220
_phantom: PhantomData<V>,
11361221
}
11371222

1138-
impl<V> Iterator for TableSequence<'_, V>
1139-
where
1140-
V: FromLua,
1141-
{
1223+
impl<V: FromLua> Iterator for TableSequence<'_, V> {
11421224
type Item = Result<V>;
11431225

11441226
fn next(&mut self) -> Option<Self::Item> {
@@ -1152,7 +1234,7 @@ where
11521234

11531235
lua.push_ref(&self.table.0);
11541236
match ffi::lua_rawgeti(state, -1, self.index) {
1155-
ffi::LUA_TNIL => None,
1237+
ffi::LUA_TNIL if self.index as usize > self.len.unwrap_or(0) => None,
11561238
_ => {
11571239
self.index += 1;
11581240
Some(V::from_stack(-1, lua))

src/value.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,17 @@ impl<'a> SerializableValue<'a> {
717717
self.options.encode_empty_tables_as_array = enabled;
718718
self
719719
}
720+
721+
/// If true, enable detection of mixed tables.
722+
///
723+
/// A mixed table is a table that has both array-like and map-like entries or several borders.
724+
///
725+
/// Default: **false**
726+
#[must_use]
727+
pub const fn detect_mixed_tables(mut self, enabled: bool) -> Self {
728+
self.options.detect_mixed_tables = enabled;
729+
self
730+
}
720731
}
721732

722733
#[cfg(feature = "serde")]

tests/serde.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,45 @@ fn test_serialize_empty_table() -> LuaResult<()> {
269269
Ok(())
270270
}
271271

272+
#[test]
273+
fn test_serialize_mixed_table() -> LuaResult<()> {
274+
let lua = Lua::new();
275+
276+
// Check that sparse array is serialized similarly when using direct serialization
277+
// and via `Lua::from_value`
278+
let table = lua.load("{1,2,3,nil,5}").eval::<Value>()?;
279+
let json1 = serde_json::to_string(&table).unwrap();
280+
let json2 = lua.from_value::<serde_json::Value>(table)?;
281+
assert_eq!(json1, json2.to_string());
282+
283+
// A table with several borders should be correctly encoded when `detect_mixed_tables` is enabled
284+
let table = lua
285+
.load(
286+
r#"
287+
local t = {1,2,3,nil,5,6}
288+
t[10] = 10
289+
return t
290+
"#,
291+
)
292+
.eval::<Value>()?;
293+
let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap();
294+
assert_eq!(json, r#"[1,2,3,null,5,6,null,null,null,10]"#);
295+
296+
// A mixed table with both array-like and map-like entries
297+
let table = lua.load(r#"{1,2,3, key="value"}"#).eval::<Value>()?;
298+
let json = serde_json::to_string(&table).unwrap();
299+
assert_eq!(json, r#"[1,2,3]"#);
300+
let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap();
301+
assert_eq!(json, r#"{"1":1,"2":2,"3":3,"key":"value"}"#);
302+
303+
// A mixed table with duplicate keys of different types
304+
let table = lua.load(r#"{1,2,3, ["1"]="value"}"#).eval::<Value>()?;
305+
let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap();
306+
assert_eq!(json, r#"{"1":1,"2":2,"3":3,"1":"value"}"#);
307+
308+
Ok(())
309+
}
310+
272311
#[test]
273312
fn test_to_value_struct() -> LuaResult<()> {
274313
let lua = Lua::new();

0 commit comments

Comments
 (0)