File tree 2 files changed +18
-1
lines changed
2 files changed +18
-1
lines changed Original file line number Diff line number Diff line change @@ -78,3 +78,11 @@ def __getitem__(self, idx):
78
78
image = self .augmentation (image )
79
79
80
80
return image .float (), bone_age
81
+
82
+ def get_subset_by_idxs (self , idxs ):
83
+ subset = super ().get_subset_by_idxs (idxs )
84
+ subset .artifact_labels = self .artifact_labels [np .array (idxs )]
85
+ subset .artifact_ids = np .where (subset .artifact_labels )[0 ]
86
+ subset .sample_ids_by_artifact = {"artificial" : subset .artifact_ids }
87
+ subset .clean_sample_ids = [i for i in range (len (subset )) if i not in subset .artifact_ids ]
88
+ return subset
Original file line number Diff line number Diff line change @@ -78,4 +78,13 @@ def __getitem__(self, i):
78
78
img = self .augmentation (img )
79
79
columns = self .metadata .columns .to_list ()
80
80
target = torch .Tensor ([columns .index (row [row == 1.0 ].index [0 ]) - 1 if self .train else 0 ]).long ()[0 ]
81
- return img , target
81
+ return img , target
82
+
83
+ def get_subset_by_idxs (self , idxs ):
84
+ subset = super ().get_subset_by_idxs (idxs )
85
+ subset .artifact_labels = self .artifact_labels [np .array (idxs )]
86
+
87
+ subset .artifact_ids = np .where (subset .artifact_labels )[0 ]
88
+ subset .sample_ids_by_artifact = {"artificial" : subset .artifact_ids }
89
+ subset .clean_sample_ids = [i for i in range (len (subset )) if i not in subset .artifact_ids ]
90
+ return subset
You can’t perform that action at this time.
0 commit comments