Skip to content

Commit

Permalink
Adding struct tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Bidek56 committed Dec 28, 2024
1 parent 9aa5c6f commit f92688b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 49 deletions.
79 changes: 63 additions & 16 deletions __tests__/dataframe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,72 @@ describe("dataframe", () => {

test("df from JSON with struct", () => {
const rows = [
{id: 1, name: 'one', attributes: {x: 700, colour: 'black'}},
{id: 2, name: 'two', attributes: {x: 800, colour: 'blue'}},
{id: 3, name: 'three', attributes: {x: 100, colour: 'red'}}
{
id: 1,
name: "one",
attributes: { b: false, bb: true, s: "one", x: 1 },
},
{
id: 2,
name: "two",
attributes: { b: false, bb: true, s: "two", x: 2 },
},
{
id: 3,
name: "three",
attributes: { b: false, bb: true, s: "three", x: 3 },
},
];

const df = pl.DataFrame(rows);
expect(df.schema).toStrictEqual(
{
id: {DataType:"Float64"},
name:{DataType:"String"},
attributes:{
DataType:{Struct:[
{name:"x",dtype:{DataType:"Float64"}},
{name:"colour",dtype:{DataType:"String"}}
]}}
}
);
});
let actual = pl.DataFrame(rows);
expect(actual.schema).toStrictEqual({
id: pl.Float64,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Float64),
]),
});

let expected = `shape: (3, 3)
┌─────┬───────┬──────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ f64 ┆ str ┆ struct[4] │
╞═════╪═══════╪══════════════════════════╡
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
└─────┴───────┴──────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);

const schema = {
id: pl.Int32,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Int16),
]),
};
actual = pl.DataFrame(rows, { schema: schema });
expected = `shape: (3, 3)
┌─────┬───────┬────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ i32 ┆ str ┆ struct[4] │
╞═════╪═══════╪════════════════════════╡
│ 1 ┆ one ┆ {false,true,"one",1} │
│ 2 ┆ two ┆ {false,true,"two",2} │
│ 3 ┆ three ┆ {false,true,"three",3} │
└─────┴───────┴────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);
expect(actual.getColumn('name').toArray()).toEqual(rows.map(e=>e['name']));
expect(actual.getColumn('attributes').toArray()).toMatchObject(rows.map(e=>e['attributes']));
});
test("dtypes", () => {
const expected = [pl.Float64, pl.String];
const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes;
Expand Down
4 changes: 2 additions & 2 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,8 @@ impl FromNapiValue for Wrap<CsvWriterOptions> {
let obj = Object::from_napi_value(env, napi_val)?;
let include_bom = obj.get::<_, bool>("includeBom")?.unwrap_or(false);
let include_header = obj.get::<_, bool>("includeHeader")?.unwrap_or(true);
let batch_size =
NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize).ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?;
let batch_size = NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize)
.ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?;
let maintain_order = obj.get::<_, bool>("maintainOrder")?.unwrap_or(true);
let date_format = obj.get::<_, String>("dateFormat")?;
let time_format = obj.get::<_, String>("timeFormat")?;
Expand Down
87 changes: 56 additions & 31 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,22 @@ pub fn from_rows(
.unwrap_or_else(|| env.create_object().unwrap());

Row(schema
.iter_fields().map(|fld| {
.iter_fields()
.map(|fld| {
let dtype: &DataType = fld.dtype();
let key: &PlSmallStr = fld.name();
if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) {
match unknown {
Some(unknown) => unsafe {
Some(unknown) => {
coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null)
},
}
_ => AnyValue::Null,
}
} else {
AnyValue::Null
}
}).collect())
})
.collect())
})
.collect();
let df = DataFrame::from_rows_and_schema(&it, &schema).map_err(JsPolarsErr::from)?;
Expand Down Expand Up @@ -1657,19 +1659,22 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
} else {
let inner_val: napi::JsObject = unsafe { val.cast() };
let inner_keys = Object::keys(&inner_val).unwrap();
let mut fldvec: Vec<Field> = Vec::with_capacity(inner_keys.len() as usize);
let mut fldvec: Vec<Field> =
Vec::with_capacity(inner_keys.len() as usize);

inner_keys.iter().for_each(|key| {
let inner_val = &inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => DataType::Struct(vec![]),
_ => DataType::Null
};

let inner_val =
&inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype =
match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => DataType::Struct(vec![]),
_ => DataType::Null,
};

let fld = Field::new(key.into(), dtype);
fldvec.push(fld);
});
Expand All @@ -1679,15 +1684,15 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
_ => DataType::Null,
}
}
None => DataType::Null
None => DataType::Null,
};
(key.to_owned(), dtype)
})
.collect()
})
}

unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
use DataType::*;
let vtype = val.get_type().unwrap();
match (vtype, dtype) {
Expand Down Expand Up @@ -1762,33 +1767,53 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<An
}
(ValueType::Object, DataType::Datetime(_, _)) => {
if val.is_date()? {
let d: napi::JsDate = val.cast();
let d: napi::JsDate = unsafe { val.cast() };
let d = d.value_of()?;
Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None))
} else {
Ok(AnyValue::Null)
}
}
(ValueType::Object, DataType::List(_)) => {
let s = val.to_series();
let s = unsafe { val.to_series() };
Ok(AnyValue::List(s))
}
(ValueType::Object, DataType::Struct(fields)) => {
let number_of_fields: i8 = fields.len().try_into().map_err(
|e| napi::Error::from_reason(format!("the number of `fields` cannot be larger than i8::MAX {e:?}"))
)?;
let number_of_fields: i8 = fields.len().try_into().map_err(|e| {
napi::Error::from_reason(format!(
"the number of `fields` cannot be larger than i8::MAX {e:?}"
))
})?;

let inner_val: napi::JsObject = val.cast();
let mut val_vec: Vec<polars::prelude::AnyValue<'_>> = Vec::with_capacity(number_of_fields as usize);
let inner_val: napi::JsObject = unsafe { val.cast() };
let mut val_vec: Vec<polars::prelude::AnyValue<'_>> =
Vec::with_capacity(number_of_fields as usize);
fields.iter().for_each(|fld| {
let single_val = inner_val.get::<_, napi::JsUnknown>(&fld.name).unwrap().unwrap();
let vv = match fld.dtype {
DataType::Boolean => AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap()),
let single_val = inner_val
.get::<_, napi::JsUnknown>(&fld.name)
.unwrap()
.unwrap();
let vv = match &fld.dtype {
DataType::Boolean => {
AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap())
}
DataType::String => AnyValue::from_js(single_val).expect("Expecting string"),
DataType::Int32 => AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap()),
DataType::Int64 => AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap()),
DataType::Float64 => AnyValue::Float64(single_val.coerce_to_number().unwrap().get_double().unwrap()),
_ => AnyValue::Null
DataType::Int16 => AnyValue::Int16(
single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16,
),
DataType::Int32 => {
AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap())
}
DataType::Int64 => {
AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap())
}
DataType::Float64 => AnyValue::Float64(
single_val.coerce_to_number().unwrap().get_double().unwrap(),
),
DataType::Struct(_) => {
coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap()
}
_ => AnyValue::Null,
};
val_vec.push(vv);
});
Expand Down

0 comments on commit f92688b

Please sign in to comment.