Skip to content

Commit

Permalink
add use_tqdm flag in dataset_builder
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptofine committed Aug 19, 2023
1 parent ae0701c commit 86b3724
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 49 deletions.
7 changes: 6 additions & 1 deletion create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,12 @@ def hrlr_pair(path: Path) -> tuple[Path, Path | None]:

s.print("Populating df...")
db.populate_df(
image_list, cfg["trim"], cfg["trim_age_limit"], cfg["save_interval"], cfg["trim_check_exists"], cfg["chunksize"]
image_list,
trim=cfg["trim"],
trim_age_limit=cfg["trim_age_limit"],
save_interval=cfg["save_interval"],
trim_check_exists=cfg["trim_check_exists"],
chunksize=cfg["chunksize"],
)

s.print("Filtering...")
Expand Down
104 changes: 58 additions & 46 deletions src/datafilters/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def populate_df(
lst: Iterable[Path],
trim: bool = True,
trim_age_limit: int = 60 * 60 * 24 * 7,
use_tqdm=True,
save_interval: int = 60,
trim_check_exists: bool = True,
chunksize: int = 100,
Expand All @@ -103,57 +104,68 @@ def populate_df(
cond &= pl.col("path").apply(os.path.exists)
self.df = self.df.filter(cond)

from_full_to_relative: dict[str, Path] = self.absolute_dict(lst)
# add new paths to the dataframe with missing data
existing_paths = set(self.df.select(pl.col("path")).to_series())
new_paths: list[str] = [path for path in from_full_to_relative if path not in existing_paths]
if new_paths:
self.df = pl.concat(
[self.df, DataFrame({"path": new_paths})],
how="diagonal",
)
from_full_to_relative: dict[str, Path] = self.get_absolutes(lst)
if new_paths := set(from_full_to_relative) - set(self.df.get_column("path")):
self.df = pl.concat((self.df, DataFrame({"path": new_paths})), how="diagonal")

for filter_ in self.filters:
filter_.filedict = from_full_to_relative
self.df = self._make_schema_compliant(self.df, self.filter_type_schema)
updated_df: DataFrame = self.df.with_columns(self.filter_type_schema)
search_cols: set[str] = {*self.build_schema, *self.basic_schema}
self.df = self._comply_to_schema(self.df, self.filter_type_schema)

search_cols: set[str] = {*self.build_schema, *self.basic_schema}
updated_df: DataFrame = self.df.with_columns(self.filter_type_schema)
unfinished: DataFrame = updated_df.filter(
pl.any(pl.col(col).is_null() for col in updated_df.columns if col in search_cols)
)
if len(unfinished):
with (
tqdm(desc="Gathering file info...", unit="file", total=len(unfinished)) as total_t,
tqdm(desc="Processing chunks...", unit="chunk", total=len(unfinished) // chunksize) as sub_t,
):
save_timer = datetime.now()
collected_data: DataFrame = DataFrame(schema=self.filter_type_schema)

def trigger_save(save_timer: datetime, collected: list[DataFrame]) -> tuple[datetime, list[DataFrame]]:
if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval:
data = pl.concat(collected, how="diagonal")
self.df = self.df.update(data, on="path")
self.save_df()
return new_time, []
return save_timer, collected

def get_nulls_chunks_expr():
for nulls, group in unfinished.with_columns(self.filter_type_schema).groupby(
*(pl.col(col).is_not_null() for col in self.build_schema)
):
group_expr = {
col: expr
for truth, (col, expr) in zip(nulls, self.build_schema.items()) # type: ignore
if not truth
}
subgroups = list(self._split_into_chunks(group, chunksize=chunksize))
total_t.set_postfix_str(str(tuple(group_expr.keys())))
sub_t.total = len(subgroups)
sub_t.n = 0

for idx, subgroup in subgroups:
new_data: DataFrame = subgroup.with_columns(**group_expr)
collected_data = pl.concat((collected_data, new_data), how="diagonal")
if ((new_time := datetime.now()) - save_timer).total_seconds() > save_interval:
self.df = self.df.update(collected_data, on="path")
self.save_df()
collected_data = collected_data.clear()
save_timer = new_time
sub_t.set_postfix_str(str(idx))
sub_t.update(1)
total_t.update(len(subgroup))
self.df = self.df.update(collected_data, on="path").rechunk()
yield (
self._split_into_chunks(group, chunksize=chunksize),
{
col: expr
for truth, (col, expr) in zip(nulls, self.build_schema.items()) # type: ignore
if not truth
},
)

save_timer: datetime = datetime.now()
collected: list[DataFrame] = []
if not use_tqdm:
for subgroups, expr in get_nulls_chunks_expr():
for _, subgroup in subgroups:
collected.append(subgroup.with_columns(**expr))
save_timer, collected = trigger_save(save_timer, collected)
else:
with (
tqdm(desc="Gathering file info...", unit="file", total=len(unfinished)) as total_t,
tqdm(desc="Processing chunks...", unit="chunk", total=len(unfinished) // chunksize) as sub_t,
):
for subgroups, expr in get_nulls_chunks_expr():
subgroups = list(subgroups)
sub_t.total = len(subgroups)
sub_t.update(-sub_t.n)
total_t.set_postfix_str(str(set(expr)))

for idx, subgroup in subgroups:
collected.append(subgroup.with_columns(**expr))
save_timer, collected = trigger_save(save_timer, collected)

sub_t.set_postfix_str(str(idx))
sub_t.update()
total_t.update(len(subgroup))
self.df = self.df.update(pl.concat(collected, how="diagonal"), on="path")
self.save_df()
return

Expand All @@ -169,7 +181,7 @@ def filter(self, lst, sort_col="path", ignore_missing_columns=False) -> Iterable
f"the following columns are required but may not be in the database: {missing_requirements}"
)

from_full_to_relative: dict[str, Path] = self.absolute_dict(lst)
from_full_to_relative: dict[str, Path] = self.get_absolutes(lst)
paths: set[str] = set(from_full_to_relative.keys())

vdf: LazyFrame = self.df.lazy().filter(pl.col("path").is_in(paths))
Expand All @@ -182,30 +194,30 @@ def filter(self, lst, sort_col="path", ignore_missing_columns=False) -> Iterable
.filter(
pl.col("path").is_in(
dfilter.compare(
set(c.select(pl.col("path")).to_series()),
set(c.get_column("path")),
self.df.select(pl.col("path"), *[pl.col(col.name) for col in dfilter.schema]),
)
)
)
.lazy()
)

return (from_full_to_relative[p] for p in vdf.sort(sort_col).select(pl.col("path")).collect().to_series())
return (from_full_to_relative[p] for p in vdf.sort(sort_col).collect().get_column("path"))

def save_df(self) -> None:
"""saves the dataframe to self.filepath"""
self.df.write_ipc(self.filepath)

def absolute_dict(self, lst: Iterable[Path]) -> dict[str, Path]:
def get_absolutes(self, lst: Iterable[Path]) -> dict[str, Path]:
return {(str((self.origin / pth).resolve())): pth for pth in lst} # type: ignore

def get_path_data(self, pths: Collection[str]):
return self.df.filter(pl.col("path").is_in(pths))

@staticmethod
def _make_schema_compliant(data_frame: DataFrame, schema) -> DataFrame:
def _comply_to_schema(data_frame: DataFrame, schema) -> DataFrame:
"""adds columns from the schema to the dataframe. (not in-place)"""
return pl.concat([data_frame, DataFrame(schema=schema)], how="diagonal")
return pl.concat((data_frame, DataFrame(schema=schema)), how="diagonal")

@staticmethod
def _split_into_chunks(df: DataFrame, chunksize: int, column="_idx"):
Expand Down
9 changes: 7 additions & 2 deletions src/datafilters/external_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,19 @@ def __init__(
def compare(self, lst: Collection, cols: DataFrame) -> set:
assert self.resolver is not None
applied: DataFrame = (
cols.filter(col("hash").is_in(cols.filter(col("path").is_in(lst)).select(col("hash")).unique().to_series()))
cols.filter(
col("hash").is_in(cols.filter(col("path").is_in(lst)).get_column("hash").unique()),
)
.groupby("hash")
.apply(lambda df: df.filter(self.resolver) if len(df) > 1 else df) # type: ignore
)

resolved_paths = set(applied.select(col("path")).to_series())
resolved_paths = set(applied.get_column("path"))
return resolved_paths

def apply_resolver(self, df: DataFrame):
return df.filter(self.resolver)

def _hash_img(self, pth) -> str:
assert self.hasher is not None
return str(self.hasher(Image.open(pth)))

0 comments on commit 86b3724

Please sign in to comment.