Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Leftover Handling to Ensure Consistent Batch Sizes and Resolutions #5

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions bucketmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64
epoch_seed = self.prng.tomaxint() % (2**32-1)
self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilience
self.epoch = None
self.left_over = None
self.left_over = {}
self.batch_total = None
self.batch_delivered = None

Expand Down Expand Up @@ -138,7 +138,7 @@ def start_epoch(self, world_size=None, global_rank=None):
index = set(index)

self.epoch = {}
self.left_over = []
self.left_over = {}
self.batch_delivered = 0
for bucket_id in sorted(self.buckets.keys()):
if len(self.buckets[bucket_id]) > 0:
Expand All @@ -147,10 +147,10 @@ def start_epoch(self, world_size=None, global_rank=None):
self.epoch[bucket_id] = list(self.epoch[bucket_id])
overhang = len(self.epoch[bucket_id]) % self.bsz
if overhang != 0:
self.left_over.extend(self.epoch[bucket_id][:overhang])
if bucket_id not in self.left_over:
self.left_over[bucket_id] = []
self.left_over[bucket_id].extend(self.epoch[bucket_id][:overhang])
self.epoch[bucket_id] = self.epoch[bucket_id][overhang:]
if len(self.epoch[bucket_id]) == 0:
del self.epoch[bucket_id]

if self.debug:
timer = time.perf_counter() - timer
Expand All @@ -163,50 +163,61 @@ def start_epoch(self, world_size=None, global_rank=None):
def get_batch(self):
if self.debug:
timer = time.perf_counter()
# check if no data left or no epoch initialized
if self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered:
if self.epoch is None or self.batch_total == self.batch_delivered:
self.start_epoch()

found_batch = False
batch_data = None
resolution = self.base_res
while not found_batch:
bucket_ids = list(self.epoch.keys())
if len(self.left_over) >= self.bsz:
bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]
bucket_ids = [-1] + bucket_ids
else:
bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]
bucket_probs = np.array(bucket_probs, dtype=np.float32)
bucket_lens = bucket_probs
bucket_probs = bucket_probs / bucket_probs.sum()
bucket_ids = np.array(bucket_ids, dtype=np.int64)
if bool(self.epoch):
chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0])
else:
chosen_id = -1
bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]

if chosen_id == -1:
# using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square image
self.prng.shuffle(self.left_over)
batch_data = self.left_over[:self.bsz]
self.left_over = self.left_over[self.bsz:]
found_batch = True
else:
left_over_bucket_ids = list(self.left_over.keys())
left_over_bucket_probs = [len(self.left_over[bucket_id]) for bucket_id in left_over_bucket_ids]

all_bucket_ids = bucket_ids + left_over_bucket_ids
all_bucket_probs = bucket_probs + left_over_bucket_probs

if len(all_bucket_probs) == 0:
# No buckets left, start new epoch
self.start_epoch()
continue

all_bucket_probs = np.array(all_bucket_probs, dtype=np.float32)
all_bucket_probs = all_bucket_probs / all_bucket_probs.sum()
all_bucket_ids = np.array(all_bucket_ids, dtype=np.int64)

chosen_id = int(self.prng.choice(all_bucket_ids, 1, p=all_bucket_probs)[0])

if chosen_id in self.epoch:
if len(self.epoch[chosen_id]) >= self.bsz:
# return bucket batch and resolution
batch_data = self.epoch[chosen_id][:self.bsz]
self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:]
resolution = tuple(self.resolutions[chosen_id])
found_batch = True
if len(self.epoch[chosen_id]) == 0:
del self.epoch[chosen_id]
else:
# can't make a batch from this, not enough images. move them to leftovers and try again
self.left_over.extend(self.epoch[chosen_id])
# Move leftovers to left_over dict
if chosen_id not in self.left_over:
self.left_over[chosen_id] = []
self.left_over[chosen_id].extend(self.epoch[chosen_id])
del self.epoch[chosen_id]

assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch))
elif chosen_id in self.left_over:
if len(self.left_over[chosen_id]) >= self.bsz:
batch_data = self.left_over[chosen_id][:self.bsz]
self.left_over[chosen_id] = self.left_over[chosen_id][self.bsz:]
resolution = tuple(self.resolutions[chosen_id])
found_batch = True
if len(self.left_over[chosen_id]) == 0:
del self.left_over[chosen_id]
else:
# Not enough images to form a batch, keep them for the next epoch
del self.left_over[chosen_id]
else:
# Should not happen
assert False, "Chosen bucket ID not found in epoch or leftovers"

if self.debug:
timer = time.perf_counter() - timer
Expand Down