Skip to content

Setting NUM_SAMPLES when using sampler with Queue #1108

@nicoloesch

Description

@nicoloesch

🚀 Feature
If a torchio.Sampler is used in combination with a torchio.Queue, the Queue requests the NUM_SAMPLES attribute of each torchio.Subject in _fill.

def _fill(self) -> None:
        assert self.sampler is not None

        if self._incomplete_subject is not None:
            subject = self._incomplete_subject
            iterable = self.sampler(subject)
            patches = list(islice(iterable, self._num_patches_incomplete))
            self.patches_list.extend(patches)
            self._incomplete_subject = None

        while True:
            subject = self._get_next_subject()
            iterable = self.sampler(subject)
            num_samples = self._get_subject_num_samples(subject)  <- HERE IS THE ATTRIBUTE CALL
            num_free_slots = self.max_length - len(self.patches_list)
            if num_free_slots < num_samples:
                self._incomplete_subject = subject
                self._num_patches_incomplete = num_samples - num_free_slots
            num_samples = min(num_samples, num_free_slots)
            patches = list(islice(iterable, num_samples))
            self.patches_list.extend(patches)
            self._num_sampled_subjects += 1
            list_full = len(self.patches_list) >= self.max_length
            all_sampled = self._num_sampled_subjects >= self.num_subjects
            if list_full or all_sampled:
                break

However, usually the max number of samples/patches per subject is dependent on the different augmentations performed and subsequently is reflected by the number of non-zero entries in the probability map, which is processed by the cdf to yield patches in _generate_patches of the respective sampler. As a result, the number of samples is only known AFTER calculating the probability_map - the current implementation of the Queue however requests the attribute PRIOR to knowing the number of samples.
If one would rewrite __call__ (sampler), _generate_patches (sampler) and _get_subject_num_samples (Queue) (shown in the following), one could obtain the probability_map prior to creating the generator in _generate_patches and therefore set the num_samples prior to the Queue requesting the attribute.

Motivation

Allowing the user to set/ automatically setting the number of samples retrieved from each subject makes the Queue more robust, functional, and alleviates the sampling of duplicates (e.g. the probability_map only has 5 allowed patches but the user requested 10 -> each one is sampled approx. twice).

Pitch

Rewrite __call__ of torchio.Sampler to the following:

def __call__(
            self, 
            subject: Subject, 
            num_patches: Optional[int] = None) -> Generator[Subject, None, None]:
        
        subject.check_consistent_space()
        if np.any(self.patch_size > subject.spatial_shape):
            message = (
                f'Patch size {tuple(self.patch_size)} cannot be'
                f' larger than image size {tuple(subject.spatial_shape)}'
            )
            raise RuntimeError(message)
        
        probability_map = self.get_probability_map(subject)
        num_max_patches = int(torch.count_nonzero(probability_map))
        setattr(subject, NUM_SAMPLES, num_max_patches)
        
        # This is optional
        if num_patches is None:
            num_patches = getattr(subject, NUM_SAMPLES)
        return self._generate_patches(subject, probability_map, num_patches)

Rewrite _generate_patches of the samplers to the following (in my example it is weighted.py sampler but needs to be done accordingly if the method is overwritten in other samplers):

def _generate_patches(
        self,
        subject: Subject,
        probability_map: torch.Tensor,
        num_patches: Optional[int] = None,
    ) -> Generator[Subject, None, None]:
        # Only removes the call to calculating the probability map here        
        probability_map_array = self.process_probability_map(
            probability_map,
            subject,
        )
        cdf = self.get_cumulative_distribution_function(probability_map_array)

        patches_left = num_patches if num_patches is not None else True
        while patches_left:
            yield self.extract_patch(subject, probability_map_array, cdf)
            if num_patches is not None:
                patches_left -= 1

And finally adapt the method _get_subject_num_samples of torchio.Queue to:

def _get_subject_num_samples(self, subject):
        num_samples = getattr(
            subject,
            NUM_SAMPLES,
            self.samples_per_volume,
        )
        return min(num_samples, self.samples_per_volume)  <- Prevents sampling of more patches than there are in a subject

Alternatives

The highlighted section in __call__ should be kept in to prevent an endless loop in the case of num_patches=None in _generate_patches. As an alternative, one could force to have num_patches set to an integer in any case (I can't think of a scenario of endless sampling), i.e. remove the Optional and test for is not None.

Remarks
The layout of the Queue might change depending on the outcome of #1096. Furthermore, there needs to be a method if a sample has zero available patches (for whatever reasons). Currently, I ensure that this does not happen in get_probability_map but the entire thing might break down if a subject has zero patches (not tested by me as of now).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions