From 959230b49f56f0ed7fa78e674e7a57766cd4d39b Mon Sep 17 00:00:00 2001 From: zeptofine Date: Thu, 19 Oct 2023 14:21:21 -0400 Subject: [PATCH] rename classes Comparable to DataFrameMatcher and FastComparable to ExprMatcher --- imdataset_creator/datarules/__init__.py | 13 +++++++++++- imdataset_creator/datarules/base_rules.py | 21 ++++++++++++++----- imdataset_creator/datarules/data_rules.py | 16 ++++++++++---- .../datarules/dataset_builder.py | 18 ++++++++-------- imdataset_creator/datarules/image_rules.py | 12 +++++------ imdataset_creator/file.py | 2 +- 6 files changed, 56 insertions(+), 26 deletions(-) diff --git a/imdataset_creator/datarules/__init__.py b/imdataset_creator/datarules/__init__.py index b17e8a2..c82680f 100644 --- a/imdataset_creator/datarules/__init__.py +++ b/imdataset_creator/datarules/__init__.py @@ -1,3 +1,14 @@ from . import base_rules, data_rules, dataset_builder, image_rules -from .base_rules import Comparable, DataColumn, ExprDict, FastComparable, File, Filter, Input, Output, Producer, Rule +from .base_rules import ( + DataColumn, + DataFrameMatcher, + ExprDict, + ExprMatcher, + File, + Filter, + Input, + Output, + Producer, + Rule, +) from .dataset_builder import DatasetBuilder, chunk_split diff --git a/imdataset_creator/datarules/base_rules.py b/imdataset_creator/datarules/base_rules.py index a0d9b35..0fba72b 100644 --- a/imdataset_creator/datarules/base_rules.py +++ b/imdataset_creator/datarules/base_rules.py @@ -29,20 +29,31 @@ def indent(t): return textwrap.indent(t, " ") -class Comparable: +class DataFrameMatcher: + """A class that filters sections of a dataframe based on a function""" + func: Callable[[PartialDataFrame, FullDataFrame], DataFrame] def __init__(self, func: Callable[[PartialDataFrame, FullDataFrame], DataFrame]): + """ + Parameters + ---------- + func : Callable[[PartialDataFrame, FullDataFrame], DataFrame] + A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered + DataFrame as an output. + """ self.func = func def __call__(self, *args, **kwargs) -> DataFrame: return self.func(*args, **kwargs) def __repr__(self): - return f"Comparable({self.func.__name__})" + return f"DataFrameMatcher({self.func.__name__})" + +class ExprMatcher: + """A class that filters files based on an expression""" -class FastComparable: expr: Expr def __init__(self, expr: Expr): @@ -52,7 +63,7 @@ def __call__(self) -> Expr: return self.expr def __repr__(self) -> str: - return f"FastComparable({self.expr})" + return f"ExprMatcher({self.expr})" @dataclass(frozen=True) @@ -102,7 +113,7 @@ class Rule(Keyworded): """An abstract DataFilter format, for use in DatasetBuilder.""" requires: DataColumn | tuple[DataColumn, ...] - comparer: Comparable | FastComparable + matcher: DataFrameMatcher | ExprMatcher all_rules: ClassVar[dict[str, type[Rule]]] = {} diff --git a/imdataset_creator/datarules/data_rules.py b/imdataset_creator/datarules/data_rules.py index 9560a69..36857f2 100644 --- a/imdataset_creator/datarules/data_rules.py +++ b/imdataset_creator/datarules/data_rules.py @@ -10,7 +10,15 @@ from polars import DataFrame, Datetime, Expr, col from ..configs.configtypes import SpecialItemData -from .base_rules import Comparable, DataColumn, FastComparable, Producer, ProducerSchema, Rule, combine_expr_conds +from .base_rules import ( + DataColumn, + DataFrameMatcher, + ExprMatcher, + Producer, + ProducerSchema, + Rule, + combine_expr_conds, +) STAT_TRACKED = ("st_size", "st_atime", "st_mtime", "st_ctime") @@ -79,7 +87,7 @@ def __init__( if self.before is not None and self.after is not None and self.after > self.before: raise self.AgeError(self.after, self.before) - self.comparer = FastComparable(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(combine_expr_conds(exprs)) @classmethod def from_cfg(cls, cfg) -> Self: @@ -123,7 +131,7 @@ def __init__( if self.blacklist: exprs.extend(col("path").str.contains(item).is_not() for item in self.blacklist) - self.comparer = FastComparable(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(combine_expr_conds(exprs)) @classmethod def get_cfg(cls) -> BlackWhitelistData: @@ -134,7 +142,7 @@ class TotalLimitRule(Rule): def __init__(self, limit=1000): super().__init__() self.total = limit - self.comparer = Comparable(self.compare) + self.matcher = DataFrameMatcher(self.compare) def compare(self, selected: DataFrame, _) -> DataFrame: return selected.head(self.total) diff --git a/imdataset_creator/datarules/dataset_builder.py b/imdataset_creator/datarules/dataset_builder.py index 6d0a235..a16ea05 100644 --- a/imdataset_creator/datarules/dataset_builder.py +++ b/imdataset_creator/datarules/dataset_builder.py @@ -12,10 +12,10 @@ from polars.type_aliases import SchemaDefinition from .base_rules import ( - Comparable, + DataFrameMatcher, DataTypeSchema, ExprDict, - FastComparable, + ExprMatcher, Producer, ProducerSchema, ProducerSet, @@ -47,17 +47,17 @@ def chunk_split( ) -def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | Comparable]: +def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | DataFrameMatcher]: """this combines expressions from different objects to a list of compressed expressions. - DataRules that are FastComparable can be combined, but `Comparable`s cannot. They will be copied to the list. + DataRules that are `ExprComparer`s can be combined, but `DataFrameComparer`s cannot. They will be copied to the list. """ - combinations: list[Expr | bool | Comparable] = [] + combinations: list[Expr | bool | DataFrameMatcher] = [] combination: Expr | bool | None = None for rule in rules: - comparer: Comparable | FastComparable = rule.comparer - if isinstance(comparer, FastComparable): + comparer: DataFrameMatcher | ExprMatcher = rule.matcher + if isinstance(comparer, ExprMatcher): combination = combination & comparer() if combination is not None else comparer() - elif isinstance(comparer, Comparable): + elif isinstance(comparer, DataFrameMatcher): if combination is not None: combinations.append(combination) combination = None @@ -270,7 +270,7 @@ def filter(self, lst, sort_col="path") -> Iterable[str]: # noqa: A003 vdf: DataFrame = self.__df.filter(pl.col("path").is_in(lst)) combined = combine_exprs(self.rules) for f in combined: - vdf = f(vdf, self.__df) if isinstance(f, Comparable) else vdf.filter(f) + vdf = f(vdf, self.__df) if isinstance(f, DataFrameMatcher) else vdf.filter(f) return vdf.sort(sort_col).get_column("path") diff --git a/imdataset_creator/datarules/image_rules.py b/imdataset_creator/datarules/image_rules.py index 302caca..571c275 100644 --- a/imdataset_creator/datarules/image_rules.py +++ b/imdataset_creator/datarules/image_rules.py @@ -14,9 +14,9 @@ from ..configs.configtypes import SpecialItemData from .base_rules import ( - Comparable, DataColumn, - FastComparable, + DataFrameMatcher, + ExprMatcher, Producer, ProducerSchema, Rule, @@ -82,7 +82,7 @@ def __init__( if max_res: exprs.append((largest // scale * scale if crop else largest) <= max_res) - self.comparer = FastComparable(combine_expr_conds(exprs)) + self.matcher = ExprMatcher(combine_expr_conds(exprs)) @classmethod def get_cfg(cls) -> ResData: @@ -109,7 +109,7 @@ class ChannelRule(Rule): def __init__(self, min_channels=1, max_channels=4) -> None: super().__init__() self.requires = DataColumn("channels", int) - self.comparer = FastComparable((min_channels <= col("channels")) & (col("channels") <= max_channels)) + self.matcher = ExprMatcher((min_channels <= col("channels")) & (col("channels") <= max_channels)) def get_size(pth): @@ -166,10 +166,10 @@ def __init__(self, resolver: str | Literal["ignore_all"] = "ignore_all") -> None if resolver != "ignore_all": self.requires = (self.requires, DataColumn(resolver)) self.resolver: Expr | bool = {"ignore_all": False}.get(resolver, col(resolver) == col(resolver).max()) - self.comparer = Comparable(self.compare) + self.matcher = DataFrameMatcher(self.compare) def compare(self, partial: DataFrame, full: DataFrame) -> DataFrame: - return partial.groupby("hash").apply(lambda df: df.filter(self.resolver) if len(df) > 1 else df) + return partial.groupby("hash").apply(lambda group: group.filter(self.resolver) if len(group) > 1 else group) @classmethod def get_cfg(cls) -> dict: diff --git a/imdataset_creator/file.py b/imdataset_creator/file.py index 33cae61..515e87d 100644 --- a/imdataset_creator/file.py +++ b/imdataset_creator/file.py @@ -33,7 +33,7 @@ class File: file: MalleablePath ext: str - def to_dict(self): + def to_dict(self) -> dict[str, str | MalleablePath]: return { "absolute_pth": self.absolute_pth, "src": self.src,