diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 70a53a4220..d2f4d35e87 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -53,7 +53,7 @@ class BioMassters(NonGeoDataset): valid_splits = ('train', 'test') valid_sensors = ('S1', 'S2') - metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' + default_metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' def __init__( self, @@ -61,6 +61,10 @@ def __init__( split: str = 'train', sensors: Sequence[str] = ['S1', 'S2'], as_time_series: bool = False, + metadata_filename: str = default_metadata_filename, + max_cloud_percentage: float | None = None, + max_red_mean: float | None = None, + include_corrupt: bool = True, ) -> None: """Initialize a new instance of BioMassters dataset. @@ -74,6 +78,10 @@ def __init__( Sentinel 2 ('S1', 'S2') as_time_series: whether or not to return all available time-steps or just a single one for a given target location + metadata_filename: metadata file to be used + max_cloud_percentage: maximum allowed cloud percentage for images + max_red_mean: maximum allowed red_mean value for images + include_corrupt: whether to include images marked as corrupted Raises: AssertionError: if ``split`` or ``sensors`` is invalid @@ -91,6 +99,10 @@ def __init__( ), f'Please choose a subset of valid sensors: {self.valid_sensors}.' self.sensors = sensors self.as_time_series = as_time_series + self.metadata_filename = metadata_filename + self.max_cloud_percentage = max_cloud_percentage + self.max_red_mean = max_red_mean + self.include_corrupt = include_corrupt self._verify() @@ -103,6 +115,16 @@ def __init__( # filter split self.df = self.df[self.df['split'] == self.split] + # additional filtering based on metadata + if self.max_cloud_percentage is not None and 'cloud_percentage' in self.df.columns: + self.df = self.df[self.df['cloud_percentage'] <= self.max_cloud_percentage] + + if self.max_red_mean is not None and 'red_mean' in self.df.columns: + self.df = self.df[self.df['red_mean'] <= self.max_red_mean] + + if not self.include_corrupt and 'corrupt_values' in self.df.columns: + self.df = self.df[self.df['corrupt_values'] is False] + # generate numerical month from filename since first month is September # and has numerical index of 0 self.df['num_month'] = (