Skip to content

Commit

Permalink
Adding support for recursive struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Bidek56 committed Dec 28, 2024
1 parent 0038898 commit 66d77e9
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 137 deletions.
168 changes: 95 additions & 73 deletions __tests__/dataframe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,6 @@ describe("dataframe", () => {
pl.Series("foo", [1, 2, 9], pl.Int16),
pl.Series("bar", [6, 2, 8], pl.Int16),
]);

test("df from JSON with struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: { b: false, bb: true, s: "one", x: 1 },
},
{
id: 2,
name: "two",
attributes: { b: false, bb: true, s: "two", x: 2 },
},
{
id: 3,
name: "three",
attributes: { b: false, bb: true, s: "three", x: 3 },
},
];

let actual = pl.DataFrame(rows);
expect(actual.schema).toStrictEqual({
id: pl.Float64,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Float64),
]),
});

let expected = `shape: (3, 3)
┌─────┬───────┬──────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ f64 ┆ str ┆ struct[4] │
╞═════╪═══════╪══════════════════════════╡
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
└─────┴───────┴──────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);

const schema = {
id: pl.Int32,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Int16),
]),
};
actual = pl.DataFrame(rows, { schema: schema });
expected = `shape: (3, 3)
┌─────┬───────┬────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ i32 ┆ str ┆ struct[4] │
╞═════╪═══════╪════════════════════════╡
│ 1 ┆ one ┆ {false,true,"one",1} │
│ 2 ┆ two ┆ {false,true,"two",2} │
│ 3 ┆ three ┆ {false,true,"three",3} │
└─────┴───────┴────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);
expect(actual.getColumn("name").toArray()).toEqual(
rows.map((e) => e["name"]),
);
expect(actual.getColumn("attributes").toArray()).toMatchObject(
rows.map((e) => e["attributes"]),
);
});
test("dtypes", () => {
const expected = [pl.Float64, pl.String];
const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes;
Expand Down Expand Up @@ -1386,6 +1313,101 @@ describe("dataframe", () => {
]);
expect(actual).toFrameEqual(expected);
});
test("df from JSON with multiple struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: {
b: false,
bb: true,
s: "one",
x: 1,
att2: { s: "two", y: 2, att3: { s: "three", y: 3 } },
},
},
];

const actual = pl.DataFrame(rows);
const expected = `shape: (1,)
Series: 'attributes' [struct[5]]
[
{false,true,"one",1.0,{"two",2.0,{"three",3.0}}}
]`;
expect(actual.select("attributes").toSeries().toString()).toEqual(expected);
});
test("df from JSON with struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: { b: false, bb: true, s: "one", x: 1 },
},
{
id: 2,
name: "two",
attributes: { b: false, bb: true, s: "two", x: 2 },
},
{
id: 3,
name: "three",
attributes: { b: false, bb: true, s: "three", x: 3 },
},
];

let actual = pl.DataFrame(rows);
expect(actual.schema).toStrictEqual({
id: pl.Float64,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Float64),
]),
});

let expected = `shape: (3, 3)
┌─────┬───────┬──────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ f64 ┆ str ┆ struct[4] │
╞═════╪═══════╪══════════════════════════╡
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
└─────┴───────┴──────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);

const schema = {
id: pl.Int32,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Int16),
]),
};
actual = pl.DataFrame(rows, { schema: schema });
expected = `shape: (3, 3)
┌─────┬───────┬────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ i32 ┆ str ┆ struct[4] │
╞═════╪═══════╪════════════════════════╡
│ 1 ┆ one ┆ {false,true,"one",1} │
│ 2 ┆ two ┆ {false,true,"two",2} │
│ 3 ┆ three ┆ {false,true,"three",3} │
└─────┴───────┴────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);
expect(actual.getColumn("name").toArray()).toEqual(
rows.map((e) => e["name"]),
);
expect(actual.getColumn("attributes").toArray()).toMatchObject(
rows.map((e) => e["attributes"]),
);
});
test("pivot", () => {
{
const df = pl.DataFrame({
Expand Down
127 changes: 63 additions & 64 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1623,75 +1623,74 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
keys.iter()
.map(|key| {
let value = obj.get::<_, napi::JsUnknown>(&key).unwrap_or(None);
let dtype = match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();
if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> =
Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown =
arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
let inner_val: napi::JsObject = unsafe { val.cast() };
let inner_keys = Object::keys(&inner_val).unwrap();
let mut fldvec: Vec<Field> =
Vec::with_capacity(inner_keys.len() as usize);

inner_keys.iter().for_each(|key| {
let inner_val =
&inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype =
match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => DataType::Struct(vec![]),
_ => DataType::Null,
};

let fld = Field::new(key.into(), dtype);
fldvec.push(fld);
});
DataType::Struct(fldvec)
}
}
_ => DataType::Null,
}
}
None => DataType::Null,
};
(key.to_owned(), dtype)
(key.to_owned(), obj_to_type(value))
})
.collect()
})
}

fn obj_to_type(value: Option<JsUnknown>) -> DataType {
match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();
if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> = Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown = arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
let inner_val: napi::JsObject = unsafe { val.cast() };
let inner_keys = Object::keys(&inner_val).unwrap();
let mut fldvec: Vec<Field> = Vec::with_capacity(inner_keys.len() as usize);

inner_keys.iter().for_each(|key| {
let inner_val = inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
// determine struct type using a recursive func
ValueType::Object => obj_to_type(inner_val),
_ => DataType::Null,
};

let fld = Field::new(key.into(), dtype);
fldvec.push(fld);
});
DataType::Struct(fldvec)
}
}
_ => DataType::Null,
}
}
None => DataType::Null,
}
}

fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
use DataType::*;
let vtype = val.get_type().unwrap();
Expand Down

0 comments on commit 66d77e9

Please sign in to comment.