Skip to content

Commit

Permalink
another rewrite of config handling
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Aug 12, 2023
1 parent 3db5310 commit 6f2a384
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 273 deletions.
35 changes: 25 additions & 10 deletions create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,25 @@ def main(
help="Which column in the database to sort by. It must be in the database.", rich_help_panel="modifiers"
),
] = "path",
stat: Annotated[bool, Option("--stat", "-s", help="use statfilter", rich_help_panel="filters")] = False,
res: Annotated[bool, Option("--res", "-r", help="use resfilter", rich_help_panel="filters")] = False,
hsh: Annotated[bool, Option("--hash", "-h", help="use hashfilter", rich_help_panel="filters")] = False,
chn: Annotated[bool, Option("--channel", "-c", help="use channelfilter", rich_help_panel="filters")] = False,
blw: Annotated[
bool, Option("--blackwhitelist", "-b", help="use blacknwhitelistfilter", rich_help_panel="filters")
] = False,
) -> int:
"""Does all the heavy lifting"""
s: RichStepper = RichStepper(loglevel=1, step=-1)
s.next("Settings: ")

db = DatasetBuilder(origin=str(input_folder), config_path=config_path)
if not config_path.exists():
db.add_filters([StatFilter, ResFilter, HashFilter, ChannelFilter, BlacknWhitelistFilter])
db.generate_config().save()
print(f"{config_path} created. edit it and restart this program.")
return 0
db.config.load()

def check_for_images(image_list: list[Path]) -> bool:
if not image_list:
Expand Down Expand Up @@ -197,17 +210,19 @@ def hrlr_pair(path: Path) -> tuple[Path, Path | None]:

return hr_path, lr_path

stat = StatFilter()
res = ResFilter()
hsh = HashFilter()
chn = ChannelFilter()
blw = BlacknWhitelistFilter()
db.add_filters(stat, res, hsh, chn, blw)

filters = []
if stat:
filters.append(StatFilter)
if res:
s.print(f"Filtering by size ({res.min} <= x <= {res.max})")
filters.append(ResFilter)
if hsh:
filters.append(HashFilter)
if chn:
filters.append(ChannelFilter)
if blw:
s.print(f"Whitelist: {blw.whitelist}", f"Blacklist: {blw.blacklist}")
filters.append(BlacknWhitelistFilter)
db.add_filters(filters)
db.fill_from_config(db.config)

# * Gather images
s.next("Gathering images...")
Expand Down Expand Up @@ -246,7 +261,7 @@ def hrlr_pair(path: Path) -> tuple[Path, Path | None]:
folders: list[Path] = [hr_folder]
if make_lr:
folders.append(lr_folder)
db.add_filters(ExistingFilter(*folders, recurse_func=recurse))
db.add_filter(ExistingFilter(folders, recurse_func=recurse))

# * Run filters
s.next("Using: ")
Expand Down
34 changes: 13 additions & 21 deletions database_config.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
trim = true
trim_age_limit_secs = 604800
trim_check_exists = true
save_interval_secs = 20
chunksize = 400
save_interval_secs = 60
chunksize = 600
filepath = "filedb.feather"

[stats]
enabled = false
before = "2040" # only get items before this threshold
after = "2010" # only get items after this threshold

before = "2100" #Only get items before this threshold
after = "1980" #Only get items after this threshold
[resolution]
enabled = true
min = 128
min = 0
max = 2048
crop = false #if true, then it will check if it is valid after cropping a little to be divisible by scale
crop = false #if true, then it will check if it is valid after cropping a little to be divisible by scale
scale = 4

[hashing]
enabled = true
hasher = "average" # average | crop_resistant | color | dhash | dhash_vertical | phash | phash_simple | whash | whash-db4
resolver = "ignore_all" # ignore_all | newest | oldest | size

hasher = "average" #average | color | crop_resistant | dhash | dhash_vertical | phash | phash_simple | whash | whash_db4
resolver = "ignore_all" #ignore_all | newest | oldest | size
[channels]
enabled = false
channel_num = 3
strict = false # if true, only images with {channel_num} channels are available
min_channels = 1
max_channels = 4

