Skip to content

Commit

Permalink
add compression filter, ...
Browse files Browse the repository at this point in the history
abstracted c_d more
fixed config emptying after an error
add an error to filters containing algos
  • Loading branch information
zeptofine committed Oct 7, 2023
1 parent 26dffbd commit 0eed2d7
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 121 deletions.
155 changes: 57 additions & 98 deletions create_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
# from __future__ import annotations

import logging
import os
from collections.abc import Callable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Pool, cpu_count, freeze_support
Expand All @@ -15,6 +16,7 @@
import ujson
from polars import DataFrame, concat
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import Progress
from typer import Option
from typing_extensions import Annotated
Expand All @@ -25,56 +27,35 @@
from src.configs import FilterData, MainConfig
from src.datarules.base_rules import File, Filter, Input, Output, Producer, Rule
from src.datarules.dataset_builder import ConfigHandler, DatasetBuilder, chunk_split
from src.file_list import get_file_list
from src.scenarios import FileScenario, OutputScenario

CPU_COUNT = int(cpu_count())
app = typer.Typer()
logging.basicConfig(level=logging.CRITICAL, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
app = typer.Typer(pretty_exceptions_show_locals=True, pretty_exceptions_short=True)
log = logging.getLogger()


@dataclass
class OutputScenario:
path: str
filters: dict[Filter, FilterData]
def read_image(path: str) -> np.ndarray:
return cv2.imread(path, cv2.IMREAD_UNCHANGED)


@dataclass
class FileScenario:
file: File
outputs: list[OutputScenario]
def get_outputs(file, outputs: Iterable[Output]):
return [
OutputScenario(str(pth), output.filters)
for output in outputs
if not (pth := output.folder / Path(output.format_file(file))).exists() or output.overwrite
]


def read_image(path: str) -> np.ndarray:
return cv2.imread(path, cv2.IMREAD_UNCHANGED)
def parse_files(files: Iterable[File], outputs: list[Output]) -> Generator[FileScenario, None, None]:
for file in files:
if out_s := get_outputs(file, outputs):
yield FileScenario(file, out_s)


def generate_scenarios(p: Progress, files: list[File], outputs: list[Output]) -> Generator[FileScenario, None, None]:
return (
FileScenario(file, outs)
for file in p.track(files, description="generating scenarios")
if ( # remove finished files
outs := [
OutputScenario(str(pth), output.filters)
for output in outputs
if not (pth := output.folder / Path(output.format_file(file))).exists() or output.overwrite
]
)
)


def parse_scenario(sc: FileScenario):
img: np.ndarray
original: np.ndarray

original = read_image(str(sc.file.absolute_pth))
mtime: os.stat_result = os.stat(str(sc.file.absolute_pth))
for output in sc.outputs:
img = original
for filter_, kwargs in output.filters.items():
img = filter_.run(img=img, **kwargs)
Path(output.path).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(output.path, img)
os.utime(output.path, (mtime.st_atime, mtime.st_mtime))
return sc
def gather_images(inputs: Iterable[Input]) -> Generator[tuple[Path, list[Path]], None, None]:
for input_ in inputs:
yield input_.folder, list(input_.run())


@app.command()
Expand All @@ -83,13 +64,29 @@ def main(
database_path: Annotated[Path, Option(help="Where the database is placed")] = Path("filedb.arrow"),
threads: Annotated[int, Option(help="multiprocessing threads")] = CPU_COUNT * 3 // 4,
chunksize: Annotated[int, Option(help="imap chunksize")] = 5,
pchunksize: Annotated[int, Option("-p", help="chunksize when populating the df")] = 100,
pinterval: Annotated[int, Option("-s", help="save interval in secs when populating the df")] = 60,
population_chunksize: Annotated[int, Option("-p", help="chunksize when populating the df")] = 100,
population_interval: Annotated[int, Option("-s", help="save interval in secs when populating the df")] = 60,
simulate: Annotated[bool, Option(help="stops before conversion")] = False,
verbose: Annotated[bool, Option(help="prints converted files")] = False,
sort_by: Annotated[str, Option(help="Which database column to sort by")] = "path",
) -> int:
"""Takes a crap ton of images and creates dataset pairs"""
if not config_path.exists():
log.error(f"{config_path} does not exist. create it in the gui and restart this program.")
return 0

with config_path.open("r") as f:
cfg: MainConfig = ujson.load(f)

db = DatasetBuilder(db_path=Path(database_path))

db_cfg = ConfigHandler(cfg)
inputs: list[Input] = db_cfg.inputs
outputs: list[Output] = db_cfg.outputs
producers: list[Producer] = db_cfg.producers
rules: list[Rule] = db_cfg.rules
db.add_rules(*rules)
db.add_producers(*producers)

c = Console(record=True)
with Progress(
Expand All @@ -101,29 +98,6 @@ def main(
progress.SpinnerColumn(),
console=c,
) as p:
if not config_path.exists():
p.log(f"{config_path} does not exist. create it in the gui and restart this program.")
return 0

with config_path.open("r") as f:
cfg: MainConfig = ujson.load(f)

db = DatasetBuilder(db_path=Path(database_path))

def check_for_images(lst: list) -> bool:
if not lst:
return False
return True

db_cfg = ConfigHandler(cfg)
inputs: list[Input] = db_cfg.inputs
outputs = db_cfg.outputs
producers = db_cfg.producers
rules = db_cfg.rules

db.add_rules(*rules)
db.add_producers(*producers)

if verbose:
p.log(
"inputs:",
Expand All @@ -140,31 +114,17 @@ def check_for_images(lst: list) -> bool:
# Gather images
images: dict[Path, list[Path]] = {}
count_t = p.add_task("Gathering", total=None)
folder_t = p.add_task("from folder", total=len(inputs))
for folder in inputs:
lst: list[Path] = []
for file in get_file_list(folder.folder, *folder.expressions):
lst.append(file)
p.advance(count_t)
images[folder.folder] = lst
p.advance(folder_t)
p.remove_task(folder_t)
for folder, lst in gather_images(inputs):
images[folder] = lst
p.update(count_t, advance=len(lst))

resolved: dict[str, File] = {
str((src / pth).resolve()): File(
absolute_pth=str(pth),
src=str(src),
relative_path=str(pth.relative_to(src).parent),
file=pth.stem,
ext=pth.suffix[pth.suffix[0] == "." :],
)
for src, lst in images.items()
for pth in lst
str((src / pth).resolve()): File.from_src(src, pth) for src, lst in images.items() for pth in lst
}
diff: int = sum(map(len, images.values())) - len(resolved)
if diff:
if diff := sum(map(len, images.values())) - len(resolved):
p.log(f"removed an estimated {diff} conflicting symlinks")
total_images = len(resolved)
p.update(count_t, total=total_images, completed=total_images)

p.update(count_t, total=len(resolved), completed=len(resolved))

total_t = p.add_task("populating df", total=None)
db.add_new_paths(set(resolved))
Expand All @@ -176,7 +136,7 @@ def check_for_images(lst: list) -> bool:
p.log(f"Skipping finished producers: {finished}")

def trigger_save(save_timer: datetime, collected: list[DataFrame]) -> tuple[datetime, list[DataFrame]]:
if ((new_time := datetime.now()) - save_timer).total_seconds() > pinterval:
if ((new_time := datetime.now()) - save_timer).total_seconds() > population_interval:
data: DataFrame = concat(collected, how="diagonal")
db.update(data)
db.save_df()
Expand All @@ -196,7 +156,7 @@ def trigger_save(save_timer: datetime, collected: list[DataFrame]) -> tuple[date
if verbose:
p.log(df)
p.log(schemas)
chunks = list(chunk_split(df, chunksize=pchunksize))
chunks = list(chunk_split(df, chunksize=population_chunksize))
p.update(chunk_t, total=len(chunks), completed=0)

for (_, size), chunk in chunks:
Expand Down Expand Up @@ -231,26 +191,25 @@ def trigger_save(save_timer: datetime, collected: list[DataFrame]) -> tuple[date
files: list[File] = [resolved[file] for file in db.filter(set(resolved))]
p.update(filter_t, total=len(files), completed=len(files))

scenarios = list(generate_scenarios(p, files, outputs))
scenarios = list(parse_files(p.track(files, description="parsing scenarios"), outputs))

if not check_for_images(scenarios):
if not scenarios:
p.log("Finished. No images remain.")
return 0
if simulate:
p.log(f"Simulated. {len(scenarios)} images remain.")
return 0
# # * convert files. Finally!

try:
with Pool(threads) as pool:
pool_t = p.add_task("parsing scenarios", total=len(scenarios))
for file in pool.imap(parse_scenario, scenarios, chunksize=chunksize):
execute_t = p.add_task("executing scenarios", total=len(scenarios))
for file in pool.imap(FileScenario.run, scenarios, chunksize=chunksize):
if verbose:
p.log(f"finished: {file}")
p.advance(pool_t)
p.advance(execute_t)
except KeyboardInterrupt:
print(-1, "KeyboardInterrupt")
return 1

return 0


Expand Down
24 changes: 10 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
[tool.poetry]
name = "dataset-creator"
version = "0.1.0"
description = ""
authors = ["zeptofine <[email protected]>"]
description = ""
name = "dataset-creator"
readme = "README.md"
version = "0.1.0"

[tool.poetry.dependencies]
python = "^3.10, <3.12"
ImageHash = "^4.3.1"
Pillow = "^10.0.1"
PySide6-Essentials = "^6.5.2"
ffmpeg-python = "^0.2.0"
imagesize = "^1.4.1"
numpy = "^1.26.0"
opencv-python = "^4.8.0.76"
Pillow = "^10.0.1"
polars = "^0.19.3"
pyarrow = "^13.0.0"
python = "^3.10, <3.12"
python-dateutil = "^2.8.2"
rich = "^13.5.3"
wcmatch = "^8.5"
typer = { extras = ["all"], version = "^0.9.0" }
PySide6-Essentials = "^6.5.2"
typer = {extras = ["all"], version = "^0.9.0"}
ujson = "^5.8.0"
ffmpeg-python = "^0.2.0"


[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.2"
wcmatch = "^8.5"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
cfg_param_wrapper>=1.0.1
imagehash
imagesize
numpy
Expand Down
4 changes: 4 additions & 0 deletions src/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from polars import DataFrame, DataType, Expr, PolarsDataType

from src.configs.configtypes import InputData, OutputData
from src.file_list import get_file_list

from ..configs import FilterData, Keyworded
from ..file import File
Expand Down Expand Up @@ -140,6 +141,9 @@ class Input(Keyworded):
def from_cfg(cls, cfg: InputData):
return cls(Path(cfg["folder"]), cfg["expressions"])

def run(self):
return get_file_list(self.folder, *self.expressions)


class InvalidFormatException(Exception):
def __init__(self, disallowed: str):
Expand Down
11 changes: 11 additions & 0 deletions src/file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from pathlib import Path


@dataclass
Expand All @@ -17,3 +18,13 @@ def to_dict(self):
"file": self.file,
"ext": self.ext,
}

@classmethod
def from_src(cls, src: Path, pth: Path):
return cls(
absolute_pth=str(pth),
src=str(src),
relative_path=str(pth.relative_to(src).parent),
file=pth.stem,
ext=pth.suffix[pth.suffix[0] == "." :],
)
3 changes: 3 additions & 0 deletions src/gui/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def get_config(self) -> ItemData:
def from_config(cls, cfg: ItemData, parent=None):
return cls(parent=parent)

def __repr__(self):
return f"{self.__class__.__name__}({self.cfg_name()})"


class FlowList(QGroupBox): # TODO: Better name lmao
n = Signal(int)
Expand Down
3 changes: 1 addition & 2 deletions src/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def get_config(self) -> MainConfig:
@catch_errors("Error saving")
@Slot()
def save_config(self):
cfg = self.get_config()
with self.cfg_path.open("w") as f:
cfg = self.get_config()
json.dump(cfg, f, indent=4)
print("saved", cfg)

Expand All @@ -192,7 +192,6 @@ def load_cfg(self, cfg: MainConfig):
self.producerlist.add_from_cfg(cfg["producers"])
self.rulelist.add_from_cfg(cfg["rules"])
self.outputlist.add_from_cfg(cfg["output"])
pass

@catch_gathering
@Slot()
Expand Down
Loading

0 comments on commit 0eed2d7

Please sign in to comment.