From f92688b6c8a6d835dc3b0977f3844cbd606c3244 Mon Sep 17 00:00:00 2001 From: Bidek56 Date: Fri, 27 Dec 2024 19:13:01 -0500 Subject: [PATCH] Adding struct tests --- __tests__/dataframe.test.ts | 79 ++++++++++++++++++++++++++------- src/conversion.rs | 4 +- src/dataframe.rs | 87 ++++++++++++++++++++++++------------- 3 files changed, 121 insertions(+), 49 deletions(-) diff --git a/__tests__/dataframe.test.ts b/__tests__/dataframe.test.ts index 5e11b898..a8fdf2ff 100644 --- a/__tests__/dataframe.test.ts +++ b/__tests__/dataframe.test.ts @@ -10,25 +10,72 @@ describe("dataframe", () => { test("df from JSON with struct", () => { const rows = [ - {id: 1, name: 'one', attributes: {x: 700, colour: 'black'}}, - {id: 2, name: 'two', attributes: {x: 800, colour: 'blue'}}, - {id: 3, name: 'three', attributes: {x: 100, colour: 'red'}} + { + id: 1, + name: "one", + attributes: { b: false, bb: true, s: "one", x: 1 }, + }, + { + id: 2, + name: "two", + attributes: { b: false, bb: true, s: "two", x: 2 }, + }, + { + id: 3, + name: "three", + attributes: { b: false, bb: true, s: "three", x: 3 }, + }, ]; - const df = pl.DataFrame(rows); - expect(df.schema).toStrictEqual( - { - id: {DataType:"Float64"}, - name:{DataType:"String"}, - attributes:{ - DataType:{Struct:[ - {name:"x",dtype:{DataType:"Float64"}}, - {name:"colour",dtype:{DataType:"String"}} - ]}} - } - ); - }); + let actual = pl.DataFrame(rows); + expect(actual.schema).toStrictEqual({ + id: pl.Float64, + name: pl.String, + attributes: pl.Struct([ + new pl.Field("b", pl.Bool), + new pl.Field("bb", pl.Bool), + new pl.Field("s", pl.String), + new pl.Field("x", pl.Float64), + ]), + }); + + let expected = `shape: (3, 3) +┌─────┬───────┬──────────────────────────┐ +│ id ┆ name ┆ attributes │ +│ --- ┆ --- ┆ --- │ +│ f64 ┆ str ┆ struct[4] │ +╞═════╪═══════╪══════════════════════════╡ +│ 1.0 ┆ one ┆ {false,true,"one",1.0} │ +│ 2.0 ┆ two ┆ {false,true,"two",2.0} │ +│ 3.0 ┆ three ┆ {false,true,"three",3.0} │ +└─────┴───────┴──────────────────────────┘`; + expect(actual.toString()).toStrictEqual(expected); + const schema = { + id: pl.Int32, + name: pl.String, + attributes: pl.Struct([ + new pl.Field("b", pl.Bool), + new pl.Field("bb", pl.Bool), + new pl.Field("s", pl.String), + new pl.Field("x", pl.Int16), + ]), + }; + actual = pl.DataFrame(rows, { schema: schema }); + expected = `shape: (3, 3) +┌─────┬───────┬────────────────────────┐ +│ id ┆ name ┆ attributes │ +│ --- ┆ --- ┆ --- │ +│ i32 ┆ str ┆ struct[4] │ +╞═════╪═══════╪════════════════════════╡ +│ 1 ┆ one ┆ {false,true,"one",1} │ +│ 2 ┆ two ┆ {false,true,"two",2} │ +│ 3 ┆ three ┆ {false,true,"three",3} │ +└─────┴───────┴────────────────────────┘`; + expect(actual.toString()).toStrictEqual(expected); + expect(actual.getColumn('name').toArray()).toEqual(rows.map(e=>e['name'])); + expect(actual.getColumn('attributes').toArray()).toMatchObject(rows.map(e=>e['attributes'])); + }); test("dtypes", () => { const expected = [pl.Float64, pl.String]; const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes; diff --git a/src/conversion.rs b/src/conversion.rs index 5a5786e5..bb7d70db 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -830,8 +830,8 @@ impl FromNapiValue for Wrap { let obj = Object::from_napi_value(env, napi_val)?; let include_bom = obj.get::<_, bool>("includeBom")?.unwrap_or(false); let include_header = obj.get::<_, bool>("includeHeader")?.unwrap_or(true); - let batch_size = - NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize).ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?; + let batch_size = NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize) + .ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?; let maintain_order = obj.get::<_, bool>("maintainOrder")?.unwrap_or(true); let date_format = obj.get::<_, String>("dateFormat")?; let time_format = obj.get::<_, String>("timeFormat")?; diff --git a/src/dataframe.rs b/src/dataframe.rs index 162545f9..a9fb7552 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -451,20 +451,22 @@ pub fn from_rows( .unwrap_or_else(|| env.create_object().unwrap()); Row(schema - .iter_fields().map(|fld| { + .iter_fields() + .map(|fld| { let dtype: &DataType = fld.dtype(); let key: &PlSmallStr = fld.name(); if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) { match unknown { - Some(unknown) => unsafe { + Some(unknown) => { coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null) - }, + } _ => AnyValue::Null, } } else { AnyValue::Null } - }).collect()) + }) + .collect()) }) .collect(); let df = DataFrame::from_rows_and_schema(&it, &schema).map_err(JsPolarsErr::from)?; @@ -1657,19 +1659,22 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator = Vec::with_capacity(inner_keys.len() as usize); + let mut fldvec: Vec = + Vec::with_capacity(inner_keys.len() as usize); inner_keys.iter().for_each(|key| { - let inner_val = &inner_val.get::<_, napi::JsUnknown>(&key).unwrap(); - let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() { - ValueType::Boolean => DataType::Boolean, - ValueType::Number => DataType::Float64, - ValueType::BigInt => DataType::UInt64, - ValueType::String => DataType::String, - ValueType::Object => DataType::Struct(vec![]), - _ => DataType::Null - }; - + let inner_val = + &inner_val.get::<_, napi::JsUnknown>(&key).unwrap(); + let dtype = + match inner_val.as_ref().unwrap().get_type().unwrap() { + ValueType::Boolean => DataType::Boolean, + ValueType::Number => DataType::Float64, + ValueType::BigInt => DataType::UInt64, + ValueType::String => DataType::String, + ValueType::Object => DataType::Struct(vec![]), + _ => DataType::Null, + }; + let fld = Field::new(key.into(), dtype); fldvec.push(fld); }); @@ -1679,7 +1684,7 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator DataType::Null, } } - None => DataType::Null + None => DataType::Null, }; (key.to_owned(), dtype) }) @@ -1687,7 +1692,7 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator(val: JsUnknown, dtype: DataType) -> JsResult> { +fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult> { use DataType::*; let vtype = val.get_type().unwrap(); match (vtype, dtype) { @@ -1762,7 +1767,7 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult { if val.is_date()? { - let d: napi::JsDate = val.cast(); + let d: napi::JsDate = unsafe { val.cast() }; let d = d.value_of()?; Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None)) } else { @@ -1770,25 +1775,45 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult { - let s = val.to_series(); + let s = unsafe { val.to_series() }; Ok(AnyValue::List(s)) } (ValueType::Object, DataType::Struct(fields)) => { - let number_of_fields: i8 = fields.len().try_into().map_err( - |e| napi::Error::from_reason(format!("the number of `fields` cannot be larger than i8::MAX {e:?}")) - )?; + let number_of_fields: i8 = fields.len().try_into().map_err(|e| { + napi::Error::from_reason(format!( + "the number of `fields` cannot be larger than i8::MAX {e:?}" + )) + })?; - let inner_val: napi::JsObject = val.cast(); - let mut val_vec: Vec> = Vec::with_capacity(number_of_fields as usize); + let inner_val: napi::JsObject = unsafe { val.cast() }; + let mut val_vec: Vec> = + Vec::with_capacity(number_of_fields as usize); fields.iter().for_each(|fld| { - let single_val = inner_val.get::<_, napi::JsUnknown>(&fld.name).unwrap().unwrap(); - let vv = match fld.dtype { - DataType::Boolean => AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap()), + let single_val = inner_val + .get::<_, napi::JsUnknown>(&fld.name) + .unwrap() + .unwrap(); + let vv = match &fld.dtype { + DataType::Boolean => { + AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap()) + } DataType::String => AnyValue::from_js(single_val).expect("Expecting string"), - DataType::Int32 => AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap()), - DataType::Int64 => AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap()), - DataType::Float64 => AnyValue::Float64(single_val.coerce_to_number().unwrap().get_double().unwrap()), - _ => AnyValue::Null + DataType::Int16 => AnyValue::Int16( + single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16, + ), + DataType::Int32 => { + AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap()) + } + DataType::Int64 => { + AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap()) + } + DataType::Float64 => AnyValue::Float64( + single_val.coerce_to_number().unwrap().get_double().unwrap(), + ), + DataType::Struct(_) => { + coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap() + } + _ => AnyValue::Null, }; val_vec.push(vv); });