From 68365742ee77e19ef46d96a5cb823492e5d288d2 Mon Sep 17 00:00:00 2001 From: Bidek56 Date: Tue, 17 Sep 2024 21:14:03 -0400 Subject: [PATCH 1/2] Fixing series fold --- __tests__/dataframe.test.ts | 42 +++++++++++++++++++++---------------- polars/series/index.ts | 27 +++++++++++++++++------- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/__tests__/dataframe.test.ts b/__tests__/dataframe.test.ts index 39e0c191..f87dd0c2 100644 --- a/__tests__/dataframe.test.ts +++ b/__tests__/dataframe.test.ts @@ -338,24 +338,30 @@ describe("dataframe", () => { const actual = df.fold((a, b) => a.concat(b)); expect(actual).toSeriesEqual(expected); }); - // test("fold", () => { - // const s1 = pl.Series([1, 2, 3]); - // const s2 = pl.Series([4, 5, 6]); - // const s3 = pl.Series([7, 8, 1]); - // const expected = pl.Series("foo", [true, true, false]); - // const df = pl.DataFrame([s1, s2, s3]); - // const actual = df.fold((a, b) => a.lessThan(b)).alias("foo"); - // expect(actual).toSeriesEqual(expected); - // }); - // test("fold-again", () => { - // const s1 = pl.Series([1, 2, 3]); - // const s2 = pl.Series([4, 5, 6]); - // const s3 = pl.Series([7, 8, 1]); - // const expected = pl.Series("foo", [12, 15, 10]); - // const df = pl.DataFrame([s1, s2, s3]); - // const actual = df.fold((a, b) => a.plus(b)).alias("foo"); - // expect(actual).toSeriesEqual(expected); - // }); + it.each` + name | actual | expected + ${"fold:lessThan"} | ${df.fold((a, b) => a.lessThan(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])} + ${"fold:lt"} | ${df.fold((a, b) => a.lt(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])} + ${"fold:lessThanEquals"} | ${df.fold((a, b) => a.lessThanEquals(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])} + ${"fold:ltEq"} | ${df.fold((a, b) => a.ltEq(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])} + ${"fold:neq"} | ${df.fold((a, b) => a.neq(b)).alias("foo")} | ${pl.Series("foo", [true, false, true])} + ${"fold:plus"} | ${df.fold((a, b) => a.plus(b)).alias("foo")} | ${pl.Series("foo", [7, 4, 17])} + ${"fold:minus"} | ${df.fold((a, b) => a.minus(b)).alias("foo")} | ${pl.Series("foo", [-5, 0, 1])} + ${"fold:mul"} | ${df.fold((a, b) => a.mul(b)).alias("foo")} | ${pl.Series("foo", [6, 4, 72])} + `("$# $name expected matches actual", ({ expected, actual }) => { + expect(expected).toSeriesEqual(actual); + }); + test("fold:lt", () => { + const s1 = pl.Series([1, 2, 3]); + const s2 = pl.Series([4, 5, 6]); + const s3 = pl.Series([7, 8, 1]); + const df = pl.DataFrame([s1, s2, s3]); + const expected = pl.Series("foo", [true, true, false]); + let actual = df.fold((a, b) => a.lessThan(b)).alias("foo"); + expect(actual).toSeriesEqual(expected); + actual = df.fold((a, b) => a.lt(b)).alias("foo"); + expect(actual).toSeriesEqual(expected); + }); test("frameEqual:true", () => { const df = pl.DataFrame({ foo: [1, 2, 3], diff --git a/polars/series/index.ts b/polars/series/index.ts index 4f49fb35..f6eb9e33 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -1125,6 +1125,9 @@ export function _Series(_s: any): Series { const wrap = (method, ...args): Series => { return _Series(unwrap(method, ...args)); }; + const wraps = (method, args: any): Series => { + return _Series(_s[method as any](args._s)); + }; const dtypeWrap = (method: string, ...args: any[]) => { const dtype = _s.dtype; @@ -1528,16 +1531,20 @@ export function _Series(_s: any): Series { return this.length; }, lt(field) { - return dtypeWrap("Lt", field); + if (typeof field === "number") return dtypeWrap("Lt", field); + return wraps("lt", field); }, lessThan(field) { - return dtypeWrap("Lt", field); + if (typeof field === "number") return dtypeWrap("Lt", field); + return wraps("lt", field); }, ltEq(field) { - return dtypeWrap("LtEq", field); + if (typeof field === "number") return dtypeWrap("LtEq", field); + return wraps("ltEq", field); }, lessThanEquals(field) { - return dtypeWrap("LtEq", field); + if (typeof field === "number") return dtypeWrap("LtEq", field); + return wraps("ltEq", field); }, limit(n = 10) { return wrap("limit", n); @@ -1558,16 +1565,19 @@ export function _Series(_s: any): Series { return wrap("mode"); }, minus(other) { - return dtypeWrap("Sub", other); + if (typeof other === "number") return dtypeWrap("Sub", other); + return wraps("sub", other); }, mul(other) { - return dtypeWrap("Mul", other); + if (typeof other === "number") return dtypeWrap("Mul", other); + return wraps("mul", other); }, nChunks() { return _s.nChunks(); }, neq(other) { - return dtypeWrap("Neq", other); + if (typeof other === "number") return dtypeWrap("Neq", other); + return wraps("neq", other); }, notEquals(other) { return this.neq(other); @@ -1585,7 +1595,8 @@ export function _Series(_s: any): Series { return expr_op("peakMin"); }, plus(other) { - return dtypeWrap("Add", other); + if (typeof other === "number") return dtypeWrap("Add", other); + return wraps("add", other); }, quantile(quantile, interpolation = "nearest") { return _s.quantile(quantile, interpolation); From 907c01e65e1eaee5123794f78897a020a546e167 Mon Sep 17 00:00:00 2001 From: Bidek56 Date: Tue, 24 Sep 2024 16:32:44 -0400 Subject: [PATCH 2/2] Adding isSeries check --- polars/series/index.ts | 43 +++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/polars/series/index.ts b/polars/series/index.ts index f6eb9e33..d6c4511f 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -1125,9 +1125,6 @@ export function _Series(_s: any): Series { const wrap = (method, ...args): Series => { return _Series(unwrap(method, ...args)); }; - const wraps = (method, args: any): Series => { - return _Series(_s[method as any](args._s)); - }; const dtypeWrap = (method: string, ...args: any[]) => { const dtype = _s.dtype; @@ -1532,19 +1529,31 @@ export function _Series(_s: any): Series { }, lt(field) { if (typeof field === "number") return dtypeWrap("Lt", field); - return wraps("lt", field); + if (Series.isSeries(field)) { + return wrap("lt", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, lessThan(field) { if (typeof field === "number") return dtypeWrap("Lt", field); - return wraps("lt", field); + if (Series.isSeries(field)) { + return wrap("lt", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, ltEq(field) { if (typeof field === "number") return dtypeWrap("LtEq", field); - return wraps("ltEq", field); + if (Series.isSeries(field)) { + return wrap("ltEq", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, lessThanEquals(field) { if (typeof field === "number") return dtypeWrap("LtEq", field); - return wraps("ltEq", field); + if (Series.isSeries(field)) { + return wrap("ltEq", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, limit(n = 10) { return wrap("limit", n); @@ -1566,18 +1575,27 @@ export function _Series(_s: any): Series { }, minus(other) { if (typeof other === "number") return dtypeWrap("Sub", other); - return wraps("sub", other); + if (Series.isSeries(other)) { + return wrap("sub", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, mul(other) { if (typeof other === "number") return dtypeWrap("Mul", other); - return wraps("mul", other); + if (Series.isSeries(other)) { + return wrap("mul", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, nChunks() { return _s.nChunks(); }, neq(other) { if (typeof other === "number") return dtypeWrap("Neq", other); - return wraps("neq", other); + if (Series.isSeries(other)) { + return wrap("neq", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, notEquals(other) { return this.neq(other); @@ -1596,7 +1614,10 @@ export function _Series(_s: any): Series { }, plus(other) { if (typeof other === "number") return dtypeWrap("Add", other); - return wraps("add", other); + if (Series.isSeries(other)) { + return wrap("add", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, quantile(quantile, interpolation = "nearest") { return _s.quantile(quantile, interpolation);