diff --git a/src/datafilters/dataset_builder.py b/src/datafilters/dataset_builder.py index d221764..c60bf7f 100644 --- a/src/datafilters/dataset_builder.py +++ b/src/datafilters/dataset_builder.py @@ -4,14 +4,12 @@ from collections.abc import Collection, Iterable from datetime import datetime from pathlib import Path - +from typing import TypeVar import polars as pl -from cfg_param_wrapper import CfgDict from polars import DataFrame, Expr from tqdm import tqdm from .base_filters import Column, Comparable, DataFilter, FastComparable -from .custom_toml import TomlCustomCommentDecoder, TomlCustomCommentEncoder def current_time() -> datetime: @@ -22,6 +20,9 @@ def _time(_=None) -> datetime: return current_time() +T = TypeVar("T") + + class DatasetBuilder: def __init__(self, origin: str, db_path: Path) -> None: super().__init__() @@ -31,8 +32,8 @@ def __init__(self, origin: str, db_path: Path) -> None: self.filepath: Path = db_path - self.basic_schema = {"path": str, "checkedtime": pl.Datetime} - self.filter_type_schema = self.basic_schema.copy() + self.basic_schema: dict[str, pl.DataType | type] = {"path": str, "checkedtime": pl.Datetime} + self.filter_type_schema: dict[str, pl.DataType | type] = self.basic_schema.copy() self.build_schema: dict[str, Expr] = {"checkedtime": pl.col("path").apply(_time)} if os.path.exists(self.filepath): @@ -97,8 +98,8 @@ def populate_df( cond: Expr = now - pl.col("checkedtime") < datetime.fromtimestamp(trim_age_limit) if trim_check_exists: cond &= pl.col("path").apply(os.path.exists) - self.df = self.df.filter(cond) + from_full_to_relative: dict[str, Path] = self.absolute_dict(lst) # add new paths to the dataframe with missing data existing_paths = set(self.df.select(pl.col("path")).to_series()) @@ -119,32 +120,37 @@ def populate_df( pl.any(pl.col(col).is_null() for col in updated_df.columns if col in search_cols) ) if len(unfinished): - with tqdm(desc="Gathering file info...", total=len(unfinished)) as t: + with ( + tqdm(desc="Gathering file info...", unit="file", total=len(unfinished)) as total_t, + tqdm(desc="Processing chunks...", unit="chunk", total=len(unfinished) // chunksize) as sub_t, + ): save_timer = datetime.now() collected_data: DataFrame = DataFrame(schema=self.filter_type_schema) - for nulls, group in ( - unfinished.with_columns(self.filter_type_schema) - .with_row_count("idx") - .with_columns(pl.col("idx") // chunksize) - .groupby("idx", *(pl.col(col).is_not_null() for col in self.build_schema)) + for nulls, group in unfinished.with_columns(self.filter_type_schema).groupby( + *(pl.col(col).is_not_null() for col in self.build_schema) ): - t.set_postfix_str(str(nulls)) - new_data = group.drop("idx").with_columns( - **{ - col: expr - for truth, (col, expr) in zip(nulls[1:], self.build_schema.items()) # type: ignore - if not truth - } - ) - collected_data = pl.concat([collected_data, new_data], how="diagonal") - t.update(len(group)) - if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval: - self.df = self.df.update(collected_data, on="path") - self.save_df() - t.set_postfix_str(f"Autosaved at {current_time()}") - collected_data = collected_data.clear() - save_timer = new_time - self.df = self.df.update(collected_data, on="path").rechunk() + group_expr = { + col: expr + for truth, (col, expr) in zip(nulls, self.build_schema.items()) # type: ignore + if not truth + } + subgroups = list(self._split_into_chunks(group, chunksize=chunksize)) + total_t.set_postfix_str(str(tuple(group_expr.keys()))) + sub_t.total = len(subgroups) + sub_t.n = 0 + + for idx, subgroup in subgroups: + new_data: DataFrame = subgroup.with_columns(**group_expr) + collected_data = pl.concat((collected_data, new_data), how="diagonal") + if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval: + self.df = self.df.update(collected_data, on="path") + self.save_df() + collected_data = collected_data.clear() + save_timer = new_time + sub_t.set_postfix_str(str(idx)) + sub_t.update(1) + total_t.update(len(subgroup)) + self.df = self.df.update(collected_data, on="path").rechunk() self.save_df() return @@ -152,6 +158,8 @@ def filter(self, lst, sort_col="path") -> list[Path]: assert ( sort_col in self.df.columns ), f"the column '{sort_col}' is not in the database. Available columns: {self.df.columns}" + if len(self.unready_filters): + warnings.warn(f"{len(self.unready_filters)} filters are not initialized and will not be populated") from_full_to_relative: dict[str, Path] = self.absolute_dict(lst) paths: set[str] = set(from_full_to_relative.keys()) @@ -192,6 +200,15 @@ def _make_schema_compliant(data_frame: DataFrame, schema) -> DataFrame: """adds columns from the schema to the dataframe. (not in-place)""" return pl.concat([data_frame, DataFrame(schema=schema)], how="diagonal") + @staticmethod + def _split_into_chunks(df: DataFrame, chunksize: int, column="_idx"): + return ( + ((idx, len(part)), part.drop(column)) + for idx, part in df.with_row_count(column) + .with_columns(pl.col(column) // chunksize) + .groupby(column, maintain_order=True) + ) + def __enter__(self, *args, **kwargs): self.__init__(*args, **kwargs) return self