Skip to content

Commit

Permalink
more types, docstrings, simplify, optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Oct 23, 2023
1 parent 846d6fa commit 4148204
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 111 deletions.
21 changes: 8 additions & 13 deletions imdataset_creator/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import time
from datetime import datetime
from multiprocessing import Pool, cpu_count, freeze_support
from pathlib import Path
Expand All @@ -22,7 +21,6 @@
File,
FileScenario,
MainConfig,
alphanumeric_sort,
chunk_split,
)

Expand Down Expand Up @@ -78,27 +76,24 @@ def main(
if verbose:
p.log(pformat(db_cfg))
# Gather images
images: dict[Path, list[Path]] = {}
resolved: dict[str, File] = {}
count_t = p.add_task("Gathering", total=None)
for folder, gen in db_cfg.gather_images():
images[folder] = sorted(gen, key=lambda x: alphanumeric_sort(str(x)), reverse=True)
for pth in images[folder]:

for folder, lst in db_cfg.gather_images(sort=True, reverse=True):
for pth in lst:
resolved[str((folder / pth).resolve())] = File.from_src(folder, pth)
p.update(count_t, advance=len(images[folder]))
p.update(count_t, advance=len(lst))

if diff := sum(map(len, images.values())) - len(resolved):
if diff := p.tasks[count_t].completed - len(resolved):
p.log(f"removed an estimated {diff} conflicting symlinks")

p.update(count_t, total=len(resolved), completed=len(resolved))

total_t = p.add_task("populating df", total=None)
db.add_new_paths(set(resolved))
db_schema = db.type_schema

db.comply_to_schema(db_schema)
unfinished: DataFrame = db.get_unfinished_existing()
db.comply_to_schema(db.type_schema, in_place=True)

unfinished: DataFrame = db.get_unfinished_existing().collect()
if not unfinished.is_empty():
p.update(total_t, total=len(unfinished))
null_t = p.add_task("nulls set", total=None)
Expand Down Expand Up @@ -152,7 +147,7 @@ def main(
files: list[File]
if db_cfg.rules:
filter_t = p.add_task("filtering", total=0)
files = [resolved[file] for file in db.filter(set(resolved))]
files = [resolved[file] for file in db.filter(set(resolved)).get_column("path")]
p.update(filter_t, total=len(files), completed=len(files))
else:
files = list(resolved.values())
Expand Down
21 changes: 19 additions & 2 deletions imdataset_creator/config_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections.abc import Generator, Iterable
from pathlib import Path
from typing import overload

from .alphanumeric_sort import alphanumeric_sort
from .configs import MainConfig, _repr_indent
from .datarules import Input, Output, Producer, Rule
from .datarules.base_rules import PathGenerator
from .file import File
from .scenarios import FileScenario, OutputScenario

Expand All @@ -21,9 +24,23 @@ def __init__(self, cfg: MainConfig):
# generate `Rule`s
self.rules: list[Rule] = [Rule.all_rules[r["name"]].from_cfg(r["data"]) for r in cfg["rules"]]

def gather_images(self) -> Generator[tuple[Path, Generator[Path, None, None]], None, None]:
@overload
def gather_images(self, sort=True, reverse=False) -> Generator[tuple[Path, list[Path]], None, None]:
...

@overload
def gather_images(self, sort=False, reverse=False) -> Generator[tuple[Path, PathGenerator], None, None]:
...

def gather_images(
self, sort=False, reverse=False
) -> Generator[tuple[Path, PathGenerator | list[Path]], None, None]:
for input_ in self.inputs:
yield input_.folder, input_.run()
gen = input_.run()
if sort:
yield input_.folder, list(map(Path, sorted(map(str, gen), key=alphanumeric_sort, reverse=reverse)))
else:
yield input_.folder, gen

def get_outputs(self, file: File) -> list[OutputScenario]:
return [
Expand Down
29 changes: 16 additions & 13 deletions imdataset_creator/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import textwrap
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
from string import Formatter
from types import MappingProxyType
from typing import Any, ClassVar
from typing import Any, ClassVar, Iterable

import numpy as np
import wcmatch.glob as wglob
Expand Down Expand Up @@ -36,10 +36,8 @@ class DataFrameMatcher:

def __init__(self, func: Callable[[PartialDataFrame, FullDataFrame], DataFrame]):
"""
Parameters
----------
func : Callable[[PartialDataFrame, FullDataFrame], DataFrame]
A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered
Args:
func (Callable[[PartialDataFrame, FullDataFrame], DataFrame]): A function that takes a DataFrame and a Dataframe with more information as input, and returns a filtered
DataFrame as an output.
"""
self.func = func
Expand All @@ -56,8 +54,8 @@ class ExprMatcher:

expr: Expr

def __init__(self, expr: Expr):
self.expr = expr
def __init__(self, *exprs: Expr):
self.expr = combine_expr_conds(exprs)

def __call__(self) -> Expr:
return self.expr
Expand Down Expand Up @@ -146,6 +144,9 @@ def run(self, img: np.ndarray) -> np.ndarray:
flags: int = wglob.BRACE | wglob.SPLIT | wglob.EXTMATCH | wglob.IGNORECASE | wglob.GLOBSTAR


PathGenerator = Generator[Path, None, None]


@dataclass(repr=False)
class Input(Keyworded):
"""
Expand All @@ -166,7 +167,7 @@ class Input(Keyworded):
def from_cfg(cls, cfg: InputData):
return cls(Path(cfg["folder"]), cfg["expressions"])

def run(self):
def run(self) -> PathGenerator:
"""
Yield the paths of all files in the input folder that match the glob patterns.
Expand Down Expand Up @@ -268,7 +269,7 @@ def from_cfg(cls, cfg: OutputData):
)


def combine_expr_conds(exprs: list[Expr]) -> Expr:
def combine_expr_conds(exprs: Iterable[Expr]) -> Expr:
"""
Combine a list of `Expr` objects using the `&` operator.
Expand All @@ -282,7 +283,9 @@ def combine_expr_conds(exprs: list[Expr]) -> Expr:
Expr
A single `Expr` object representing the combined expression.
"""
comp: Expr = exprs[0]
for e in exprs[1:]:
comp &= e
comp: Expr | None = None
for e in exprs:
comp = comp & e if comp is None else e
assert comp is not None

return comp
4 changes: 2 additions & 2 deletions imdataset_creator/datarules/data_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
if self.before is not None and self.after is not None and self.after > self.before:
raise self.AgeError(self.after, self.before)

self.matcher = ExprMatcher(combine_expr_conds(exprs))
self.matcher = ExprMatcher(*exprs)

@classmethod
def from_cfg(cls, cfg) -> Self:
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
if self.blacklist:
exprs.extend(col("path").str.contains(item).is_not() for item in self.blacklist)

self.matcher = ExprMatcher(combine_expr_conds(exprs))
self.matcher = ExprMatcher(*exprs)

@classmethod
def get_cfg(cls) -> BlackWhitelistData:
Expand Down
77 changes: 44 additions & 33 deletions imdataset_creator/datarules/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Generator, Literal, TypeVar, overload

import polars as pl
from polars import DataFrame, Expr
from polars import DataFrame, Expr, LazyFrame
from polars.type_aliases import SchemaDefinition

from .base_rules import (
Expand Down Expand Up @@ -39,6 +39,17 @@ def chunk_split(
chunksize: int,
col_name: str = "_idx",
) -> Generator[DataFrame, None, None]:
"""
Splits a dataframe into chunks based on index.
Args:
df (DataFrame): the dataframe to split
chunksize (int): the size of each resulting chunk
col_name (str, optional): the name of the temporary chunk. Defaults to "_idx".
Yields:
Generator[DataFrame, None, None]: a chunk
"""
return (
part.drop(col_name)
for _, part in df.with_row_count(col_name)
Expand All @@ -47,25 +58,23 @@ def chunk_split(
)


def combine_exprs(rules: Iterable[Rule]) -> list[Expr | bool | DataFrameMatcher]:
"""this combines expressions from different objects to a list of compressed expressions.
DataRules that are `ExprComparer`s can be combined, but `DataFrameComparer`s cannot. They will be copied to the list.
def combine_matchers(
matchers: Iterable[DataFrameMatcher | ExprMatcher],
) -> Generator[Expr | DataFrameMatcher, None, None]:
"""this combines expressions from different matchers to compressed expressions.
DataRules that are `ExprMatcher`s can be combined, but `DataFrameMatcher`s cannot. They will be copied to the list.
"""
combinations: list[Expr | bool | DataFrameMatcher] = []
combination: Expr | bool | None = None
for rule in rules:
comparer: DataFrameMatcher | ExprMatcher = rule.matcher
if isinstance(comparer, ExprMatcher):
combination = combination & comparer() if combination is not None else comparer()
elif isinstance(comparer, DataFrameMatcher):
for matcher in matchers:
if isinstance(matcher, DataFrameMatcher):
if combination is not None:
combinations.append(combination)
yield combination
combination = None
combinations.append(comparer)
yield matcher
elif isinstance(matcher, ExprMatcher):
combination = combination & matcher() if combination is not None else matcher()
if combination is not None:
combinations.append(combination)

return combinations
yield combination


def blacklist_schema(schema: ProducerSchema, blacklist: Collection) -> ProducerSchema:
Expand Down Expand Up @@ -185,7 +194,15 @@ def remove_finished_producers(self) -> ProducerSet:
self.producers = new
return ProducerSet(old - new)

@overload
def unfinished_by_col(self, df: DataFrame, cols: Iterable[str] | None = None) -> DataFrame:
...

@overload
def unfinished_by_col(self, df: LazyFrame, cols: Iterable[str] | None = None) -> LazyFrame:
...

def unfinished_by_col(self, df: DataFrame | LazyFrame, cols: Iterable[str] | None = None) -> DataFrame | LazyFrame:
if cols is None:
cols = set(self.type_schema) & set(df.columns)
return df.filter(pl.any_horizontal(pl.col(col).is_null() for col in cols))
Expand Down Expand Up @@ -232,7 +249,7 @@ def populate_chunks(
chunks: Iterable[DataFrame],
schemas: ProducerSchema,
db_schema: DataTypeSchema | None = None,
):
) -> Generator[DataFrame, None, None]:
if db_schema is None:
db_schema = self.type_schema
chunk: DataFrame
Expand All @@ -248,37 +265,31 @@ def df_with_types(self, types: DataTypeSchema | None = None):
types = self.type_schema
return self.comply_to_schema(types).with_columns(types)

def get_unfinished(self) -> DataFrame:
def get_unfinished(self) -> LazyFrame:
# check if producers are completely finished
type_schema: DataTypeSchema = self.type_schema
self.comply_to_schema(type_schema, in_place=True)
return self.unfinished_by_col(self.__df.with_columns(type_schema))
return self.unfinished_by_col(self.__df.lazy().with_columns(type_schema))

def get_unfinished_existing(self) -> DataFrame:
unfinished = self.get_unfinished()
if not len(unfinished):
return unfinished
return unfinished.filter(pl.col("path").apply(os.path.exists))
def get_unfinished_existing(self) -> LazyFrame:
return self.get_unfinished().filter(pl.col("path").apply(os.path.exists))

def filter(self, lst, sort_col="path") -> Iterable[str]: # noqa: A003
assert sort_col in self.__df.columns, f"'{sort_col}' is not in {self.__df.columns}"
def filter(self, lst) -> DataFrame: # noqa: A003
if len(self.unready_rules):
warnings.warn(
f"{len(self.unready_rules)} filters are not initialized and will not be populated", stacklevel=2
)

vdf: DataFrame = self.__df.filter(pl.col("path").is_in(lst))
combined = combine_exprs(self.rules)
for f in combined:
vdf = f(vdf, self.__df) if isinstance(f, DataFrameMatcher) else vdf.filter(f)

return vdf.sort(sort_col).get_column("path")
for matcher in combine_matchers([rule.matcher for rule in self.rules]):
vdf = matcher(vdf, self.__df) if isinstance(matcher, DataFrameMatcher) else vdf.filter(matcher)
return vdf

def save_df(self, pth: str | Path | None = None) -> None:
"""saves the dataframe to self.filepath"""
self.__df.write_ipc(pth or self.filepath)

def update(self, df: DataFrame, on="path", how: Literal["left", "inner"] = "left"):
def update(self, df: DataFrame, on="path", how: Literal["left", "inner", "outer"] = "left"):
self.__df = self.__df.update(df, on=on, how=how)

def trigger_save_via_time(
Expand Down Expand Up @@ -316,11 +327,11 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(attrlist)})"

@overload
def comply_to_schema(self, schema: SchemaDefinition, in_place: Literal[False] = False) -> DataFrame:
def comply_to_schema(self, schema: SchemaDefinition) -> DataFrame:
...

@overload
def comply_to_schema(self, schema: SchemaDefinition, in_place: Literal[True] = True) -> None:
def comply_to_schema(self, schema: SchemaDefinition, in_place=True) -> None:
...

def comply_to_schema(self, schema: SchemaDefinition, in_place: bool = False) -> DataFrame | None:
Expand Down
2 changes: 1 addition & 1 deletion imdataset_creator/datarules/image_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
if max_res:
exprs.append((largest // scale * scale if crop else largest) <= max_res)

self.matcher = ExprMatcher(combine_expr_conds(exprs))
self.matcher = ExprMatcher(*exprs)

@classmethod
def get_cfg(cls) -> ResData:
Expand Down
23 changes: 0 additions & 23 deletions imdataset_creator/file_list.py

This file was deleted.

Loading

0 comments on commit 4148204

Please sign in to comment.