Skip to content

Commit

Permalink
add random rotate, random flip
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Oct 23, 2023
1 parent 959230b commit bf39a97
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 24 deletions.
73 changes: 68 additions & 5 deletions imdataset_creator/datarules/base_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from string import Formatter
from types import MappingProxyType
from typing import Any, ClassVar, Set
from typing import Any, ClassVar

import numpy as np
import wcmatch.glob as wglob
Expand Down Expand Up @@ -90,7 +90,7 @@ def __call__(self) -> ProducerSchema:
raise NotImplementedError


class ProducerSet(Set[Producer]):
class ProducerSet(set[Producer]):
@staticmethod
def _combine_schema(exprs: ProducerSchema) -> ExprDict:
return {col: expr for expression in exprs for col, expr in expression.items()}
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init_subclass__(cls):
Filter.all_filters[cls.cfg_kwd()] = cls

@abstractmethod
def run(self, img: np.ndarray):
def run(self, img: np.ndarray) -> np.ndarray:
raise NotImplementedError


Expand All @@ -148,6 +148,17 @@ def run(self, img: np.ndarray):

@dataclass(repr=False)
class Input(Keyworded):
"""
A dataclass representing the input configuration.
Attributes
----------
folder : Path
The path to the input folder.
expressions : list[str]
A list of glob patterns to match files in the input folder.
"""

folder: Path
expressions: list[str]

Expand All @@ -156,6 +167,14 @@ def from_cfg(cls, cfg: InputData):
return cls(Path(cfg["folder"]), cfg["expressions"])

def run(self):
"""
Yield the paths of all files in the input folder that match the glob patterns.
Returns
-------
Iterator[Path]
An iterator over all paths of files in the input folder that match the glob patterns.
"""
for file in wglob.iglob(self.expressions, flags=flags, root_dir=self.folder): # type: ignore
yield self.folder / file

Expand Down Expand Up @@ -184,6 +203,26 @@ def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, A

@dataclass(repr=False)
class Output(Keyworded):
"""
A dataclass representing the output configuration.
Attributes
----------
folder : Path
The path to the output folder.
filters : list[Filter]
A list of `Filter` objects to be applied to the output.
output_format : str
The format of the output files.
overwrite : bool
Whether to overwrite existing files.
Raises
------
InvalidFormatException
If the `output_format` is invalid.
"""

