From 4bcc68cbf5fce50d889e1872347f2cff4c22beb8 Mon Sep 17 00:00:00 2001 From: Darek Date: Mon, 30 Sep 2024 11:49:22 -0400 Subject: [PATCH] Fixing series fold (#271) Fixing series fold to close #79 --- __tests__/dataframe.test.ts | 42 ++++++++++++++++++-------------- polars/series/index.ts | 48 ++++++++++++++++++++++++++++++------- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/__tests__/dataframe.test.ts b/__tests__/dataframe.test.ts index 39e0c191..f87dd0c2 100644 --- a/__tests__/dataframe.test.ts +++ b/__tests__/dataframe.test.ts @@ -338,24 +338,30 @@ describe("dataframe", () => { const actual = df.fold((a, b) => a.concat(b)); expect(actual).toSeriesEqual(expected); }); - // test("fold", () => { - // const s1 = pl.Series([1, 2, 3]); - // const s2 = pl.Series([4, 5, 6]); - // const s3 = pl.Series([7, 8, 1]); - // const expected = pl.Series("foo", [true, true, false]); - // const df = pl.DataFrame([s1, s2, s3]); - // const actual = df.fold((a, b) => a.lessThan(b)).alias("foo"); - // expect(actual).toSeriesEqual(expected); - // }); - // test("fold-again", () => { - // const s1 = pl.Series([1, 2, 3]); - // const s2 = pl.Series([4, 5, 6]); - // const s3 = pl.Series([7, 8, 1]); - // const expected = pl.Series("foo", [12, 15, 10]); - // const df = pl.DataFrame([s1, s2, s3]); - // const actual = df.fold((a, b) => a.plus(b)).alias("foo"); - // expect(actual).toSeriesEqual(expected); - // }); + it.each` + name | actual | expected + ${"fold:lessThan"} | ${df.fold((a, b) => a.lessThan(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])} + ${"fold:lt"} | ${df.fold((a, b) => a.lt(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])} + ${"fold:lessThanEquals"} | ${df.fold((a, b) => a.lessThanEquals(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])} + ${"fold:ltEq"} | ${df.fold((a, b) => a.ltEq(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])} + ${"fold:neq"} | ${df.fold((a, b) => a.neq(b)).alias("foo")} | ${pl.Series("foo", [true, false, true])} + ${"fold:plus"} | ${df.fold((a, b) => a.plus(b)).alias("foo")} | ${pl.Series("foo", [7, 4, 17])} + ${"fold:minus"} | ${df.fold((a, b) => a.minus(b)).alias("foo")} | ${pl.Series("foo", [-5, 0, 1])} + ${"fold:mul"} | ${df.fold((a, b) => a.mul(b)).alias("foo")} | ${pl.Series("foo", [6, 4, 72])} + `("$# $name expected matches actual", ({ expected, actual }) => { + expect(expected).toSeriesEqual(actual); + }); + test("fold:lt", () => { + const s1 = pl.Series([1, 2, 3]); + const s2 = pl.Series([4, 5, 6]); + const s3 = pl.Series([7, 8, 1]); + const df = pl.DataFrame([s1, s2, s3]); + const expected = pl.Series("foo", [true, true, false]); + let actual = df.fold((a, b) => a.lessThan(b)).alias("foo"); + expect(actual).toSeriesEqual(expected); + actual = df.fold((a, b) => a.lt(b)).alias("foo"); + expect(actual).toSeriesEqual(expected); + }); test("frameEqual:true", () => { const df = pl.DataFrame({ foo: [1, 2, 3], diff --git a/polars/series/index.ts b/polars/series/index.ts index fcd36599..861626da 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -1528,16 +1528,32 @@ export function _Series(_s: any): Series { return this.length; }, lt(field) { - return dtypeWrap("Lt", field); + if (typeof field === "number") return dtypeWrap("Lt", field); + if (Series.isSeries(field)) { + return wrap("lt", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, lessThan(field) { - return dtypeWrap("Lt", field); + if (typeof field === "number") return dtypeWrap("Lt", field); + if (Series.isSeries(field)) { + return wrap("lt", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, ltEq(field) { - return dtypeWrap("LtEq", field); + if (typeof field === "number") return dtypeWrap("LtEq", field); + if (Series.isSeries(field)) { + return wrap("ltEq", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, lessThanEquals(field) { - return dtypeWrap("LtEq", field); + if (typeof field === "number") return dtypeWrap("LtEq", field); + if (Series.isSeries(field)) { + return wrap("ltEq", (field as any)._s); + } + throw new Error("Not a number nor a series"); }, limit(n = 10) { return wrap("limit", n); @@ -1558,16 +1574,28 @@ export function _Series(_s: any): Series { return wrap("mode"); }, minus(other) { - return dtypeWrap("Sub", other); + if (typeof other === "number") return dtypeWrap("Sub", other); + if (Series.isSeries(other)) { + return wrap("sub", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, mul(other) { - return dtypeWrap("Mul", other); + if (typeof other === "number") return dtypeWrap("Mul", other); + if (Series.isSeries(other)) { + return wrap("mul", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, nChunks() { return _s.nChunks(); }, neq(other) { - return dtypeWrap("Neq", other); + if (typeof other === "number") return dtypeWrap("Neq", other); + if (Series.isSeries(other)) { + return wrap("neq", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, notEquals(other) { return this.neq(other); @@ -1585,7 +1613,11 @@ export function _Series(_s: any): Series { return expr_op("peakMin"); }, plus(other) { - return dtypeWrap("Add", other); + if (typeof other === "number") return dtypeWrap("Add", other); + if (Series.isSeries(other)) { + return wrap("add", (other as any)._s); + } + throw new Error("Not a number nor a series"); }, quantile(quantile, interpolation = "nearest") { return _s.quantile(quantile, interpolation);