[blackwhitelists]
enabled = false
whitelist = ["safe"] # files with these strings are filtered in
all_whitelists_are_true = true # allow files that are valid to __every__ whitelist string
blacklist = ["explicit"] # items with these strings are filtered out
whitelist = []
blackist = []
exclusive = false #Only allow files that are valid by every whitelist string
41 changes: 30 additions & 11 deletions manage_dataframe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
"import polars as pl\n",
"from tqdm import tqdm\n",
"\n",
"from dataset_filters.dataset_builder import DatasetBuilder\n",
"from dataset_filters import DataFilter"
"from src.datafilters.dataset_builder import DatasetBuilder\n",
"from src.datafilters.base_filters import DataFilter"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Read dataframe\n"
"filename = \"filedb.feather\"\n",
"df = pl.read_ipc(filename)\n",
"df"
]
},
{
Expand All @@ -39,7 +43,10 @@
"metadata": {},
"outputs": [],
"source": [
"filename = \"filedb.feather\""
"# shuffle dataframe\n",
"def shuffle(df: pl.DataFrame, rand_col=\"rnd\") -> pl.DataFrame:\n",
" return df.with_columns(pl.Series(rand_col, [random.random() for _ in range(len(df))])).sort(rand_col).drop(rand_col)\n",
"df = shuffle(df)"
]
},
{
Expand All @@ -48,7 +55,21 @@
"metadata": {},
"outputs": [],
"source": [
"df = pl.read_ipc(filename)"
"# delete random items in dataframe\n",
"thresh = 0.9\n",
"import random\n",
"\n",
"\n",
"def rnd(_):\n",
" return random.random()\n",
"\n",
"\n",
"def drop_rand(df: pl.DataFrame, exclude: list[str], thresh: float = 0.9) -> pl.DataFrame:\n",
" new = df.select(pl.when(pl.all().exclude(*exclude).apply(rnd, skip_nulls=False) < thresh).then(pl.all()))\n",
" return new.with_columns(df.select(*exclude)).select(df.columns)\n",
"\n",
"\n",
"df = drop_rand(df, [\"path\", \"checkedtime\"], 0.5)"
]
},
{
Expand All @@ -57,9 +78,8 @@
"metadata": {},
"outputs": [],
"source": [
"# shuffle dataframe\n",
"def shuffle(df: pl.DataFrame, rand_col=\"rnd\") -> pl.DataFrame:\n",
" return df.with_columns(pl.Series(rand_col, [random.random() for _ in range(len(df))])).sort(rand_col).drop(rand_col)"
"# save\n",
"df.write_ipc(filename)"
]
},
{
Expand All @@ -68,8 +88,7 @@
"metadata": {},
"outputs": [],
"source": [
"# save shuffled\n",
"shuffle(df).write_ipc(filename)"
"df"
]
},
{
Expand Down
75 changes: 39 additions & 36 deletions src/datafilters/base_filters.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import inspect
import sys
from abc import abstractmethod
from collections.abc import Collection
from dataclasses import dataclass
from enum import EnumType, Enum
from pathlib import Path
from typing import Any
from typing import Any, Self

from polars import DataFrame, Expr, PolarsDataType


class Comparable:
@abstractmethod
def compare(self, lst: Collection[Path], cols: DataFrame) -> list:
"""Uses collected data to return a new list of only valid images, depending on what the filter does."""
"""Uses all collected data to return a new list of only valid images, depending on what the filter does."""
raise NotImplementedError


class FastComparable:
@abstractmethod
def fast_comp(self) -> Expr | bool:
"""Returns an Expr that can be used to filter more efficiently, in Rust"""
"""Returns an Expr that can be used to filter more efficiently"""
raise NotImplementedError


@dataclass
@dataclass(frozen=True)
class Column:
"""A class defining what is in a column which a filter may use to apply a"""

Expand All @@ -33,47 +36,47 @@ class Column:
class DataFilter:
"""An abstract DataFilter format, for use in DatasetBuilder."""

config_keyword: str

def __init__(self) -> None:
"""
filedict: dict[str, Path]
This is filled from the dataset builder, and contains a dictionary going from the resolved versions of
the files to the ones given from the user.
column_schema: dict[str, PolarsDataType | type]
This is used to add a column using names and types to the file database.
These *must* be filled by the build_schema.
build_schema: dict[str, Expr]
This is used to build the data given in the column_schema.
config: tuple[str | None, dict[str, Any]]
This is used to populate self attributes from the database's config file.
The string represents the section of the config that belongs to this filter.
If none, it's disabled.
the dictionary __must__ include an "enabled" flag. It represents whether the filter is enabled by default.
"""
self.schema: tuple[Column] = tuple()
self.filedict: dict[str, Path] = {} # used for certain filters, like Existing
self.schema: list[Column] = []
self.config: tuple[str | None, dict[str, Any]] = (None, {"enabled": False})
self.__enabled = False

def enable(self):
self.__enabled = True

def get_config(self) -> tuple[str | None, dict[str, Any]]:
return self.config

def populate_from_cfg(self, dct: dict[str, Any]):
for key, val in dct.items():
if key not in ("filedict", "column_schema", "build_schema", "config", "config_triggers"):
setattr(self, key, val)

def __bool__(self):
return self.__enabled
@classmethod
def from_cfg(cls, *args, **kwargs) -> Self:
return cls(*args, **kwargs) # type: ignore

@classmethod
def get_cfg(cls) -> dict:
cfg: dict[str, Any] = {}
module = sys.modules[cls.__module__]
for key, val in list(inspect.signature(cls.__init__).parameters.items())[1:]:
if issubclass(type(val.default), Enum):
cfg[key] = val.default.value
else:
cfg[key] = val.default
if val.annotation is not inspect._empty:
annotation = eval(val.annotation, module.__dict__)
comment = DataFilter._obj_to_comment(annotation)
if comment:
cfg[f"!#{key}"] = comment

return cfg

@staticmethod
def _obj_to_comment(obj) -> str:
if type(obj) is EnumType:
return " | ".join(obj._member_map_.values()) # type: ignore
elif hasattr(obj, "__metadata__"):
return str(obj.__metadata__[0])

return ""

def __repr__(self) -> str:
attrlist: list[str] = [
Expand Down
Loading

0 comments on commit 6f2a384

Please sign in to comment.