folder: Path
filters: list[Filter]
output_format: str
Expand All @@ -204,6 +243,19 @@ def __init__(
self.filters = filters

def format_file(self, file: File):
"""
Format a `File` object according to the `output_format`.
Parameters
----------
file : File
The `File` object to be formatted.
Returns
-------
str
The formatted string.
"""
return output_formatter.format(self.output_format, **file.to_dict())

@classmethod
Expand All @@ -217,9 +269,20 @@ def from_cfg(cls, cfg: OutputData):


def combine_expr_conds(exprs: list[Expr]) -> Expr:
assert exprs
"""
Combine a list of `Expr` objects using the `&` operator.
comp = exprs[0]
Parameters
----------
exprs : list[Expr]
A list of `Expr` objects to be combined.
Returns
-------
Expr
A single `Expr` object representing the combined expression.
"""
comp: Expr = exprs[0]
for e in exprs[1:]:
comp &= e
return comp
8 changes: 8 additions & 0 deletions imdataset_creator/enum_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum
from typing import TypeVar

T = TypeVar("T")


def listostr2listoenum(lst: list[str], enum: type[T]) -> list[T]:
return [enum._member_map_[k] for k in lst] # type: ignore
3 changes: 3 additions & 0 deletions imdataset_creator/gui/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def get_config(self) -> dict[str, bool]:
def set_config(self, i: str, val: bool):
self.items[i].setChecked(val)

def get_enabled(self):
return [i for i, item in self.items.items() if item.isChecked()]


UnderlineFont = QFont()
UnderlineFont.setUnderline(True)
Expand Down
111 changes: 101 additions & 10 deletions imdataset_creator/gui/output_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PySide6.QtCore import QSize
from PySide6.QtCore import QSize, Qt
from PySide6.QtWidgets import (
QCheckBox,
QComboBox,
Expand All @@ -14,13 +14,16 @@
QWidget,
)

from imdataset_creator.configs.configtypes import ItemData

from ..datarules import Filter
from ..image_filters import destroyers, resizer
from .frames import FlowItem, FlowList, MiniCheckList, tooltip


class FilterView(FlowItem):
title = "Filter"
needs_settings = True

bound_item: type[Filter]

Expand All @@ -42,7 +45,6 @@ def get(self):

class ResizeFilterView(FilterView):
title = "Resize"
needs_settings = True

bound_item = resizer.Resize

Expand Down Expand Up @@ -74,7 +76,6 @@ def from_config(cls, cfg, parent=None):
class CropFilterView(FilterView):
title = "Crop"
desc = "Crop the image to the specified size. If the item is 0, it will not be considered"
needs_settings = True

bound_item = resizer.Crop

Expand Down Expand Up @@ -117,7 +118,6 @@ def from_config(cls, cfg: resizer.CropData, parent=None):

class BlurFilterView(FilterView):
title = "Blur"
needs_settings = True

bound_item = destroyers.Blur

Expand Down Expand Up @@ -152,7 +152,7 @@ def reset_settings_group(self):
self.blur_range_y.setValue(16)

def get_config(self) -> destroyers.BlurData:
algos = [algo for algo, enabled in self.algorithms.get_config().items() if enabled]
algos = self.algorithms.get_enabled()
if not algos:
raise EmptyAlgorithmsError(self)
return destroyers.BlurData(
Expand All @@ -178,7 +178,6 @@ def from_config(cls, cfg, parent=None):

class NoiseFilterView(FilterView):
title = "Noise"
needs_settings = True

bound_item = destroyers.Noise

Expand Down Expand Up @@ -209,7 +208,7 @@ def reset_settings_group(self):
self.intensity_range_y.setValue(16)

def get_config(self) -> destroyers.NoiseData:
algos = [algo for algo, enabled in self.algorithms.get_config().items() if enabled]
algos = self.algorithms.get_enabled()
if not algos:
raise EmptyAlgorithmsError(self)
return destroyers.NoiseData(
Expand All @@ -235,7 +234,6 @@ def from_config(cls, cfg, parent=None):

class CompressionFilterView(FilterView):
title = "Compression"
needs_settings = True

bound_item = destroyers.Compression

Expand Down Expand Up @@ -291,6 +289,7 @@ def configure_settings_group(self):
self.group_grid.addWidget(self.mpeg2_bitrate, 6, 1, 1, 2)

def reset_settings_group(self):
self.algorithms.disable_all()
self.j_range_min.setValue(0)
self.j_range_max.setValue(100)
self.w_range_min.setValue(1)
Expand All @@ -301,12 +300,12 @@ def reset_settings_group(self):
self.hevc_range_max.setValue(33)

def get_config(self) -> destroyers.CompressionData:
algos = [algo for algo, enabled in self.algorithms.get_config().items() if enabled]
algos = self.algorithms.get_enabled()
if not algos:
raise EmptyAlgorithmsError(self)
return destroyers.CompressionData(
{
"algorithms": [algo for algo, enabled in self.algorithms.get_config().items() if enabled],
"algorithms": algos,
"jpeg_quality_range": [self.j_range_min.value(), self.j_range_max.value()],
"webp_quality_range": [self.w_range_min.value(), self.w_range_max.value()],
"h264_crf_range": [self.h264_range_min.value(), self.h264_range_max.value()],
Expand Down Expand Up @@ -335,6 +334,96 @@ def from_config(cls, cfg, parent=None):
return self


class RandomFlipFilterView(FilterView):
title = "Random Flip"

bound_item = resizer.RandomFlip

def configure_settings_group(self):
self.flip_x_slider = QSlider(Qt.Orientation.Horizontal, self)
self.flip_x_slider.setMaximum(100)
self.flip_x_slider.setMinimum(0)
self.flip_y_slider = QSlider(Qt.Orientation.Horizontal, self)
self.flip_y_slider.setMaximum(100)
self.flip_y_slider.setMinimum(0)
self.flip_x_chance = QDoubleSpinBox(self)
self.flip_x_chance.setRange(0, 100)
self.flip_x_chance.setSuffix("%")
self.flip_x_chance.setSingleStep(0.5)
self.flip_y_chance = QDoubleSpinBox(self)
self.flip_y_chance.setRange(0, 100)
self.flip_y_chance.setSuffix("%")
self.flip_y_chance.setSingleStep(0.5)

self.flip_x_slider.valueChanged.connect(self.flip_x_chance.setValue)
self.flip_y_slider.valueChanged.connect(self.flip_y_chance.setValue)
self.flip_x_chance.valueChanged.connect(self.flip_x_slider.setValue)
self.flip_y_chance.valueChanged.connect(self.flip_y_slider.setValue)

self.group_grid.addWidget(QLabel("Flip X Chance", self), 0, 0, 1, 1)
self.group_grid.addWidget(self.flip_x_slider, 0, 1, 1, 1)
self.group_grid.addWidget(self.flip_x_chance, 0, 2, 1, 1)
self.group_grid.addWidget(QLabel("Flip Y Chance", self), 1, 0, 1, 1)
self.group_grid.addWidget(self.flip_y_slider, 1, 1, 1, 1)
self.group_grid.addWidget(self.flip_y_chance, 1, 2, 1, 1)

def reset_settings_group(self):
self.flip_x_chance.setValue(50)
self.flip_y_chance.setValue(50)

@classmethod
def from_config(cls, cfg: resizer.RandomFlipData, parent=None):
self = cls(parent)
self.flip_x_chance.setValue(cfg["flip_x_chance"] * 100)
self.flip_y_chance.setValue(cfg["flip_y_chance"] * 100)
return self

def get_config(self) -> resizer.RandomFlipData:
return {
"flip_x_chance": self.flip_x_chance.value() / 100,
"flip_y_chance": self.flip_y_chance.value() / 100,
}


class RandomRotateFilterView(FilterView):
title = "Random Rotate"

bound_item = resizer.RandomRotate

def configure_settings_group(self) -> None:
self.rotate_chance = QDoubleSpinBox(self)
self.rotate_slider = QSlider(Qt.Orientation.Horizontal, self)
self.rotate_chance.valueChanged.connect(self.rotate_slider.setValue)
self.rotate_slider.valueChanged.connect(self.rotate_chance.setValue)
self.rotate_slider.setRange(0, 100)

self.rotation_list = MiniCheckList(resizer.RandomRotateDirections._member_names_, self)
self.group_grid.addWidget(self.rotation_list, 0, 0, 1, 3)
self.group_grid.addWidget(QLabel("Rotation chance:", self), 1, 0, 1, 1)
self.group_grid.addWidget(self.rotate_slider, 1, 1, 1, 1)
self.group_grid.addWidget(self.rotate_chance, 1, 2, 1, 1)

def reset_settings_group(self):
self.rotation_list.disable_all()

@classmethod
def from_config(cls, cfg: resizer.RandomRotateData, parent=None):
self = cls(parent)
for item in cfg["rotate_directions"]:
self.rotation_list.set_config(item, True)
self.rotate_chance.setValue(cfg["rotate_chance"] * 100)
return self

def get_config(self) -> resizer.RandomRotateData:
rots = self.rotation_list.get_enabled()
if not rots:
raise EmptyAlgorithmsError(self)
return {
"rotate_chance": self.rotate_chance.value() / 100,
"rotate_directions": rots,
}


class EmptyAlgorithmsError(Exception):
"""Raised when no algorithms are enabled"""

Expand All @@ -354,4 +443,6 @@ def __init__(self, *args, **kwargs):
BlurFilterView,
NoiseFilterView,
CompressionFilterView,
RandomFlipFilterView,
RandomRotateFilterView,
)
13 changes: 11 additions & 2 deletions imdataset_creator/image_filters/destroyers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..configs.configtypes import FilterData
from ..datarules import Filter
from ..enum_helpers import listostr2listoenum

log = logging.getLogger()

Expand Down Expand Up @@ -73,7 +74,7 @@ def run(
@classmethod
def from_cfg(cls, cfg: BlurData) -> Self:
return cls(
algorithms=[BlurAlgorithm._member_map_[k] for k in cfg["algorithms"]], # type: ignore
algorithms=listostr2listoenum(cfg["algorithms"], BlurAlgorithm),
blur_range=cfg["blur_range"], # type: ignore
scale=cfg["scale"],
)
Expand Down Expand Up @@ -124,6 +125,14 @@ def run(self, img: ndarray) -> ndarray:
noise = noise[..., None]
return img + noise

@classmethod
def from_cfg(cls, cfg: NoiseData) -> Self:
return cls(
algorithms=listostr2listoenum(cfg["algorithms"], NoiseAlgorithm),
intensity_range=cfg["intensity_range"], # type: ignore
scale=cfg["scale"],
)


class CompressionAlgorithms(Enum):
JPEG = "jpeg"
Expand Down Expand Up @@ -230,7 +239,7 @@ def run(self, img: ndarray):
@classmethod
def from_cfg(cls, cfg: CompressionData) -> Self:
return cls(
algorithms=[CompressionAlgorithms._member_map_[k] for k in cfg["algorithms"]], # type: ignore
algorithms=listostr2listoenum(cfg["algorithms"], CompressionAlgorithms),
jpeg_quality_range=cfg["jpeg_quality_range"], # type: ignore
webp_quality_range=cfg["webp_quality_range"], # type: ignore
h264_crf_range=cfg["h264_crf_range"], # type: ignore
Expand Down
Loading

0 comments on commit bf39a97

Please sign in to comment.