From 8816b4682348b99ca9859487b6668623ca042e9c Mon Sep 17 00:00:00 2001 From: Luke Deen Taylor Date: Tue, 17 Dec 2024 13:28:12 -0500 Subject: [PATCH] Add additional generic types to DataFrame methods (#302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding generic types to a few more methods beyond what was added in #293 by @scarf005 Focusing mostly on adding identity types to methods which I believe don’t change the original type of the dataframe. I added “identity” type signatures to the following methods: > extend, fillNull, filter, interpolate, limit, max, mean, median, min, quantile, rechunk, shiftAndFill, shrinkToFit, slice, sort, std, sum, tail, unique, var, vstack, where, upsample These previously returned `DataFrame`, even when called on a well-typed DataFrame, but now return `DataFrame` (the original type) --- I also added better types for a few slightly more complex ones: - map - improved return type based on the function passed, but unimproved parameter type - nullCount - toRecords - toSeries - for now, returning a broad union type, rather than identifying the specific column by index - withColumn --- Along the way, I added minor fixes for the types of: 1. `pl.intRange` [[1]](https://github.com/pola-rs/nodejs-polars/pull/302/commits/890bf21bd2bf5ee222b666c21941f78c94af19a1) which had overloads in the wrong order leading to incorrect return types, and 2. the `pl.Series(name, values, dtype)` constructor [[2]](https://github.com/pola-rs/nodejs-polars/pull/302/commits/a2635bd971c4dfd5a83a57766ad1e993c2247309), whose strongly-typed overload was failing to apply in simple cases like `pl.Series("index", [0, 1, 2, 3, 4], pl.Int64)` when the input array used `number`s instead of `BigInt`s --- __tests__/expr.test.ts | 22 +++++---- polars/dataframe.ts | 92 ++++++++++++++++++++---------------- polars/datatypes/datatype.ts | 14 ++++++ polars/lazy/functions.ts | 42 ++++++++++------ polars/series/index.ts | 3 +- 5 files changed, 108 insertions(+), 65 deletions(-) diff --git a/__tests__/expr.test.ts b/__tests__/expr.test.ts index 1e5fad28..4d38192a 100644 --- a/__tests__/expr.test.ts +++ b/__tests__/expr.test.ts @@ -525,20 +525,22 @@ describe("expr", () => { a: [1, 2, 3, 3, 3], b: ["a", "a", "b", "a", "a"], }); - let actual = df.select(pl.len()); - let expected = pl.DataFrame({ len: [5] }); + const actual = df.select(pl.len()); + const expected = pl.DataFrame({ len: [5] }); expect(actual).toFrameEqual(expected); - actual = df.withColumn(pl.len()); - expected = df.withColumn(pl.lit(5).alias("len")); - expect(actual).toFrameEqual(expected); + const actual2 = df.withColumn(pl.len()); + const expected2 = df.withColumn(pl.lit(5).alias("len")); + expect(actual2).toFrameEqual(expected2); - actual = df.withColumn(pl.intRange(pl.len()).alias("index")); - expected = df.withColumn(pl.Series("index", [0, 1, 2, 3, 4], pl.Int64)); - expect(actual).toFrameEqual(expected); + const actual3 = df.withColumn(pl.intRange(pl.len()).alias("index")); + const expected3 = df.withColumn( + pl.Series("index", [0, 1, 2, 3, 4], pl.Int64), + ); + expect(actual3).toFrameEqual(expected3); - actual = df.groupBy("b").agg(pl.len()); - expect(actual.shape).toEqual({ height: 2, width: 2 }); + const actual4 = df.groupBy("b").agg(pl.len()); + expect(actual4.shape).toEqual({ height: 2, width: 2 }); }); test("list", () => { const df = pl.DataFrame({ diff --git a/polars/dataframe.ts b/polars/dataframe.ts index a798854c..f4de27a0 100644 --- a/polars/dataframe.ts +++ b/polars/dataframe.ts @@ -466,7 +466,7 @@ export interface DataFrame = any> * @param other DataFrame to vertically add. */ - extend(other: DataFrame): DataFrame; + extend(other: DataFrame): DataFrame; /** * Fill null/missing values by a filling strategy * @@ -480,7 +480,7 @@ export interface DataFrame = any> * - "one" * @returns DataFrame with None replaced with the filling strategy. */ - fillNull(strategy: FillNullStrategy): DataFrame; + fillNull(strategy: FillNullStrategy): DataFrame; /** * Filter the rows in the DataFrame based on a predicate expression. * ___ @@ -519,7 +519,7 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - filter(predicate: any): DataFrame; + filter(predicate: any): DataFrame; /** * Find the index of a column by name. * ___ @@ -764,7 +764,7 @@ export interface DataFrame = any> /** * Interpolate intermediate values. The interpolation method is linear. */ - interpolate(): DataFrame; + interpolate(): DataFrame; /** * Get a mask of all duplicated rows in this DataFrame. */ @@ -937,8 +937,11 @@ export interface DataFrame = any> * Get first N rows as DataFrame. * @see {@link head} */ - limit(length?: number): DataFrame; - map(func: (...args: any[]) => any): any[]; + limit(length?: number): DataFrame; + map( + // TODO: strong types for the mapping function + func: (row: any[], i: number, arr: any[][]) => ReturnT, + ): ReturnT[]; /** * Aggregate the columns of this DataFrame to their maximum value. @@ -962,8 +965,8 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - max(): DataFrame; - max(axis: 0): DataFrame; + max(): DataFrame; + max(axis: 0): DataFrame; max(axis: 1): Series; /** * Aggregate the columns of this DataFrame to their mean value. @@ -972,8 +975,8 @@ export interface DataFrame = any> * @param axis - either 0 or 1 * @param nullStrategy - this argument is only used if axis == 1 */ - mean(): DataFrame; - mean(axis: 0): DataFrame; + mean(): DataFrame; + mean(axis: 0): DataFrame; mean(axis: 1): Series; mean(axis: 1, nullStrategy?: "ignore" | "propagate"): Series; /** @@ -997,7 +1000,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - median(): DataFrame; + median(): DataFrame; /** * Unpivot a DataFrame from wide to long format. * @deprecated *since 0.13.0* use {@link unpivot} @@ -1059,8 +1062,8 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - min(): DataFrame; - min(axis: 0): DataFrame; + min(): DataFrame; + min(axis: 0): DataFrame; min(axis: 1): Series; /** * Get number of chunks used by the ChunkedArrays of this DataFrame. @@ -1087,12 +1090,14 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - nullCount(): DataFrame; + nullCount(): DataFrame<{ + [K in keyof T]: Series, K & string>; + }>; partitionBy( cols: string | string[], stable?: boolean, includeKey?: boolean, - ): DataFrame[]; + ): DataFrame[]; partitionBy( cols: string | string[], stable: boolean, @@ -1210,13 +1215,13 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - quantile(quantile: number): DataFrame; + quantile(quantile: number): DataFrame; /** * __Rechunk the data in this DataFrame to a contiguous allocation.__ * * This will make sure all subsequent operations have optimal and predictable performance. */ - rechunk(): DataFrame; + rechunk(): DataFrame; /** * __Rename column names.__ * ___ @@ -1443,12 +1448,15 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - shiftAndFill(n: number, fillValue: number): DataFrame; - shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame; + shiftAndFill(n: number, fillValue: number): DataFrame; + shiftAndFill({ + n, + fillValue, + }: { n: number; fillValue: number }): DataFrame; /** * Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data. */ - shrinkToFit(): DataFrame; + shrinkToFit(): DataFrame; shrinkToFit(inPlace: true): void; shrinkToFit({ inPlace }: { inPlace: true }): void; /** @@ -1477,8 +1485,8 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - slice({ offset, length }: { offset: number; length: number }): DataFrame; - slice(offset: number, length: number): DataFrame; + slice({ offset, length }: { offset: number; length: number }): DataFrame; + slice(offset: number, length: number): DataFrame; /** * Sort the DataFrame by column. * ___ @@ -1493,7 +1501,7 @@ export interface DataFrame = any> descending?: boolean, nullsLast?: boolean, maintainOrder?: boolean, - ): DataFrame; + ): DataFrame; sort({ by, reverse, // deprecated @@ -1504,7 +1512,7 @@ export interface DataFrame = any> reverse?: boolean; // deprecated nullsLast?: boolean; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; sort({ by, descending, @@ -1514,7 +1522,7 @@ export interface DataFrame = any> descending?: boolean; nullsLast?: boolean; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; /** * Aggregate the columns of this DataFrame to their standard deviation value. * ___ @@ -1536,7 +1544,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - std(): DataFrame; + std(): DataFrame; /** * Aggregate the columns of this DataFrame to their mean value. * ___ @@ -1544,8 +1552,8 @@ export interface DataFrame = any> * @param axis - either 0 or 1 * @param nullStrategy - this argument is only used if axis == 1 */ - sum(): DataFrame; - sum(axis: 0): DataFrame; + sum(): DataFrame; + sum(axis: 0): DataFrame; sum(axis: 1): Series; sum(axis: 1, nullStrategy?: "ignore" | "propagate"): Series; /** @@ -1595,7 +1603,7 @@ export interface DataFrame = any> * ╰─────────┴─────╯ * ``` */ - tail(length?: number): DataFrame; + tail(length?: number): DataFrame; /** * @deprecated *since 0.4.0* use {@link writeCSV} * @category Deprecated @@ -1614,7 +1622,7 @@ export interface DataFrame = any> * ``` * @category IO */ - toRecords(): Record[]; + toRecords(): { [K in keyof T]: DTypeToJs | null }[]; /** * compat with `JSON.stringify` @@ -1644,7 +1652,7 @@ export interface DataFrame = any> * ``` * @category IO */ - toObject(): { [K in keyof T]: DTypeToJs[] }; + toObject(): { [K in keyof T]: DTypeToJs[] }; /** * @deprecated *since 0.4.0* use {@link writeIPC} @@ -1656,7 +1664,7 @@ export interface DataFrame = any> * @category IO Deprecated */ toParquet(destination?, options?); - toSeries(index?: number): Series; + toSeries(index?: number): T[keyof T]; toString(): string; /** * Convert a ``DataFrame`` to a ``Series`` of type ``Struct`` @@ -1768,12 +1776,12 @@ export interface DataFrame = any> maintainOrder?: boolean, subset?: ColumnSelection, keep?: "first" | "last", - ): DataFrame; + ): DataFrame; unique(opts: { maintainOrder?: boolean; subset?: ColumnSelection; keep?: "first" | "last"; - }): DataFrame; + }): DataFrame; /** Decompose a struct into its fields. The fields will be inserted in to the `DataFrame` on the location of the `struct` type. @@ -1833,7 +1841,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - var(): DataFrame; + var(): DataFrame; /** * Grow this DataFrame vertically by stacking a DataFrame to it. * @param df - DataFrame to stack. @@ -1866,12 +1874,16 @@ export interface DataFrame = any> * ╰─────┴─────┴─────╯ * ``` */ - vstack(df: DataFrame): DataFrame; + vstack(df: DataFrame): DataFrame; /** * Return a new DataFrame with the column added or replaced. * @param column - Series, where the name of the Series refers to the column in the DataFrame. */ - withColumn(column: Series | Expr): DataFrame; + withColumn( + column: Series, + ): DataFrame< + Simplify }> + >; withColumn(column: Series | Expr): DataFrame; withColumns(...columns: (Expr | Series)[]): DataFrame; /** @@ -1896,7 +1908,7 @@ export interface DataFrame = any> */ withRowCount(name?: string): DataFrame; /** @see {@link filter} */ - where(predicate: any): DataFrame; + where(predicate: any): DataFrame; /** Upsample a DataFrame at a regular frequency. @@ -1972,13 +1984,13 @@ shape: (7, 3) every: string, by?: string | string[], maintainOrder?: boolean, - ): DataFrame; + ): DataFrame; upsample(opts: { timeColumn: string; every: string; by?: string | string[]; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; } function prepareOtherArg(anyValue: any): Series { diff --git a/polars/datatypes/datatype.ts b/polars/datatypes/datatype.ts index d3c38cb0..0a9400c2 100644 --- a/polars/datatypes/datatype.ts +++ b/polars/datatypes/datatype.ts @@ -491,6 +491,20 @@ export type DTypeToJs = T extends DataType.Decimal : T extends DataType.Utf8 ? string : never; +// some objects can be constructed with a looser JS type than they’d return when converted back to JS +export type DTypeToJsLoose = T extends DataType.Decimal + ? number | bigint + : T extends DataType.Float64 + ? number | bigint + : T extends DataType.Int64 + ? number | bigint + : T extends DataType.Int32 + ? number | bigint + : T extends DataType.Bool + ? boolean + : T extends DataType.Utf8 + ? string + : never; export type DtypeToJsName = T extends DataType.Decimal ? "Decimal" : T extends DataType.Float64 diff --git a/polars/lazy/functions.ts b/polars/lazy/functions.ts index 35988b8d..3b8385a3 100644 --- a/polars/lazy/functions.ts +++ b/polars/lazy/functions.ts @@ -204,8 +204,22 @@ export function intRange(opts: { end: number | Expr; step?: number | Expr; dtype?: DataType; + eager?: false; +}): Expr; +export function intRange
(opts: { + start: number | Expr; + end: number | Expr; + step?: number | Expr; + dtype?: DT; + eager: true; +}): Series
; +export function intRange
(opts: { + start: number | Expr; + end: number | Expr; + step?: number | Expr; + dtype?: DT; eager?: boolean; -}); +}): Expr | Series
; /** @deprecated *since 0.15.0* use `start` and `end` instead */ export function intRange(opts: { low: number | Expr; @@ -213,27 +227,27 @@ export function intRange(opts: { step?: number | Expr; dtype?: DataType; eager?: boolean; -}); +}): Expr | Series; export function intRange( start: number | Expr, end?: number | Expr, step?: number | Expr, dtype?: DataType, - eager?: true, -): Series; -export function intRange( + eager?: false, +): Expr; +export function intRange
( start: number | Expr, end?: number | Expr, step?: number | Expr, - dtype?: DataType, - eager?: false, -): Expr; -export function intRange( + dtype?: DT, + eager?: true, +): Series
; +export function intRange
( opts: any, - end?, + end?: number | Expr, step = 1 as number | Expr, - dtype: DataType = DataType.Int64, - eager?, + dtype?: DT, + eager?: boolean, ): Series | Expr { // @deprecated since 0.15.0 if (typeof opts?.low === "number") { @@ -256,7 +270,7 @@ export function intRange( .select(intRange(start, end, step).alias("intRange") as any) .getColumn("intRange") as any; } - return _Expr(pli.intRange(start, end, step, dtype)); + return _Expr(pli.intRange(start, end, step, dtype || DataType.Int64)); } /*** * Generate a range of integers for each row of the input columns. @@ -554,7 +568,7 @@ export function head(column: Series | ExprOrString, n?): Series | Expr { └───────┴──────┴──────┴─────┘ ``` */ -export function len(): any { +export function len(): Expr { return _Expr(pli.len()); } /** Get the last value. */ diff --git a/polars/series/index.ts b/polars/series/index.ts index 375616b9..d4cfadfe 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -2,6 +2,7 @@ import { DataFrame, _DataFrame } from "../dataframe"; import { DTYPE_TO_FFINAME, DataType, type Optional } from "../datatypes"; import type { DTypeToJs, + DTypeToJsLoose, DtypeToJsName, JsToDtype, JsType, @@ -1933,7 +1934,7 @@ export interface SeriesConstructor extends Deserialize { ): Series, Name>; ( name: Name, - values: ArrayLike>, + values: ArrayLike>, dtype?: T2, ): Series; (name: string, values: any[], dtype?): Series;