Skip to content

Commit

Permalink
update scannet instance dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
QianyiWu committed Oct 19, 2022
1 parent b09eca4 commit d8fc381
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions code/datasets/scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self,
all_sems = []
all_poses = []

self.instance = instance

img_dir = os.path.join(data_dir, 'color')
if instance:
sem_dir = os.path.join(data_dir, 'instance-filt')
Expand All @@ -40,10 +42,22 @@ def __init__(self,
pose_dir = os.path.join(data_dir, 'pose')

self.label_mapping = None
self.instance_mapping_dict= {}
if instance:
with open(os.path.join(data_dir, 'label_mapping_instance.txt'), 'r') as f:
content = f.readlines()
self.label_mapping = [int(a) for a in content[0].split(',')]
# with open(os.path.join(data_dir, 'label_mapping_instance.txt'), 'r') as f:
# content = f.readlines()
# self.label_mapping = [int(a) for a in content[0].split(',')]

# using the remapped instance label for training
with open(os.path.join(data_dir, 'instance_mapping.txt'), 'r') as f:
for l in f:
(k, v_sem, v_ins) = l.split(',')
self.instance_mapping_dict[int(k)] = int(v_ins)
self.label_mapping = [] # get the sorted label mapping list
for k, v in self.instance_mapping_dict.items():
if v not in self.label_mapping: # not a duplicate instance
self.label_mapping.append(v)
print('Instance Label Mapping: ', self.label_mapping)
else:
with open(os.path.join(data_dir, 'label_mapping.txt'), 'r') as f:
content = f.readlines()
Expand Down Expand Up @@ -176,18 +190,15 @@ def load_meta_data(self, split_path, img_dir, seg_dir, pose_dir):
for cur_id in img_ids:
split_imgs.append((cv2.resize(imageio.imread(os.path.join(img_dir, "%s.jpg" % str(cur_id))), (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA))\
.transpose(2, 0, 1).reshape(3, -1).transpose(1, 0))
# hacky, update some labels brute forcely....
cur_sems = cv2.resize(imageio.imread(os.path.join(seg_dir, "%s.png" % str(cur_id))), (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_NEAREST).\

ori_sems = cv2.resize(imageio.imread(os.path.join(seg_dir, "%s.png" % str(cur_id))), (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_NEAREST).\
reshape(1, -1).transpose(1, 0)

# remap the semantic label for continuous index
# cur_sems[cur_sems == 131] = 4
# cur_sems[cur_sems == 24] = 5
# cur_sems[cur_sems == 1163] = 6
# cur_sems[cur_sems == 56] = 7
cur_sems = np.copy(ori_sems)
if self.label_mapping is not None:
for i in self.label_mapping:
cur_sems[cur_sems == i] = self.label_mapping.index(i)
# cur_sems[cur_sems == i] = self.label_mapping.index(i)
cur_sems[ori_sems == i] = self.label_mapping.index(self.instance_mapping_dict[i]) if self.instance else self.label_mapping(i)
split_sems.append(cur_sems)

pose_path = os.path.join(pose_dir, "%s.txt" % str(cur_id))
Expand Down

0 comments on commit d8fc381

Please sign in to comment.