From 4148204bf45924515141a1e0e312ab5b333a6a22 Mon Sep 17 00:00:00 2001 From: zeptofine Date: Mon, 23 Oct 2023 13:52:23 -0400 Subject: [PATCH] more types, docstrings, simplify, optimizations --- imdataset_creator/__main__.py | 21 ++--- imdataset_creator/config_handler.py | 21 ++++- imdataset_creator/datarules/base_rules.py | 29 +++---- imdataset_creator/datarules/data_rules.py | 4 +- .../datarules/dataset_builder.py | 77 +++++++++++-------- imdataset_creator/datarules/image_rules.py | 2 +- imdataset_creator/file_list.py | 23 ------ imdataset_creator/gui/output_view.py | 17 ---- imdataset_creator/gui/producer_views.py | 2 +- imdataset_creator/gui/rule_views.py | 8 +- imdataset_creator/image_filters/destroyers.py | 3 +- 11 files changed, 96 insertions(+), 111 deletions(-) delete mode 100644 imdataset_creator/file_list.py diff --git a/imdataset_creator/__main__.py b/imdataset_creator/__main__.py index b2b4dc6..d290609 100644 --- a/imdataset_creator/__main__.py +++ b/imdataset_creator/__main__.py @@ -1,6 +1,5 @@ import json import logging -import time from datetime import datetime from multiprocessing import Pool, cpu_count, freeze_support from pathlib import Path @@ -22,7 +21,6 @@ File, FileScenario, MainConfig, - alphanumeric_sort, chunk_split, ) @@ -78,27 +76,24 @@ def main( if verbose: p.log(pformat(db_cfg)) # Gather images - images: dict[Path, list[Path]] = {} resolved: dict[str, File] = {} count_t = p.add_task("Gathering", total=None) - for folder, gen in db_cfg.gather_images(): - images[folder] = sorted(gen, key=lambda x: alphanumeric_sort(str(x)), reverse=True) - for pth in images[folder]: + + for folder, lst in db_cfg.gather_images(sort=True, reverse=True): + for pth in lst: resolved[str((folder / pth).resolve())] = File.from_src(folder, pth) - p.update(count_t, advance=len(images[folder])) + p.update(count_t, advance=len(lst)) - if diff := sum(map(len, images.values())) - len(resolved): + if diff := p.tasks[count_t].completed - len(resolved): p.log(f"removed an estimated {diff} conflicting symlinks") p.update(count_t, total=len(resolved), completed=len(resolved)) total_t = p.add_task("populating df", total=None) db.add_new_paths(set(resolved)) - db_schema = db.type_schema - - db.comply_to_schema(db_schema) - unfinished: DataFrame = db.get_unfinished_existing() + db.comply_to_schema(db.type_schema, in_place=True) + unfinished: DataFrame = db.get_unfinished_existing().collect() if not unfinished.is_empty(): p.update(total_t, total=len(unfinished)) null_t = p.add_task("nulls set", total=None) @@ -152,7 +147,7 @@ def main( files: list[File] if db_cfg.rules: filter_t = p.add_task("filtering", total=0) - files = [resolved[file] for file in db.filter(set(resolved))] + files = [resolved[file] for file in db.filter(set(resolved)).get_column("path")] p.update(filter_t, total=len(files), completed=len(files)) else: files = list(resolved.values()) diff --git a/imdataset_creator/config_handler.py b/imdataset_creator/config_handler.py index cc6cced..074fffa 100644 --- a/imdataset_creator/config_handler.py +++ b/imdataset_creator/config_handler.py @@ -1,8 +1,11 @@ from collections.abc import Generator, Iterable from pathlib import Path +from typing import overload +from .alphanumeric_sort import alphanumeric_sort from .configs import MainConfig, _repr_indent from .datarules import Input, Output, Producer, Rule +from .datarules.base_rules import PathGenerator from .file import File from .scenarios import FileScenario, OutputScenario @@ -21,9 +24,23 @@ def __init__(self, cfg: MainConfig): # generate `Rule`s self.rules: list[Rule] = [Rule.all_rules[r["name"]].from_cfg(r["data"]) for r in cfg["rules"]] - def gather_images(self) -> Generator[tuple[Path, Generator[Path, None, None]], None, None]: + @overload + def gather_images(self, sort=True, reverse=False) -> Generator[tuple[Path, list[Path]], None, None]: + ... + + @overload + def gather_images(self, sort=False, reverse=False) -> Generator[tuple[Path, PathGenerator], None, None]: + ... + + def gather_images( + self, sort=False, reverse=False + ) -> Generator[tuple[Path, PathGenerator | list[Path]], None, None]: for input_ in self.inputs: - yield input_.folder, input_.run() + gen = input_.run() + if sort: + yield input_.folder, list(map(Path, sorted(map(str, gen), key=alphanumeric_sort, reverse=reverse))) + else: + yield input_.folder, gen def get_outputs(self, file: File) -> list[OutputScenario]: return [ diff --git a/imdataset_creator/datarules/base_rules.py b/imdataset_creator/datarules/base_rules.py index c588882..a42fd0d 100644 --- a/imdataset_creator/datarules/base_rules.py +++ b/imdataset_creator/datarules/base_rules.py @@ -3,12 +3,12 @@ import textwrap from abc import abstractmethod from collections import defaultdict -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass from pathlib import Path from string import Formatter from types import MappingProxyType -from typing import Any, ClassVar +from typing import Any, ClassVar, Iterable import numpy as np import wcmatch.glob as wglob @@ -36,10 +36,8 @@ class DataFrameMatcher: def __init__(self, func: Callable[[PartialDataFrame, FullDataFrame], DataFrame]): """ - Parameters - ---------- - func : Callable[[PartialDataFrame, FullDataFrame], DataFrame] - A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered + Args: + func (Callable[[PartialDataFrame, FullDataFrame], DataFrame]): A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered DataFrame as an output. """ self.func = func @@ -56,8 +54,8 @@ class ExprMatcher: expr: Expr - def __init__(self, expr: Expr): - self.expr = expr + def __init__(self, *exprs: Expr): + self.expr = combine_expr_conds(exprs) def __call__(self) -> Expr: return self.expr @@ -146,6 +144,9 @@ def run(self, img: np.ndarray) -> np.ndarray: flags: int = wglob.BRACE | wglob.SPLIT | wglob.EXTMATCH | wglob.IGNORECASE | wglob.GLOBSTAR +PathGenerator = Generator[Path, None, None] + + @dataclass(repr=False) class Input(Keyworded): """ @@ -166,7 +167,7 @@ class Input(Keyworded): def from_cfg(cls, cfg: InputData): return cls(Path(cfg["folder"]), cfg["expressions"]) - def run(self): + def run(self) -> PathGenerator: """ Yield the paths of all files in the input folder that match the glob patterns. @@ -268,7 +269,7 @@ def from_cfg(cls, cfg: OutputData): ) -def combine_expr_conds(exprs: list[Expr]) -> Expr: +def combine_expr_conds(exprs: Iterable[Expr]) -> Expr: """ Combine a list of `Expr` objects using the `&` operator. @@ -282,7 +283,9 @@ def combine_expr_conds(exprs: list[Expr]) -> Expr: Expr A single `Expr` object representing the combined expression. """ - comp: Expr = exprs[0] - for e in exprs[1:]: - comp &= e + comp: Expr | None = None + for e in exprs: + comp = comp & e if comp is None else e + assert comp is not None + return comp diff --git a/imdataset_creator/datarules/data_rules.py b/imdataset_creator/datarules/data_rules.py index 36857f2..76cc1ed 100644 --- a/imdataset_creator/datarules/data_rules.py +++ b/imdataset_creator/datarules/data_rules.py @@ -87,7 +87,7 @@ def __init__( if self.before is not None and self.after is not None and self.after > self.before: raise self.AgeError(self.after, self.before) - self.matcher = ExprMatcher(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(*exprs) @classmethod def from_cfg(cls, cfg) -> Self: @@ -131,7 +131,7 @@ def __init__( if self.blacklist: exprs.extend(col("path").str.contains(item).is_not() for item in self.blacklist) - self.matcher = ExprMatcher(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(*exprs) @classmethod def get_cfg(cls) -> BlackWhitelistData: diff --git a/imdataset_creator/datarules/dataset_builder.py b/imdataset_creator/datarules/dataset_builder.py index a16ea05..de9bfd4 100644 --- a/imdataset_creator/datarules/dataset_builder.py +++ b/imdataset_creator/datarules/dataset_builder.py @@ -8,7 +8,7 @@ from typing import Generator, Literal, TypeVar, overload import polars as pl -from polars import DataFrame, Expr +from polars import DataFrame, Expr, LazyFrame from polars.type_aliases import SchemaDefinition from .base_rules import ( @@ -39,6 +39,17 @@ def chunk_split( chunksize: int, col_name: str = "_idx", ) -> Generator[DataFrame, None, None]: + """ + Splits a dataframe into chunks based on index. + + Args: + df (DataFrame): the dataframe to split + chunksize (int): the size of each resulting chunk + col_name (str, optional): the name of the temporary chunk. Defaults to "_idx". + + Yields: + Generator[DataFrame, None, None]: a chunk + """ return ( part.drop(col_name) for _, part in df.with_row_count(col_name) @@ -47,25 +58,23 @@ def chunk_split( ) -def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | DataFrameMatcher]: - """this combines expressions from different objects to a list of compressed expressions. - DataRules that are `ExprComparer`s can be combined, but `DataFrameComparer`s cannot. They will be copied to the list. +def combine_matchers( + matchers: Iterable[DataFrameMatcher | ExprMatcher], +) -> Generator[Expr | DataFrameMatcher, None, None]: + """this combines expressions from different matchers to compressed expressions. + DataRules that are `ExprMatcher`s can be combined, but `DataFrameMatcher`s cannot. They will be copied to the list. """ - combinations: list[Expr | bool | DataFrameMatcher] = [] combination: Expr | bool | None = None - for rule in rules: - comparer: DataFrameMatcher | ExprMatcher = rule.matcher - if isinstance(comparer, ExprMatcher): - combination = combination & comparer() if combination is not None else comparer() - elif isinstance(comparer, DataFrameMatcher): + for matcher in matchers: + if isinstance(matcher, DataFrameMatcher): if combination is not None: - combinations.append(combination) + yield combination combination = None - combinations.append(comparer) + yield matcher + elif isinstance(matcher, ExprMatcher): + combination = combination & matcher() if combination is not None else matcher() if combination is not None: - combinations.append(combination) - - return combinations + yield combination def blacklist_schema(schema: ProducerSchema, blacklist: Collection) -> ProducerSchema: @@ -185,7 +194,15 @@ def remove_finished_producers(self) -> ProducerSet: self.producers = new return ProducerSet(old - new) + @overload def unfinished_by_col(self, df: DataFrame, cols: Iterable[str] | None = None) -> DataFrame: + ... + + @overload + def unfinished_by_col(self, df: LazyFrame, cols: Iterable[str] | None = None) -> LazyFrame: + ... + + def unfinished_by_col(self, df: DataFrame | LazyFrame, cols: Iterable[str] | None = None) -> DataFrame | LazyFrame: if cols is None: cols = set(self.type_schema) & set(df.columns) return df.filter(pl.any_horizontal(pl.col(col).is_null() for col in cols)) @@ -232,7 +249,7 @@ def populate_chunks( chunks: Iterable[DataFrame], schemas: ProducerSchema, db_schema: DataTypeSchema | None = None, - ): + ) -> Generator[DataFrame, None, None]: if db_schema is None: db_schema = self.type_schema chunk: DataFrame @@ -248,37 +265,31 @@ def df_with_types(self, types: DataTypeSchema | None = None): types = self.type_schema return self.comply_to_schema(types).with_columns(types) - def get_unfinished(self) -> DataFrame: + def get_unfinished(self) -> LazyFrame: # check if producers are completely finished type_schema: DataTypeSchema = self.type_schema self.comply_to_schema(type_schema, in_place=True) - return self.unfinished_by_col(self.__df.with_columns(type_schema)) + return self.unfinished_by_col(self.__df.lazy().with_columns(type_schema)) - def get_unfinished_existing(self) -> DataFrame: - unfinished = self.get_unfinished() - if not len(unfinished): - return unfinished - return unfinished.filter(pl.col("path").apply(os.path.exists)) + def get_unfinished_existing(self) -> LazyFrame: + return self.get_unfinished().filter(pl.col("path").apply(os.path.exists)) - def filter(self, lst, sort_col="path") -> Iterable[str]: # noqa: A003 - assert sort_col in self.__df.columns, f"'{sort_col}' is not in {self.__df.columns}" + def filter(self, lst) -> DataFrame: # noqa: A003 if len(self.unready_rules): warnings.warn( f"{len(self.unready_rules)} filters are not initialized and will not be populated", stacklevel=2 ) vdf: DataFrame = self.__df.filter(pl.col("path").is_in(lst)) - combined = combine_exprs(self.rules) - for f in combined: - vdf = f(vdf, self.__df) if isinstance(f, DataFrameMatcher) else vdf.filter(f) - - return vdf.sort(sort_col).get_column("path") + for matcher in combine_matchers([rule.matcher for rule in self.rules]): + vdf = matcher(vdf, self.__df) if isinstance(matcher, DataFrameMatcher) else vdf.filter(matcher) + return vdf def save_df(self, pth: str | Path | None = None) -> None: """saves the dataframe to self.filepath""" self.__df.write_ipc(pth or self.filepath) - def update(self, df: DataFrame, on="path", how: Literal["left", "inner"] = "left"): + def update(self, df: DataFrame, on="path", how: Literal["left", "inner", "outer"] = "left"): self.__df = self.__df.update(df, on=on, how=how) def trigger_save_via_time( @@ -316,11 +327,11 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(attrlist)})" @overload - def comply_to_schema(self, schema: SchemaDefinition, in_place: Literal[False] = False) -> DataFrame: + def comply_to_schema(self, schema: SchemaDefinition) -> DataFrame: ... @overload - def comply_to_schema(self, schema: SchemaDefinition, in_place: Literal[True] = True) -> None: + def comply_to_schema(self, schema: SchemaDefinition, in_place=True) -> None: ... def comply_to_schema(self, schema: SchemaDefinition, in_place: bool = False) -> DataFrame | None: diff --git a/imdataset_creator/datarules/image_rules.py b/imdataset_creator/datarules/image_rules.py index 571c275..7b154b2 100644 --- a/imdataset_creator/datarules/image_rules.py +++ b/imdataset_creator/datarules/image_rules.py @@ -82,7 +82,7 @@ def __init__( if max_res: exprs.append((largest // scale * scale if crop else largest) <= max_res) - self.matcher = ExprMatcher(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(*exprs) @classmethod def get_cfg(cls) -> ResData: diff --git a/imdataset_creator/file_list.py b/imdataset_creator/file_list.py deleted file mode 100644 index f25d08e..0000000 --- a/imdataset_creator/file_list.py +++ /dev/null @@ -1,23 +0,0 @@ -from collections.abc import Generator -from os import sep -from pathlib import Path - - -def get_file_list(folder, *patterns: str) -> Generator[Path, None, None]: - """ - Args folders: One or more folder paths. - Returns list[Path]: paths in the specified folders.""" - - return (y for pattern in patterns for y in folder.rglob(pattern)) - - -def to_recursive(path: Path | str, recursive: bool = False, replace_spaces: bool = False) -> Path: - """Convert the file path to a recursive path if recursive is False - (Also replaces spaces with underscores) - Ex: i/path/to/image.png => i/path_to_image.png""" - new_pth: str = str(path) - if replace_spaces and " " in new_pth: - new_pth = new_pth.replace(" ", "_") - if not recursive and sep in new_pth: - new_pth = new_pth.replace(sep, "_") - return Path(new_pth) diff --git a/imdataset_creator/gui/output_view.py b/imdataset_creator/gui/output_view.py index 0e5df08..1911c27 100644 --- a/imdataset_creator/gui/output_view.py +++ b/imdataset_creator/gui/output_view.py @@ -13,23 +13,6 @@ from .output_filters import FilterList -class InvalidFormatException(Exception): - def __init__(self, disallowed: str): - super().__init__(f"invalid format string. '{disallowed}' is not allowed.") - - -class SafeFormatter(Formatter): - def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - # the goal is to make sure `property`s and indexing is still available, while dunders and things are not - if "__" in field_name: - raise InvalidFormatException("__") - - return super().get_field(field_name, args, kwargs) - - -output_formatter = SafeFormatter() - - class OutputView(InputView): bound_item = Output diff --git a/imdataset_creator/gui/producer_views.py b/imdataset_creator/gui/producer_views.py index d61493a..504832b 100644 --- a/imdataset_creator/gui/producer_views.py +++ b/imdataset_creator/gui/producer_views.py @@ -74,7 +74,7 @@ class ProducerList(BuilderDependencyList): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__registered_by: dict[str, type[ProducerView]] = {} - self.set_text("Rules") + self.set_text("Producers") self.register_item( FileInfoProducerView, ImShapeProducerView, diff --git a/imdataset_creator/gui/rule_views.py b/imdataset_creator/gui/rule_views.py index d7f9773..5e06048 100644 --- a/imdataset_creator/gui/rule_views.py +++ b/imdataset_creator/gui/rule_views.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from typing import Self from PySide6.QtCore import QDate, QDateTime, QTime, Slot from PySide6.QtWidgets import QCheckBox, QDateTimeEdit, QLabel, QLineEdit, QSpinBox, QTextEdit @@ -31,11 +32,11 @@ def get(self) -> base_rules.Rule: return base_rules.Rule() @classmethod - def __wrap_get(cls: type[RuleView]): + def __wrap_get(cls: type[Self]): original_get = cls.get original_get_config = cls.get_config - def get_wrapper(self: RuleView): + def get_wrapper(self: Self): rule = original_get(self) if rule.requires: if isinstance(rule.requires, base_rules.DataColumn): @@ -44,7 +45,7 @@ def get_wrapper(self: RuleView): self.set_requires(str(set({r.name for r in rule.requires}))) return rule - def get_config_wrapper(self: RuleView): + def get_config_wrapper(self: Self): self.get() return original_get_config(self) @@ -318,7 +319,6 @@ def configure_settings_group(self): self.group_grid.addWidget(self.resolver, 1, 0, 1, 2) def get(self): - super().get() return image_rules.HashRule(resolver=self.resolver.text()) def get_config(self): diff --git a/imdataset_creator/image_filters/destroyers.py b/imdataset_creator/image_filters/destroyers.py index 624b874..a364ca5 100644 --- a/imdataset_creator/image_filters/destroyers.py +++ b/imdataset_creator/image_filters/destroyers.py @@ -1,11 +1,10 @@ import logging import subprocess -import typing from dataclasses import dataclass from enum import Enum from math import sqrt from random import choice, randint -from typing import Literal, Self +from typing import Self import cv2 import ffmpeg