Skip to content

Commit

Permalink
add DATALOADER.REPEAT_SQRT
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5245

X-link: fairinternal/detectron2#602

For sampler **RepeatFactorTrainingSampler**, current per-category weight is computed as **1/sqrt(frequency)**.

This works fine on LVIS but is not sufficient in highly imbalanced data we have for person segmentation.

Thus we add an argument **DATALOADER.REPEAT_SQRT**. If false, we compute per-category weight as **1/frequency**.

This change is entirely back-compatible.

Reviewed By: wat3rBro

Differential Revision:
D55355021

Privacy Context Container: L1165023

fbshipit-source-id: 6bca2eecc3b9a7b4693a288c5779627254cd5ec5
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Mar 28, 2024
1 parent eb96ee1 commit afe9eb9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 2 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
# Repeat threshold for RepeatFactorTrainingSampler
_C.DATALOADER.REPEAT_THRESHOLD = 0.0
# if True, take square root when computing repeating factor
_C.DATALOADER.REPEAT_SQRT = True
# Tf True, when working on datasets that have instance annotations, the
# training dataloader will filter out images without associated annotations
_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
Expand Down
4 changes: 2 additions & 2 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _build_weighted_sampler(cfg, enable_category_balance=False):
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
)
for dataset_dict in dataset_name_to_dicts.values()
]
Expand Down Expand Up @@ -482,7 +482,7 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
dataset, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "RandomSubsetTrainingSampler":
Expand Down
12 changes: 10 additions & 2 deletions detectron2/data/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, repeat_factors, *, shuffle=True, seed=None):
self._frac_part = repeat_factors - self._int_part

@staticmethod
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
"""
Compute (fractional) per-image repeat factors based on category frequency.
The repeat factor for an image is a function of the frequency of the rarest
Expand All @@ -169,6 +169,7 @@ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
repeat_thresh (float): frequency threshold below which data is repeated.
If the frequency is half of `repeat_thresh`, the image will be
repeated twice.
sqrt (bool): if True, apply :func:`math.sqrt` to the repeat factor.
Returns:
torch.Tensor:
Expand All @@ -187,7 +188,14 @@ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t / f(c)))
category_rep = {
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
cat_id: max(
1.0,
(
math.sqrt(repeat_thresh / cat_freq)
if sqrt
else (repeat_thresh / cat_freq)
),
)
for cat_id, cat_freq in category_freq.items()
}
for cat_id in sorted(category_rep.keys()):
Expand Down

0 comments on commit afe9eb9

Please sign in to comment.