Skip to content

Commit 5b82f28

Browse files
committed
correct splits for attacked datasets
1 parent c5d352e commit 5b82f28

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

datasets/bone_attacked.py

+8
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,11 @@ def __getitem__(self, idx):
7878
image = self.augmentation(image)
7979

8080
return image.float(), bone_age
81+
82+
def get_subset_by_idxs(self, idxs):
83+
subset = super().get_subset_by_idxs(idxs)
84+
subset.artifact_labels = self.artifact_labels[np.array(idxs)]
85+
subset.artifact_ids = np.where(subset.artifact_labels)[0]
86+
subset.sample_ids_by_artifact = {"artificial": subset.artifact_ids}
87+
subset.clean_sample_ids = [i for i in range(len(subset)) if i not in subset.artifact_ids]
88+
return subset

datasets/isic_attacked.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,13 @@ def __getitem__(self, i):
7878
img = self.augmentation(img)
7979
columns = self.metadata.columns.to_list()
8080
target = torch.Tensor([columns.index(row[row == 1.0].index[0]) - 1 if self.train else 0]).long()[0]
81-
return img, target
81+
return img, target
82+
83+
def get_subset_by_idxs(self, idxs):
84+
subset = super().get_subset_by_idxs(idxs)
85+
subset.artifact_labels = self.artifact_labels[np.array(idxs)]
86+
87+
subset.artifact_ids = np.where(subset.artifact_labels)[0]
88+
subset.sample_ids_by_artifact = {"artificial": subset.artifact_ids}
89+
subset.clean_sample_ids = [i for i in range(len(subset)) if i not in subset.artifact_ids]
90+
return subset

0 commit comments

Comments
 (0)