Skip to content

Commit

Permalink
rename classes
Browse files Browse the repository at this point in the history
Comparable to DataFrameMatcher and FastComparable to ExprMatcher
  • Loading branch information
zeptofine committed Oct 19, 2023
1 parent 198c68e commit 959230b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 26 deletions.
13 changes: 12 additions & 1 deletion imdataset_creator/datarules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from . import base_rules, data_rules, dataset_builder, image_rules
from .base_rules import Comparable, DataColumn, ExprDict, FastComparable, File, Filter, Input, Output, Producer, Rule
from .base_rules import (
DataColumn,
DataFrameMatcher,
ExprDict,
ExprMatcher,
File,
Filter,
Input,
Output,
Producer,
Rule,
)
from .dataset_builder import DatasetBuilder, chunk_split
21 changes: 16 additions & 5 deletions imdataset_creator/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,31 @@ def indent(t):
return textwrap.indent(t, " ")


class Comparable:
class DataFrameMatcher:
"""A class that filters sections of a dataframe based on a function"""

func: Callable[[PartialDataFrame, FullDataFrame], DataFrame]

def __init__(self, func: Callable[[PartialDataFrame, FullDataFrame], DataFrame]):
"""
Parameters
----------
func : Callable[[PartialDataFrame, FullDataFrame], DataFrame]
A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered
DataFrame as an output.
"""
self.func = func

def __call__(self, *args, **kwargs) -> DataFrame:
return self.func(*args, **kwargs)

def __repr__(self):
return f"Comparable({self.func.__name__})"
return f"DataFrameMatcher({self.func.__name__})"


class ExprMatcher:
"""A class that filters files based on an expression"""

class FastComparable:
expr: Expr

def __init__(self, expr: Expr):
Expand All @@ -52,7 +63,7 @@ def __call__(self) -> Expr:
return self.expr

def __repr__(self) -> str:
return f"FastComparable({self.expr})"
return f"ExprMatcher({self.expr})"


@dataclass(frozen=True)
Expand Down Expand Up @@ -102,7 +113,7 @@ class Rule(Keyworded):
"""An abstract DataFilter format, for use in DatasetBuilder."""

requires: DataColumn | tuple[DataColumn, ...]
comparer: Comparable | FastComparable
matcher: DataFrameMatcher | ExprMatcher

all_rules: ClassVar[dict[str, type[Rule]]] = {}

Expand Down
16 changes: 12 additions & 4 deletions imdataset_creator/datarules/data_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from polars import DataFrame, Datetime, Expr, col

from ..configs.configtypes import SpecialItemData
from .base_rules import Comparable, DataColumn, FastComparable, Producer, ProducerSchema, Rule, combine_expr_conds
from .base_rules import (
DataColumn,
DataFrameMatcher,
ExprMatcher,
Producer,
ProducerSchema,
Rule,
combine_expr_conds,
)

STAT_TRACKED = ("st_size", "st_atime", "st_mtime", "st_ctime")

