Skip to content

Commit

Permalink
add more tqdm stuff and reorder population chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Aug 19, 2023
1 parent c91a106 commit a8ca167
Showing 1 changed file with 46 additions and 29 deletions.
75 changes: 46 additions & 29 deletions src/datafilters/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from collections.abc import Collection, Iterable
from datetime import datetime
from pathlib import Path

from typing import TypeVar
import polars as pl
from cfg_param_wrapper import CfgDict
from polars import DataFrame, Expr
from tqdm import tqdm

from .base_filters import Column, Comparable, DataFilter, FastComparable
from .custom_toml import TomlCustomCommentDecoder, TomlCustomCommentEncoder


def current_time() -> datetime:
Expand All @@ -22,6 +20,9 @@ def _time(_=None) -> datetime:
return current_time()


T = TypeVar("T")


class DatasetBuilder:
def __init__(self, origin: str, db_path: Path) -> None:
super().__init__()
Expand All @@ -31,8 +32,8 @@ def __init__(self, origin: str, db_path: Path) -> None:

self.filepath: Path = db_path

self.basic_schema = {"path": str, "checkedtime": pl.Datetime}
self.filter_type_schema = self.basic_schema.copy()
self.basic_schema: dict[str, pl.DataType | type] = {"path": str, "checkedtime": pl.Datetime}
self.filter_type_schema: dict[str, pl.DataType | type] = self.basic_schema.copy()

self.build_schema: dict[str, Expr] = {"checkedtime": pl.col("path").apply(_time)}
if os.path.exists(self.filepath):
Expand Down Expand Up @@ -97,8 +98,8 @@ def populate_df(
cond: Expr = now - pl.col("checkedtime") < datetime.fromtimestamp(trim_age_limit)
if trim_check_exists:
cond &= pl.col("path").apply(os.path.exists)

self.df = self.df.filter(cond)

from_full_to_relative: dict[str, Path] = self.absolute_dict(lst)
# add new paths to the dataframe with missing data
existing_paths = set(self.df.select(pl.col("path")).to_series())
Expand All @@ -119,39 +120,46 @@ def populate_df(
pl.any(pl.col(col).is_null() for col in updated_df.columns if col in search_cols)
)
if len(unfinished):
with tqdm(desc="Gathering file info...", total=len(unfinished)) as t:
with (
tqdm(desc="Gathering file info...", unit="file", total=len(unfinished)) as total_t,
tqdm(desc="Processing chunks...", unit="chunk", total=len(unfinished) // chunksize) as sub_t,
):
save_timer = datetime.now()
collected_data: DataFrame = DataFrame(schema=self.filter_type_schema)
for nulls, group in (
unfinished.with_columns(self.filter_type_schema)
.with_row_count("idx")
.with_columns(pl.col("idx") // chunksize)
.groupby("idx", *(pl.col(col).is_not_null() for col in self.build_schema))
for nulls, group in unfinished.with_columns(self.filter_type_schema).groupby(
*(pl.col(col).is_not_null() for col in self.build_schema)
):
t.set_postfix_str(str(nulls))
new_data = group.drop("idx").with_columns(
**{
col: expr
for truth, (col, expr) in zip(nulls[1:], self.build_schema.items()) # type: ignore
if not truth
}
)
collected_data = pl.concat([collected_data, new_data], how="diagonal")
t.update(len(group))
if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval:
self.df = self.df.update(collected_data, on="path")
self.save_df()
t.set_postfix_str(f"Autosaved at {current_time()}")
collected_data = collected_data.clear()
save_timer = new_time
self.df = self.df.update(collected_data, on="path").rechunk()
group_expr = {
col: expr
for truth, (col, expr) in zip(nulls, self.build_schema.items()) # type: ignore
if not truth
}
subgroups = list(self._split_into_chunks(group, chunksize=chunksize))
total_t.set_postfix_str(str(tuple(group_expr.keys())))
sub_t.total = len(subgroups)
sub_t.n = 0

for idx, subgroup in subgroups:
new_data: DataFrame = subgroup.with_columns(**group_expr)
collected_data = pl.concat((collected_data, new_data), how="diagonal")
if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval:
self.df = self.df.update(collected_data, on="path")
self.save_df()
collected_data = collected_data.clear()
save_timer = new_time
sub_t.set_postfix_str(str(idx))
sub_t.update(1)
total_t.update(len(subgroup))
self.df = self.df.update(collected_data, on="path").rechunk()
self.save_df()
return

def filter(self, lst, sort_col="path") -> list[Path]:
assert (
sort_col in self.df.columns
), f"the column '{sort_col}' is not in the database. Available columns: {self.df.columns}"
if len(self.unready_filters):
warnings.warn(f"{len(self.unready_filters)} filters are not initialized and will not be populated")

from_full_to_relative: dict[str, Path] = self.absolute_dict(lst)
paths: set[str] = set(from_full_to_relative.keys())
Expand Down Expand Up @@ -192,6 +200,15 @@ def _make_schema_compliant(data_frame: DataFrame, schema) -> DataFrame:
"""adds columns from the schema to the dataframe. (not in-place)"""
return pl.concat([data_frame, DataFrame(schema=schema)], how="diagonal")

@staticmethod
def _split_into_chunks(df: DataFrame, chunksize: int, column="_idx"):
return (
((idx, len(part)), part.drop(column))
for idx, part in df.with_row_count(column)
.with_columns(pl.col(column) // chunksize)
.groupby(column, maintain_order=True)
)

def __enter__(self, *args, **kwargs):
self.__init__(*args, **kwargs)
return self
Expand Down

0 comments on commit a8ca167

Please sign in to comment.