Skip to content

Commit

Permalink
rename Column class to DataColumn
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Oct 19, 2023
1 parent def9806 commit 198c68e
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion imdataset_creator/datarules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import base_rules, data_rules, dataset_builder, image_rules
from .base_rules import ExprDict, File, Filter, Input, Output, Producer, Rule
from .base_rules import Comparable, DataColumn, ExprDict, FastComparable, File, Filter, Input, Output, Producer, Rule
from .dataset_builder import DatasetBuilder, chunk_split
4 changes: 2 additions & 2 deletions imdataset_creator/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __repr__(self) -> str:


@dataclass(frozen=True)
class Column:
class DataColumn:
"""A class defining a column that a filter may need"""

name: str
Expand Down Expand Up @@ -101,7 +101,7 @@ def type_schema(self) -> DataTypeSchema:
class Rule(Keyworded):
"""An abstract DataFilter format, for use in DatasetBuilder."""

requires: Column | tuple[Column, ...]
requires: DataColumn | tuple[DataColumn, ...]
comparer: Comparable | FastComparable

all_rules: ClassVar[dict[str, type[Rule]]] = {}
Expand Down
4 changes: 2 additions & 2 deletions imdataset_creator/datarules/data_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from polars import DataFrame, Datetime, Expr, col

from ..configs.configtypes import SpecialItemData
from .base_rules import Column, Comparable, FastComparable, Producer, ProducerSchema, Rule, combine_expr_conds
from .base_rules import Comparable, DataColumn, FastComparable, Producer, ProducerSchema, Rule, combine_expr_conds

STAT_TRACKED = ("st_size", "st_atime", "st_mtime", "st_ctime")

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
after: str | datetime | None = None,
) -> None:
super().__init__()
self.requires = Column("mtime", Datetime("ms"))
self.requires = DataColumn("mtime", Datetime("ms"))

exprs: list[Expr] = []

Expand Down
20 changes: 14 additions & 6 deletions imdataset_creator/datarules/image_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from polars import DataFrame, Expr, col

from ..configs.configtypes import SpecialItemData
from .base_rules import Column, Comparable, ExprDict, FastComparable, Producer, ProducerSchema, Rule, combine_expr_conds
from .base_rules import (
Comparable,
DataColumn,
FastComparable,
Producer,
ProducerSchema,
Rule,
combine_expr_conds,
)


def whash_db4(img) -> imagehash.ImageHash:
Expand Down Expand Up @@ -60,8 +68,8 @@ def __init__(
) -> None:
super().__init__()
self.requires = (
Column("width", int),
Column("height", int),
DataColumn("width", int),
DataColumn("height", int),
)

smallest = pl.min_horizontal(col("width"), col("height"))
Expand Down Expand Up @@ -100,7 +108,7 @@ class ChannelRule(Rule):

def __init__(self, min_channels=1, max_channels=4) -> None:
super().__init__()
self.requires = Column("channels", int)
self.requires = DataColumn("channels", int)
self.comparer = FastComparable((min_channels <= col("channels")) & (col("channels") <= max_channels))


Expand Down Expand Up @@ -154,9 +162,9 @@ class HashRule(Rule):
def __init__(self, resolver: str | Literal["ignore_all"] = "ignore_all") -> None:
super().__init__()

self.requires = Column("hash", str)
self.requires = DataColumn("hash", str)
if resolver != "ignore_all":
self.requires = (self.requires, Column(resolver))
self.requires = (self.requires, DataColumn(resolver))
self.resolver: Expr | bool = {"ignore_all": False}.get(resolver, col(resolver) == col(resolver).max())
self.comparer = Comparable(self.compare)

Expand Down
2 changes: 1 addition & 1 deletion imdataset_creator/gui/rule_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __wrap_get(cls: type[RuleView]):
def get_wrapper(self: RuleView):
rule = original_get(self)
if rule.requires:
if isinstance(rule.requires, base_rules.Column):
if isinstance(rule.requires, base_rules.DataColumn):
self.set_requires(str({rule.requires.name}))
else:
self.set_requires(str(set({r.name for r in rule.requires})))
Expand Down

0 comments on commit 198c68e

Please sign in to comment.