-
Notifications
You must be signed in to change notification settings - Fork 250
Description
🚀 Feature
If a torchio.Sampler
is used in combination with a torchio.Queue
, the Queue requests the NUM_SAMPLES
attribute of each torchio.Subject
in _fill
.
def _fill(self) -> None:
assert self.sampler is not None
if self._incomplete_subject is not None:
subject = self._incomplete_subject
iterable = self.sampler(subject)
patches = list(islice(iterable, self._num_patches_incomplete))
self.patches_list.extend(patches)
self._incomplete_subject = None
while True:
subject = self._get_next_subject()
iterable = self.sampler(subject)
num_samples = self._get_subject_num_samples(subject) <- HERE IS THE ATTRIBUTE CALL
num_free_slots = self.max_length - len(self.patches_list)
if num_free_slots < num_samples:
self._incomplete_subject = subject
self._num_patches_incomplete = num_samples - num_free_slots
num_samples = min(num_samples, num_free_slots)
patches = list(islice(iterable, num_samples))
self.patches_list.extend(patches)
self._num_sampled_subjects += 1
list_full = len(self.patches_list) >= self.max_length
all_sampled = self._num_sampled_subjects >= self.num_subjects
if list_full or all_sampled:
break
However, usually the max number of samples/patches per subject is dependent on the different augmentations performed and subsequently is reflected by the number of non-zero entries in the probability map, which is processed by the cdf to yield patches in _generate_patches
of the respective sampler. As a result, the number of samples is only known AFTER calculating the probability_map - the current implementation of the Queue however requests the attribute PRIOR to knowing the number of samples.
If one would rewrite __call__
(sampler
), _generate_patches
(sampler
) and _get_subject_num_samples
(Queue
) (shown in the following), one could obtain the probability_map prior to creating the generator in _generate_patches
and therefore set the num_samples
prior to the Queue requesting the attribute.
Motivation
Allowing the user to set/ automatically setting the number of samples retrieved from each subject makes the Queue
more robust, functional, and alleviates the sampling of duplicates (e.g. the probability_map only has 5 allowed patches but the user requested 10 -> each one is sampled approx. twice).
Pitch
Rewrite __call__
of torchio.Sampler
to the following:
def __call__(
self,
subject: Subject,
num_patches: Optional[int] = None) -> Generator[Subject, None, None]:
subject.check_consistent_space()
if np.any(self.patch_size > subject.spatial_shape):
message = (
f'Patch size {tuple(self.patch_size)} cannot be'
f' larger than image size {tuple(subject.spatial_shape)}'
)
raise RuntimeError(message)
probability_map = self.get_probability_map(subject)
num_max_patches = int(torch.count_nonzero(probability_map))
setattr(subject, NUM_SAMPLES, num_max_patches)
# This is optional
if num_patches is None:
num_patches = getattr(subject, NUM_SAMPLES)
return self._generate_patches(subject, probability_map, num_patches)
Rewrite _generate_patches
of the samplers to the following (in my example it is weighted.py
sampler but needs to be done accordingly if the method is overwritten in other samplers):
def _generate_patches(
self,
subject: Subject,
probability_map: torch.Tensor,
num_patches: Optional[int] = None,
) -> Generator[Subject, None, None]:
# Only removes the call to calculating the probability map here
probability_map_array = self.process_probability_map(
probability_map,
subject,
)
cdf = self.get_cumulative_distribution_function(probability_map_array)
patches_left = num_patches if num_patches is not None else True
while patches_left:
yield self.extract_patch(subject, probability_map_array, cdf)
if num_patches is not None:
patches_left -= 1
And finally adapt the method _get_subject_num_samples
of torchio.Queue
to:
def _get_subject_num_samples(self, subject):
num_samples = getattr(
subject,
NUM_SAMPLES,
self.samples_per_volume,
)
return min(num_samples, self.samples_per_volume) <- Prevents sampling of more patches than there are in a subject
Alternatives
The highlighted section in __call__
should be kept in to prevent an endless loop in the case of num_patches=None
in _generate_patches
. As an alternative, one could force to have num_patches set to an integer in any case (I can't think of a scenario of endless sampling), i.e. remove the Optional
and test for is not None
.
Remarks
The layout of the Queue
might change depending on the outcome of #1096. Furthermore, there needs to be a method if a sample has zero available patches (for whatever reasons). Currently, I ensure that this does not happen in get_probability_map
but the entire thing might break down if a subject has zero patches (not tested by me as of now).