Skip to content

Commit

Permalink
split config creation to dataset_builder
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Oct 2, 2023
1 parent b044506 commit 3938854
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 38 deletions.
40 changes: 9 additions & 31 deletions create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(
from rich.console import Console
from rich.progress import Progress

from src.datarules.dataset_builder import DatasetBuilder, chunk_split
from src.datarules.dataset_builder import ConfigHandler, DatasetBuilder, chunk_split
from src.file_list import get_file_list

c = Console(record=True)
Expand All @@ -105,33 +105,11 @@ def check_for_images(lst: list) -> bool:
return False
return True

# generate `Input`s
inputs: list[Input] = [
Input(
Path(folder["data"]["folder"]),
folder["data"]["expressions"],
)
for folder in cfg["inputs"]
]

# generate `Output`s
outputs: list[Output] = [
Output(
Path(folder["data"]["path"]),
{Filter.all_filters[filter_["name"]]: filter_["data"] for filter_ in folder["data"]["lst"]},
folder["data"]["overwrite"],
folder["data"]["output_format"],
)
for folder in cfg["output"]
]
for output in outputs:
output.path.mkdir(parents=True, exist_ok=True)

# generate `Producer`s
producers: list[Producer] = [Producer.all_producers[p["name"]].from_cfg(p["data"]) for p in cfg["producers"]]

# generate `Rule`s
rules: list[Rule] = [Rule.all_rules[r["name"]].from_cfg(r["data"]) for r in cfg["rules"]]
db_cfg = ConfigHandler(cfg)
inputs: list[Input] = db_cfg.inputs
outputs = db_cfg.outputs
producers = db_cfg.producers
rules = db_cfg.rules

db.add_rules(*rules)
db.add_producers(*producers)
Expand All @@ -155,10 +133,10 @@ def check_for_images(lst: list) -> bool:
folder_t = p.add_task("from folder", total=len(inputs))
for folder in inputs:
lst: list[Path] = []
for file in get_file_list(folder.path, *folder.expressions):
for file in get_file_list(folder.folder, *folder.expressions):
lst.append(file)
p.advance(count_t)
images[folder.path] = lst
images[folder.folder] = lst
p.advance(folder_t)
p.remove_task(folder_t)
resolved: dict[str, File] = {
Expand Down Expand Up @@ -254,7 +232,7 @@ def trigger_save(save_timer: datetime, collected: list[DataFrame]) -> tuple[date
output.filters,
)
for output in outputs
if not (pth := output.path / Path(output.format_file(file))).exists() or output.overwrite
if not (pth := output.folder / Path(output.format_file(file))).exists() or output.overwrite
]
)
]
Expand Down
2 changes: 1 addition & 1 deletion src/configs/configtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FilterData(SpecialItemData):


class OutputData(SpecialItemData):
path: str
folder: str
lst: list[ItemConfig[FilterData]]
output_format: str
overwrite: bool
Expand Down
21 changes: 18 additions & 3 deletions src/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from polars import DataFrame, DataType, Expr, PolarsDataType

from src.configs.configtypes import InputData, OutputData

from ..configs import FilterData, Keyworded
from ..file import File

Expand Down Expand Up @@ -130,9 +132,13 @@ def run(self, *args, **kwargs):

@dataclass
class Input(Keyworded):
path: Path
folder: Path
expressions: list[str]

@classmethod
def from_cfg(cls, cfg: InputData):
return cls(Path(cfg["folder"]), cfg["expressions"])


class InvalidFormatException(Exception):
def __init__(self, disallowed: str):
Expand All @@ -158,13 +164,13 @@ def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, A

@dataclass
class Output(Keyworded):
path: Path
folder: Path
filters: dict[Filter, FilterData]
output_format: str
overwrite: bool

def __init__(self, path, filters, overwrite=False, output_format=DEFAULT_OUTPUT_FORMAT):
self.path = path
self.folder = path
# try to format. If it fails, it will raise InvalidFormatException
outputformatter.format(output_format, **PLACEHOLDER_FORMAT_KWARGS)
self.output_format = output_format
Expand All @@ -173,3 +179,12 @@ def __init__(self, path, filters, overwrite=False, output_format=DEFAULT_OUTPUT_

def format_file(self, file: File):
return outputformatter.format(self.output_format, **file.to_dict())

@classmethod
def from_cfg(cls, cfg: OutputData):
return cls(
Path(cfg["folder"]),
{Filter.all_filters[filter_["name"]]: filter_["data"] for filter_ in cfg["lst"]},
cfg["overwrite"],
cfg["output_format"],
)
22 changes: 21 additions & 1 deletion src/datarules/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from typing import Generator, Literal, TypeVar, overload

import polars as pl
from polars import DataFrame, Expr
import ujson
from polars import DataFrame, Expr, concat
from polars.type_aliases import SchemaDefinition

from ..configs import FilterData, MainConfig
from .base_rules import (
Comparable,
DataTypeSchema,
ExprDict,
FastComparable,
Filter,
Input,
Output,
Producer,
ProducerSet,
Rule,
Expand Down Expand Up @@ -294,3 +299,18 @@ def comply_to_schema(self, schema: SchemaDefinition, in_place=False) -> DataFram
if in_place:
self.__df = new_df
return new_df


class ConfigHandler:
def __init__(self, cfg: MainConfig):
# generate `Input`s
self.inputs: list[Input] = [Input.from_cfg(folder["data"]) for folder in cfg["inputs"]]
# generate `Output`s
self.outputs: list[Output] = [Output.from_cfg(folder["data"]) for folder in cfg["output"]]
# generate `Producer`s
self.producers: list[Producer] = [
Producer.all_producers[p["name"]].from_cfg(p["data"]) for p in cfg["producers"]
]

# generate `Rule`s
self.rules: list[Rule] = [Rule.all_rules[r["name"]].from_cfg(r["data"]) for r in cfg["rules"]]
4 changes: 2 additions & 2 deletions src/gui/output_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get(self) -> str:

def get_config(self) -> OutputData:
return {
"path": self.text.text(),
"folder": self.text.text(),
"output_format": self.format_str.text() or self.format_str.placeholderText(),
"lst": self.list.get_config(),
"overwrite": self.overwrite.isChecked(),
Expand All @@ -87,7 +87,7 @@ def get_config(self) -> OutputData:
@classmethod
def from_config(cls, cfg: OutputData, parent=None):
self = cls(parent)
self.text.setText(cfg["path"])
self.text.setText(cfg["folder"])
self.format_str.setText(cfg["output_format"])
self.list.add_from_cfg(cfg["lst"])
self.overwrite.setChecked(cfg["overwrite"])
Expand Down

0 comments on commit 3938854

Please sign in to comment.