diff --git a/bucketmanager.py b/bucketmanager.py index bb6adbb..ecb6e82 100644 --- a/bucketmanager.py +++ b/bucketmanager.py @@ -34,7 +34,7 @@ def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64 epoch_seed = self.prng.tomaxint() % (2**32-1) self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilience self.epoch = None - self.left_over = None + self.left_over = {} self.batch_total = None self.batch_delivered = None @@ -138,7 +138,7 @@ def start_epoch(self, world_size=None, global_rank=None): index = set(index) self.epoch = {} - self.left_over = [] + self.left_over = {} self.batch_delivered = 0 for bucket_id in sorted(self.buckets.keys()): if len(self.buckets[bucket_id]) > 0: @@ -147,10 +147,10 @@ def start_epoch(self, world_size=None, global_rank=None): self.epoch[bucket_id] = list(self.epoch[bucket_id]) overhang = len(self.epoch[bucket_id]) % self.bsz if overhang != 0: - self.left_over.extend(self.epoch[bucket_id][:overhang]) + if bucket_id not in self.left_over: + self.left_over[bucket_id] = [] + self.left_over[bucket_id].extend(self.epoch[bucket_id][:overhang]) self.epoch[bucket_id] = self.epoch[bucket_id][overhang:] - if len(self.epoch[bucket_id]) == 0: - del self.epoch[bucket_id] if self.debug: timer = time.perf_counter() - timer @@ -163,8 +163,7 @@ def start_epoch(self, world_size=None, global_rank=None): def get_batch(self): if self.debug: timer = time.perf_counter() - # check if no data left or no epoch initialized - if self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered: + if self.epoch is None or self.batch_total == self.batch_delivered: self.start_epoch() found_batch = False @@ -172,29 +171,27 @@ def get_batch(self): resolution = self.base_res while not found_batch: bucket_ids = list(self.epoch.keys()) - if len(self.left_over) >= self.bsz: - bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids] - bucket_ids = [-1] + bucket_ids - else: - bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids] - bucket_probs = np.array(bucket_probs, dtype=np.float32) - bucket_lens = bucket_probs - bucket_probs = bucket_probs / bucket_probs.sum() - bucket_ids = np.array(bucket_ids, dtype=np.int64) - if bool(self.epoch): - chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0]) - else: - chosen_id = -1 + bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids] - if chosen_id == -1: - # using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square image - self.prng.shuffle(self.left_over) - batch_data = self.left_over[:self.bsz] - self.left_over = self.left_over[self.bsz:] - found_batch = True - else: + left_over_bucket_ids = list(self.left_over.keys()) + left_over_bucket_probs = [len(self.left_over[bucket_id]) for bucket_id in left_over_bucket_ids] + + all_bucket_ids = bucket_ids + left_over_bucket_ids + all_bucket_probs = bucket_probs + left_over_bucket_probs + + if len(all_bucket_probs) == 0: + # No buckets left, start new epoch + self.start_epoch() + continue + + all_bucket_probs = np.array(all_bucket_probs, dtype=np.float32) + all_bucket_probs = all_bucket_probs / all_bucket_probs.sum() + all_bucket_ids = np.array(all_bucket_ids, dtype=np.int64) + + chosen_id = int(self.prng.choice(all_bucket_ids, 1, p=all_bucket_probs)[0]) + + if chosen_id in self.epoch: if len(self.epoch[chosen_id]) >= self.bsz: - # return bucket batch and resolution batch_data = self.epoch[chosen_id][:self.bsz] self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:] resolution = tuple(self.resolutions[chosen_id]) @@ -202,11 +199,25 @@ def get_batch(self): if len(self.epoch[chosen_id]) == 0: del self.epoch[chosen_id] else: - # can't make a batch from this, not enough images. move them to leftovers and try again - self.left_over.extend(self.epoch[chosen_id]) + # Move leftovers to left_over dict + if chosen_id not in self.left_over: + self.left_over[chosen_id] = [] + self.left_over[chosen_id].extend(self.epoch[chosen_id]) del self.epoch[chosen_id] - - assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch)) + elif chosen_id in self.left_over: + if len(self.left_over[chosen_id]) >= self.bsz: + batch_data = self.left_over[chosen_id][:self.bsz] + self.left_over[chosen_id] = self.left_over[chosen_id][self.bsz:] + resolution = tuple(self.resolutions[chosen_id]) + found_batch = True + if len(self.left_over[chosen_id]) == 0: + del self.left_over[chosen_id] + else: + # Not enough images to form a batch, keep them for the next epoch + del self.left_over[chosen_id] + else: + # Should not happen + assert False, "Chosen bucket ID not found in epoch or leftovers" if self.debug: timer = time.perf_counter() - timer