Expand Down Expand Up @@ -79,7 +87,7 @@ def __init__(
if self.before is not None and self.after is not None and self.after > self.before:
raise self.AgeError(self.after, self.before)

self.comparer = FastComparable(combine_expr_conds(exprs))
self.matcher = ExprMatcher(combine_expr_conds(exprs))

@classmethod
def from_cfg(cls, cfg) -> Self:
Expand Down Expand Up @@ -123,7 +131,7 @@ def __init__(
if self.blacklist:
exprs.extend(col("path").str.contains(item).is_not() for item in self.blacklist)

self.comparer = FastComparable(combine_expr_conds(exprs))
self.matcher = ExprMatcher(combine_expr_conds(exprs))

@classmethod
def get_cfg(cls) -> BlackWhitelistData:
Expand All @@ -134,7 +142,7 @@ class TotalLimitRule(Rule):
def __init__(self, limit=1000):
super().__init__()
self.total = limit
self.comparer = Comparable(self.compare)
self.matcher = DataFrameMatcher(self.compare)

def compare(self, selected: DataFrame, _) -> DataFrame:
return selected.head(self.total)
18 changes: 9 additions & 9 deletions imdataset_creator/datarules/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from polars.type_aliases import SchemaDefinition

from .base_rules import (
Comparable,
DataFrameMatcher,
DataTypeSchema,
ExprDict,
FastComparable,
ExprMatcher,
Producer,
ProducerSchema,
ProducerSet,
Expand Down Expand Up @@ -47,17 +47,17 @@ def chunk_split(
)


def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | Comparable]:
def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | DataFrameMatcher]:
"""this combines expressions from different objects to a list of compressed expressions.
DataRules that are FastComparable can be combined, but `Comparable`s cannot. They will be copied to the list.
DataRules that are `ExprComparer`s can be combined, but `DataFrameComparer`s cannot. They will be copied to the list.
"""
combinations: list[Expr | bool | Comparable] = []
combinations: list[Expr | bool | DataFrameMatcher] = []
combination: Expr | bool | None = None
for rule in rules:
comparer: Comparable | FastComparable = rule.comparer
if isinstance(comparer, FastComparable):
comparer: DataFrameMatcher | ExprMatcher = rule.matcher
if isinstance(comparer, ExprMatcher):
combination = combination & comparer() if combination is not None else comparer()
elif isinstance(comparer, Comparable):
elif isinstance(comparer, DataFrameMatcher):
if combination is not None:
combinations.append(combination)
combination = None
Expand Down Expand Up @@ -270,7 +270,7 @@ def filter(self, lst, sort_col="path") -> Iterable[str]: # noqa: A003
vdf: DataFrame = self.__df.filter(pl.col("path").is_in(lst))
combined = combine_exprs(self.rules)
for f in combined:
vdf = f(vdf, self.__df) if isinstance(f, Comparable) else vdf.filter(f)
vdf = f(vdf, self.__df) if isinstance(f, DataFrameMatcher) else vdf.filter(f)

return vdf.sort(sort_col).get_column("path")

Expand Down
12 changes: 6 additions & 6 deletions imdataset_creator/datarules/image_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from ..configs.configtypes import SpecialItemData
from .base_rules import (
Comparable,
DataColumn,
FastComparable,
DataFrameMatcher,
ExprMatcher,
Producer,
ProducerSchema,
Rule,
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
if max_res:
exprs.append((largest // scale * scale if crop else largest) <= max_res)

self.comparer = FastComparable(combine_expr_conds(exprs))
self.matcher = ExprMatcher(combine_expr_conds(exprs))

@classmethod
def get_cfg(cls) -> ResData:
Expand All @@ -109,7 +109,7 @@ class ChannelRule(Rule):
def __init__(self, min_channels=1, max_channels=4) -> None:
super().__init__()
self.requires = DataColumn("channels", int)
self.comparer = FastComparable((min_channels <= col("channels")) & (col("channels") <= max_channels))
self.matcher = ExprMatcher((min_channels <= col("channels")) & (col("channels") <= max_channels))


def get_size(pth):
Expand Down Expand Up @@ -166,10 +166,10 @@ def __init__(self, resolver: str | Literal["ignore_all"] = "ignore_all") -> None
if resolver != "ignore_all":
self.requires = (self.requires, DataColumn(resolver))
self.resolver: Expr | bool = {"ignore_all": False}.get(resolver, col(resolver) == col(resolver).max())
self.comparer = Comparable(self.compare)
self.matcher = DataFrameMatcher(self.compare)

def compare(self, partial: DataFrame, full: DataFrame) -> DataFrame:
return partial.groupby("hash").apply(lambda df: df.filter(self.resolver) if len(df) > 1 else df)
return partial.groupby("hash").apply(lambda group: group.filter(self.resolver) if len(group) > 1 else group)

@classmethod
def get_cfg(cls) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion imdataset_creator/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class File:
file: MalleablePath
ext: str

def to_dict(self):
def to_dict(self) -> dict[str, str | MalleablePath]:
return {
"absolute_pth": self.absolute_pth,
"src": self.src,
Expand Down

0 comments on commit 959230b

Please sign in to comment.