diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 00000000..528f30c7 --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 00000000..5b81ce5f --- /dev/null +++ b/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = storage +['remote "storage"'] + url = s3://machinelearning-assets/roofmeasurements/datasets/roof_segmentation diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 00000000..51973055 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/.gitignore b/.gitignore index df124a5c..dbe25587 100755 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,5 @@ # data files *.obj checkpoints -datasets +#datasets runs diff --git a/data/blender_scripts/extract_vertex_labels.py b/data/blender_scripts/extract_vertex_labels.py new file mode 100644 index 00000000..f9fd2ba6 --- /dev/null +++ b/data/blender_scripts/extract_vertex_labels.py @@ -0,0 +1,16 @@ + +import bpy + +ob = bpy.context.object +obdata = bpy.context.object.data + +label = [] +for v in obdata.vertices: + if bpy.context.object.vertex_groups['roof'].index in [i.group for i in v.groups]: + label.append(str(1)) + else: + label.append(str(0)) + +with open('/home/ihahanov/Projects/roof-measurements/dl_roof_extraction/meshcnn' + '/datasets/roof_seg/vseg/model12.eseg', 'w') as f: + f.write('\n'.join(label)) diff --git a/data/make_annotation_from_vertex_labels.py b/data/make_annotation_from_vertex_labels.py new file mode 100644 index 00000000..9322e154 --- /dev/null +++ b/data/make_annotation_from_vertex_labels.py @@ -0,0 +1,196 @@ +import numpy as np +import os +import glob +import filecmp +import sys + + +''' +Creates esseg files for accuracy with smooth transitions between classes +Requires Objects and corresponding labels per edge +Author: Rana Hanocka / Lisa Schneider + +@input: + path where seg, sseg, train, test folders are placed + +@output: + esseg files for all objects + to run it from cmd line: + python create_sseg.py /home/user/MedMeshCNN/datasets/human_seg/ +''' + +def compute_face_normals_and_areas(vs, faces): + face_normals = np.cross(vs[faces[:, 1]] - vs[faces[:, 0]], + vs[faces[:, 2]] - vs[faces[:, 1]]) + face_areas = np.sqrt((face_normals ** 2).sum(axis=1)) + face_normals /= face_areas[:, np.newaxis] + assert (not np.any(face_areas[:, np.newaxis] == 0)), 'has zero area face' + face_areas *= 0.5 + return face_normals, face_areas + + +def remove_non_manifolds(vs, faces): + edges_set = set() + mask = np.ones(len(faces), dtype=bool) + _, face_areas = compute_face_normals_and_areas(vs, faces) + for face_id, face in enumerate(faces): + if face_areas[face_id] == 0: + mask[face_id] = False + continue + faces_edges = [] + is_manifold = False + for i in range(3): + cur_edge = (face[i], face[(i + 1) % 3]) + if cur_edge in edges_set: + is_manifold = True + break + else: + faces_edges.append(cur_edge) + if is_manifold: + mask[face_id] = False + else: + for idx, edge in enumerate(faces_edges): + edges_set.add(edge) + return faces[mask], face_areas[mask] + +def get_gemm_edges(faces, export_name_edges): + """ + gemm_edges: array (#E x 4) of the 4 one-ring neighbors for each edge + sides: array (#E x 4) indices (values of: 0,1,2,3) indicating where an edge is in the gemm_edge entry of the 4 neighboring edges + for example edge i -> gemm_edges[gemm_edges[i], sides[i]] == [i, i, i, i] + """ + edge_nb = [] + sides = [] + edge2key = dict() + edges = [] + edges_count = 0 + nb_count = [] + for face_id, face in enumerate(faces): + faces_edges = [] + for i in range(3): + cur_edge = (face[i], face[(i + 1) % 3]) + faces_edges.append(cur_edge) + for idx, edge in enumerate(faces_edges): + edge = tuple(sorted(list(edge))) + faces_edges[idx] = edge + if edge not in edge2key: + edge2key[edge] = edges_count + edges.append(list(edge)) + edge_nb.append([-1, -1, -1, -1]) + sides.append([-1, -1, -1, -1]) + nb_count.append(0) + edges_count += 1 + for idx, edge in enumerate(faces_edges): + edge_key = edge2key[edge] + edge_nb[edge_key][nb_count[edge_key]] = edge2key[faces_edges[(idx + 1) % 3]] + edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[faces_edges[(idx + 2) % 3]] + nb_count[edge_key] += 2 + for idx, edge in enumerate(faces_edges): + edge_key = edge2key[edge] + sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1 + sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2 + edges = np.array(edges, dtype=np.int32) + np.savetxt(export_name_edges, edges, fmt='%i') + return edge_nb, edges + + +def load_labels(path): + with open(path, 'r') as f: + content = f.read().splitlines() + return content + +def create_sseg_file(gemms, labels, export_name_seseg): + gemmlabels = {} + classes = len(np.unique(labels)) + class_to_idx = {v: i for i, v in enumerate(np.unique(labels))} + totaledges = len(gemms) + sseg = np.zeros([ totaledges, classes]) + for i, edges in enumerate(gemms): + alllabels = [] + for edge in range(len(edges)): + lookupEdge = edges[edge] + label = labels[lookupEdge] + alllabels.append(label) + gemmlabels[i] = alllabels + + for i, edges in enumerate(gemms): + gemmlab = gemmlabels[i] + uniqueValues, counts = np.unique(gemmlab, return_counts=True) + for j, label in enumerate(uniqueValues): + weight = 0.125*counts[j] + sseg[i][class_to_idx[label]] = weight + np.savetxt(export_name_seseg, sseg, fmt='%1.6f') + +def get_obj(file): + vs, faces = [], [] + f = open(file) + for line in f: + line = line.strip() + splitted_line = line.split() + if not splitted_line: + continue + elif splitted_line[0] == 'v': + vs.append([float(v) for v in splitted_line[1:4]]) + elif splitted_line[0] == 'f': + face_vertex_ids = [int(c.split('/')[0]) for c in splitted_line[1:]] + assert len(face_vertex_ids) == 3 + face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind) + for ind in face_vertex_ids] + faces.append(face_vertex_ids) + f.close() + vs = np.asarray(vs) + faces = np.asarray(faces, dtype=int) + assert np.logical_and(faces >= 0, faces < len(vs)).all() + return faces, vs + + +import trimesh as tm +def edges_to_path(edges, color=tm.visual.color.random_color()): + lines = np.asarray(edges) + args = tm.path.exchange.misc.lines_to_path(lines) + colors = [color for _ in range(len(args['entities']))] + path = tm.path.Path3D(**args, colors=colors) + return path + + +def show_mesh(edges, vs, label, colors=[[0,0,0,255], [120,120,120,255]]): + colors = np.array(colors) + edges = vs[edges] + tm.Scene([edges_to_path(e, colors[int(l)]) for e, l in zip(edges, label)]).show() + + + +def create_files(path): + for filename in glob.glob(os.path.join(path, 'obj/*.obj')): + print(filename) + basename = os.path.splitext(os.path.basename(filename))[0] + v_label_name = os.path.join(os.path.join(path, 'vseg'), basename + '.eseg') + label_name = os.path.join(os.path.join(path, 'seg'), basename + '.eseg') + export_name_seseg = os.path.join(os.path.join(path, 'sseg'), basename + '.seseg') + export_name_edges = os.path.join(os.path.join(path, 'edges'), basename + '.edges') + + faces, vs = get_obj(filename) + faces, face_areas = remove_non_manifolds(vs, faces) + gemms, edges = get_gemm_edges(faces, export_name_edges) + with open(v_label_name) as f: + v_label = np.array(f.readlines(), dtype=int) + + edge_label = [] + for e in edges: + if v_label[e[0]] == 1 and v_label[e[1]] == 1: + edge_label.append(str(2)) + else: + edge_label.append(str(1)) + + with open(label_name, 'w') as f: + f.write('\n'.join(edge_label)) + print(len(edge_label)) + if os.path.isfile(label_name): + + create_sseg_file(gemms, edge_label, export_name_seseg) + else: + print(label_name, "is no directory") + + +if __name__ == '__main__': + create_files(sys.argv[1]) \ No newline at end of file diff --git a/data/segmentation_data.py b/data/segmentation_data.py index 7d687ae0..ee5db680 100644 --- a/data/segmentation_data.py +++ b/data/segmentation_data.py @@ -5,6 +5,27 @@ import numpy as np from models.layers.mesh import Mesh + +import trimesh as tm +def edges_to_path(edges, color=tm.visual.color.random_color()): + lines = np.asarray(edges) + args = tm.path.exchange.misc.lines_to_path(lines) + colors = [color for _ in range(len(args['entities']))] + path = tm.path.Path3D(**args, colors=colors) + return path + + +def show_mesh(mesh, label, colors=[[0,0,0,255], [120,120,120,255]]): + colors = np.array(colors) + edges = mesh.vs[mesh.edges] + tm.Scene([edges_to_path(e, colors[int(l)]) for e, l in zip(edges, label)]).show() + + +def show_vertices(mesh, label, colors=[[0,0,0,255], [120,120,120,255]]): + colors = np.array(colors) + tm.PointCloud(mesh.vs, colors=np.array(colors)[label]).show() + + class SegmentationData(BaseDataset): def __init__(self, opt): @@ -29,6 +50,7 @@ def __getitem__(self, index): mesh = Mesh(file=path, opt=self.opt, hold_history=True, export_folder=self.opt.export_folder) meta = {} meta['mesh'] = mesh + meta['path'] = path label = read_seg(self.seg_paths[index]) - self.offset label = pad(label, self.opt.ninput_edges, val=-1, dim=0) meta['label'] = label diff --git a/datasets/.gitignore b/datasets/.gitignore new file mode 100644 index 00000000..dc61639c --- /dev/null +++ b/datasets/.gitignore @@ -0,0 +1 @@ +/roof_seg diff --git a/datasets/roof_seg.dvc b/datasets/roof_seg.dvc new file mode 100644 index 00000000..fe5014e9 --- /dev/null +++ b/datasets/roof_seg.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 5ad27ad6392a64493c0976a5ad1a8296.dir + size: 176735174 + nfiles: 404 + path: roof_seg diff --git a/decimate.py b/decimate.py new file mode 100644 index 00000000..c359cb98 --- /dev/null +++ b/decimate.py @@ -0,0 +1,108 @@ +from copy import deepcopy +import os +from collections import OrderedDict + +import trimesh as tm +import numpy as np +import torch + +from options.pl_options import PLOptions +from data import DataLoader +from data.segmentation_data import Mesh +from util.util import pad +from train_pl import MeshSegmenter + + +def show_mesh(mesh, label): + edges = mesh.edges + vertices = mesh.vs + vertex_label = np.zeros(len(vertices)) + for e_l, e in zip(label[0], edges): + if e_l == 1: + vertex_label[e] = 1 + faces = mesh.faces + vertex_colors = np.array([[255, 100, 0, 255], [0, 100, 255, 255]])[vertex_label.astype(int)] + trimesh = tm.Trimesh(faces=faces, vertices=vertices, vertex_colors=vertex_colors) + trimesh.show() + return trimesh + + +def simplify_rooftop(roof_segment: tm.Trimesh, n_triangles) -> tm.Trimesh: + """ + Perform mesh simplificaiton based on desired triangles number + :param roof_segment: Trimesh - submesh of the roof + :param n_triangles: int - number of triangles the simplified mesh would contain + :return: tm.Trimesh - Simplified mesh + """ + n_triangles = max([n_triangles, 5]) + segment = roof_segment.simplify_quadratic_decimation(n_triangles) + + return segment + + +def load_obj(path, opt, mean, std): + mesh = Mesh(file=path, opt=opt, hold_history=True, export_folder=opt.export_folder) + meta = {} + meta['mesh'] = [mesh] + meta['path'] = [path] + edge_features = mesh.extract_features() + edge_features = pad(edge_features, opt.ninput_edges) + edge_features = (edge_features - mean) / std + meta['edge_features'] = np.expand_dims(edge_features, 0) + meta['label'] = np.array([]) + meta['soft_label'] = np.array([]) + return meta + + +def run_decimation(epoch=-1): + opt = PLOptions().parse() + opt.serial_batches = True # no shuffle + dataset = DataLoader(opt) + model = MeshSegmenter(opt) + + device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') + checkpoint = torch.load(opt.model_path, map_location=device) + + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = k.replace('.module', '') + new_state_dict[new_key] = v + model.load_state_dict(new_state_dict) + + for i, data in enumerate(dataset): + if i != 21 and i != 22: + continue + print(i, data['path']) + obj_name = os.path.basename(data['path'][0]) + print(obj_name) + + torch.cuda.empty_cache() + + mesh = deepcopy(data['mesh'][0]) + pred_class = model.forward(data).max(1)[1] + tm_mesh = show_mesh(mesh, label=pred_class) + + torch.cuda.empty_cache() + + for desired_triangle_area in [0.5, 0.8, 1, 1.2, 1.4, 1.7, 2, 2.5, 3.5, 5, 7]: + print(desired_triangle_area) + + tm_mesh_new = simplify_rooftop(tm_mesh, int((tm_mesh.area / desired_triangle_area))) + + new_obj_name = obj_name[:-4] + '_' + str(desired_triangle_area) + '.obj' + obj_path = os.path.join(opt.decimation_dir, new_obj_name) + + with open(obj_path, mode='w') as f: + f.write(tm.exchange.obj.export_obj(tm_mesh_new)) + data_new = load_obj(obj_path, opt, dataset.dataset.mean, dataset.dataset.std) + + mesh_new = deepcopy(data_new['mesh'][0]) + pred_class_new = model.forward(data_new).max(1)[1] + show_mesh(mesh_new, label=pred_class_new) + + torch.cuda.empty_cache() + # os.unlink(f.name) + +if __name__ == '__main__': + run_decimation() diff --git a/models/layers/mesh_pool.py b/models/layers/mesh_pool.py index 394d0fc9..b903ceab 100644 --- a/models/layers/mesh_pool.py +++ b/models/layers/mesh_pool.py @@ -56,6 +56,7 @@ def __pool_main(self, mesh_index): self.__updated_fe[mesh_index] = fe def __pool_edge(self, mesh, edge_id, mask, edge_groups): + # if the edge is a boundary edge or any of its neighbor edges are boundary if self.has_boundaries(mesh, edge_id): return False elif self.__clean_side(mesh, edge_id, mask, edge_groups, 0)\ @@ -74,7 +75,7 @@ def __pool_edge(self, mesh, edge_id, mask, edge_groups): def __clean_side(self, mesh, edge_id, mask, edge_groups, side): if mesh.edges_count <= self.__out_target: return False - invalid_edges = MeshPool.__get_invalids(mesh, edge_id, edge_groups, side) + invalid_edges = MeshPool.__get_invalids(mesh, edge_id, edge_groups, side) # triplet edges sharing the same vertex while len(invalid_edges) != 0 and mesh.edges_count > self.__out_target: self.__remove_triplete(mesh, mask, edge_groups, invalid_edges) if mesh.edges_count <= self.__out_target: @@ -116,6 +117,12 @@ def __pool_side(self, mesh, edge_id, mask, edge_groups, side): def __get_invalids(mesh, edge_id, edge_groups, side): info = MeshPool.__get_face_info(mesh, edge_id, side) key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b = info + + # if we have a separate triangle not connected to anything + if len(set(other_keys_a).intersection([key_a, key_b, edge_id])) == 2 or \ + len(set(other_keys_b).intersection([key_a, key_b, edge_id])) == 2: + return [] + shared_items = MeshPool.__get_shared_items(other_keys_a, other_keys_b) if len(shared_items) == 0: return [] @@ -135,6 +142,7 @@ def __get_invalids(mesh, edge_id, edge_groups, side): MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_a) MeshPool.__union_groups(mesh, edge_groups, key_b, update_key_b) MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_b) + return [key_a, key_b, middle_edge] @staticmethod diff --git a/models/layers/mesh_prepare.py b/models/layers/mesh_prepare.py index 47e827c7..aba184e0 100644 --- a/models/layers/mesh_prepare.py +++ b/models/layers/mesh_prepare.py @@ -11,11 +11,12 @@ def fill_mesh(mesh2fill, file: str, opt): mesh_data = from_scratch(file, opt) np.savez_compressed(load_path, gemm_edges=mesh_data.gemm_edges, vs=mesh_data.vs, edges=mesh_data.edges, edges_count=mesh_data.edges_count, ve=mesh_data.ve, v_mask=mesh_data.v_mask, - filename=mesh_data.filename, sides=mesh_data.sides, + filename=mesh_data.filename, sides=mesh_data.sides, faces=mesh_data.faces, edge_lengths=mesh_data.edge_lengths, edge_areas=mesh_data.edge_areas, features=mesh_data.features) mesh2fill.vs = mesh_data['vs'] mesh2fill.edges = mesh_data['edges'] + mesh2fill.faces = mesh_data['faces'] mesh2fill.gemm_edges = mesh_data['gemm_edges'] mesh2fill.edges_count = int(mesh_data['edges_count']) mesh2fill.ve = mesh_data['ve'] @@ -51,9 +52,9 @@ def __getitem__(self, item): mesh_data.filename = 'unknown' mesh_data.edge_lengths = None mesh_data.edge_areas = [] - mesh_data.vs, faces = fill_from_file(mesh_data, file) + mesh_data.vs, mesh_data.faces = fill_from_file(mesh_data, file) mesh_data.v_mask = np.ones(len(mesh_data.vs), dtype=bool) - faces, face_areas = remove_non_manifolds(mesh_data, faces) + faces, face_areas = remove_non_manifolds(mesh_data, mesh_data.faces) if opt.num_aug > 1: faces = augmentation(mesh_data, opt, faces) build_gemm(mesh_data, faces, face_areas) @@ -155,7 +156,7 @@ def build_gemm(mesh, faces, face_areas): sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1 sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2 mesh.edges = np.array(edges, dtype=np.int32) - mesh.gemm_edges = np.array(edge_nb, dtype=np.int64) + mesh.gemm_edges = np.array(edge_nb, dtype=np.int64) # [n_edges, 4] - matrix of edges and 4 their neighbors mesh.sides = np.array(sides, dtype=np.int64) mesh.edges_count = edges_count mesh.edge_areas = np.array(mesh.edge_areas, dtype=np.float32) / np.sum(face_areas) #todo whats the difference between edge_areas and edge_lenghts? @@ -195,10 +196,12 @@ def slide_verts(mesh, prct): for vi in vids: if shifted < target: edges = mesh.ve[vi] - if min(dihedral[edges]) > 2.65: - edge = mesh.edges[np.random.choice(edges)] - vi_t = edge[1] if vi == edge[0] else edge[0] - nv = mesh.vs[vi] + np.random.uniform(0.2, 0.5) * (mesh.vs[vi_t] - mesh.vs[vi]) + if len(dihedral[edges]) == 0: + continue + if min(dihedral[edges]) > 2.65: # if any 2 adjacent faces for the vi vertex are flat enough + edge = mesh.edges[np.random.choice(edges)] # take one random edge + vi_t = edge[1] if vi == edge[0] else edge[0] # take the opposite vertex + nv = mesh.vs[vi] + np.random.uniform(0.2, 0.5) * (mesh.vs[vi_t] - mesh.vs[vi]) # shift origin vi vertex mesh.vs[vi] = nv shifted += 1 else: @@ -366,18 +369,26 @@ def get_edge_points(mesh): def get_side_points(mesh, edge_id): + """ + Return 4 points indices for each edge. 2 point indices constituting the edge itself, and 2 point indices on the + opposite sides of both triangles sharing the given edge. + [edge_a[0], edge_a[1], opposite_vertex_A, opposite_vertex_B] + + In case the edge lies on the boundary, the third and fourth edge are similar and correspond to the same vertex. + [edge_a[0], edge_a[1], opposite_vertex_A, opposite_vertex_A] + """ # if mesh.gemm_edges[edge_id, side] == -1: # return mesh.get_side_points(edge_id, ((side + 2) % 4)) # else: edge_a = mesh.edges[edge_id] - if mesh.gemm_edges[edge_id, 0] == -1: + if mesh.gemm_edges[edge_id, 0] == -1: # If edge lies on the boundary with LEFT face missing edge_b = mesh.edges[mesh.gemm_edges[edge_id, 2]] edge_c = mesh.edges[mesh.gemm_edges[edge_id, 3]] else: edge_b = mesh.edges[mesh.gemm_edges[edge_id, 0]] edge_c = mesh.edges[mesh.gemm_edges[edge_id, 1]] - if mesh.gemm_edges[edge_id, 2] == -1: + if mesh.gemm_edges[edge_id, 2] == -1: # If edge lies on the boundary with RIGHT face missing edge_d = mesh.edges[mesh.gemm_edges[edge_id, 0]] edge_e = mesh.edges[mesh.gemm_edges[edge_id, 1]] else: diff --git a/models/losses.py b/models/losses.py new file mode 100644 index 00000000..8b271778 --- /dev/null +++ b/models/losses.py @@ -0,0 +1,197 @@ +"""Common image segmentation losses. +""" + +import torch + +from torch.nn import functional as F + + +def bce_loss(true, logits, pos_weight=None): + """Computes the weighted binary cross-entropy loss. + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, 1, H, W]. Corresponds to + the raw output or logits of the model. + pos_weight: a scalar representing the weight attributed + to the positive class. This is especially useful for + an imbalanced dataset. + Returns: + bce_loss: the weighted binary cross-entropy loss. + """ + bce_loss = F.binary_cross_entropy_with_logits( + logits.float(), + true.float(), + pos_weight=pos_weight, + ) + return bce_loss + + +def ce_loss(true, logits, weights, ignore=255): + """Computes the weighted multi-class cross-entropy loss. + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + weight: a tensor of shape [C,]. The weights attributed + to each class. + ignore: the class index to ignore. + Returns: + ce_loss: the weighted multi-class cross-entropy loss. + """ + true = true.squeeze(-1).squeeze(1) + logits = logits.squeeze(-1) + + ce_loss = F.cross_entropy( + logits.float(), + true.long(), + ignore_index=ignore, + weight=weights, + ) + return ce_loss + + +def dice_loss(true, logits, eps=1e-7): + """Computes the Sørensen–Dice loss. + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the dice loss so we + return the negated dice loss. + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: added to the denominator for numerical stability. + Returns: + dice_loss: the Sørensen–Dice loss. + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + dice_loss = (2. * intersection / (cardinality + eps)).mean() + return (1 - dice_loss) + + +def jaccard_loss(true, logits, eps=1e-7): + """Computes the Jaccard loss, a.k.a the IoU loss. + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the jaccard loss so we + return the negated jaccard loss. + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: added to the denominator for numerical stability. + Returns: + jacc_loss: the Jaccard loss. + """ + + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + union = cardinality - intersection + jacc_loss = (intersection / (union + eps)).mean() + return (1 - jacc_loss) + + +def tversky_loss(true, logits, alpha, beta, eps=1e-7): + """Computes the Tversky loss [1]. + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + alpha: controls the penalty for false positives. + beta: controls the penalty for false negatives. + eps: added to the denominator for numerical stability. + Returns: + tversky_loss: the Tversky loss. + Notes: + alpha = beta = 0.5 => dice coeff + alpha = beta = 1 => tanimoto coeff + alpha + beta = 1 => F beta coeff + References: + [1]: https://arxiv.org/abs/1706.05721 + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + fps = torch.sum(probas * (1 - true_1_hot), dims) + fns = torch.sum((1 - probas) * true_1_hot, dims) + num = intersection + denom = intersection + (alpha * fps) + (beta * fns) + tversky_loss = (num / (denom + eps)).mean() + return (1 - tversky_loss) + + +def ce_dice(true, pred, weights=torch.tensor([0.5, 2])): + if weights is not None: + weights = torch.tensor(weights).to(pred.device) + + return ce_loss(true, pred, weights) + \ + dice_loss(true, pred) + + +def ce_jaccard(true, pred, weights=torch.tensor([0.5, 2])): + if weights is not None: + weights = torch.tensor(weights).to(pred.device) + + return ce_loss(true, pred, weights) + \ + jaccard_loss(true, pred) + + +def focal_loss(true, pred): + pass + + +def postprocess(true, pred): + num_classses = pred.shape[1] + true = true.view(-1) + pred = pred.view(num_classses, -1) + not_padding = true != -1 + true = true[not_padding] + pred = pred[:, not_padding] + true = true.view(1, 1, -1, 1) + pred = pred.view(1, num_classses, -1, 1) + return true, pred \ No newline at end of file diff --git a/models/mesh_classifier.py b/models/mesh_classifier.py index 9ce50cb3..fd382621 100644 --- a/models/mesh_classifier.py +++ b/models/mesh_classifier.py @@ -1,4 +1,6 @@ import torch +import torchmetrics + from . import networks from os.path import join from util.util import seg_accuracy, print_network @@ -32,7 +34,7 @@ def __init__(self, opt): self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt, self.gpu_ids, opt.arch, opt.init_type, opt.init_gain) self.net.train(self.is_train) - self.criterion = networks.define_loss(opt).to(self.device) + self.criterion = networks.define_loss(opt) if self.is_train: self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) @@ -58,7 +60,7 @@ def forward(self): return out def backward(self, out): - self.loss = self.criterion(out, self.labels) + self.loss = self.criterion(self.labels, out) self.loss.backward() def optimize_parameters(self): @@ -74,17 +76,23 @@ def load_network(self, which_epoch): """load model from disk""" save_filename = '%s_net.pth' % which_epoch load_path = join(self.save_dir, save_filename) + self.load_weights(load_path) + + def load_weights(self, load_path): net = self.net if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device - state_dict = torch.load(load_path, map_location=str(self.device)) - if hasattr(state_dict, '_metadata'): - del state_dict._metadata - net.load_state_dict(state_dict) + saved_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(saved_dict, '_metadata'): + del saved_dict._metadata + current_dict = net.state_dict() + filtered_dict = {k: v for k, v in saved_dict.items() if saved_dict[k].shape == current_dict[k].shape} + current_dict.update(filtered_dict) + net.load_state_dict(current_dict) def save_network(self, which_epoch): """save model to disk""" @@ -115,6 +123,20 @@ def test(self): correct = self.get_accuracy(pred_class, label_class) return correct, len(label_class) + def get_metrics(self, acc_metric, f1_metric, iou_metric): + with torch.no_grad(): + out = self.forward() + pred_class = out.data.max(1)[1] + label_class = self.labels + label_class[label_class == -1] = 0 + pred_class = pred_class.to(self.device) + label_class = label_class.to(self.device) + + acc = acc_metric(pred_class, label_class) + f1 = f1_metric(pred_class, label_class) + iou = iou_metric(pred_class, label_class) + return acc, f1, iou + def get_accuracy(self, pred, labels): """computes accuracy for classification / segmentation """ if self.opt.dataset_mode == 'classification': diff --git a/models/networks.py b/models/networks.py index c2a13e2e..8e4634e9 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from models.layers.mesh_pool import MeshPool from models.layers.mesh_unpool import MeshUnpool +from .losses import ce_jaccard, dice_loss, jaccard_loss, ce_loss, ce_dice ############################################################################### @@ -87,7 +88,7 @@ def init_net(net, init_type, init_gain, gpu_ids): assert(torch.cuda.is_available()) net.cuda(gpu_ids[0]) net = net.cuda() - net = torch.nn.DataParallel(net, gpu_ids) + net = torch.nn.DataParallel(net, gpu_ids if gpu_ids else [0]) if init_type != 'none': init_weights(net, init_type, init_gain) return net @@ -114,7 +115,18 @@ def define_loss(opt): if opt.dataset_mode == 'classification': loss = torch.nn.CrossEntropyLoss() elif opt.dataset_mode == 'segmentation': - loss = torch.nn.CrossEntropyLoss(ignore_index=-1) + device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') + weights = torch.FloatTensor(opt.loss_weights).to(device) + + losses = { + 'ce': functools.partial(ce_loss, weights=weights), + 'dice': dice_loss, + 'jaccard': jaccard_loss, + 'ce_dice': functools.partial(ce_dice, weights=weights), + 'ce_jaccard': functools.partial(ce_jaccard, weights=weights) + } + + loss = losses.get(opt.loss) return loss ############################################################################## diff --git a/options/base_options.py b/options/base_options.py index 61f21ce0..1d053a65 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -11,7 +11,7 @@ def __init__(self): def initialize(self): # data params - self.parser.add_argument('--dataroot', required=True, help='path to meshes (should have subfolders train, test)') + self.parser.add_argument('--dataroot', required=False, help='path to meshes (should have subfolders train, test)') self.parser.add_argument('--dataset_mode', choices={"classification", "segmentation"}, default='classification') self.parser.add_argument('--ninput_edges', type=int, default=750, help='# of input edges (will include dummy edges)') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples per epoch') @@ -26,6 +26,11 @@ def initialize(self): self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + self.parser.add_argument('--loss', type=str, default='ce_dice', + help='loss function; possible values: ce, dice, jaccard, ce_dice, ce_jaccard') + self.parser.add_argument('--loss_weights', nargs='+', default=[0.5, 2], type=float, + help='weights for loss function, used only with ce/ce_dice/ce_jaccard losses') + # general params self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') @@ -38,10 +43,10 @@ def initialize(self): # self.initialized = True - def parse(self): + def parse(self, args=None): if not self.initialized: self.initialize() - self.opt, unknown = self.parser.parse_known_args() + self.opt, unknown = self.parser.parse_known_args(args) self.opt.is_train = self.is_train # train or test str_ids = self.opt.gpu_ids.split(',') diff --git a/options/pl_options.py b/options/pl_options.py new file mode 100644 index 00000000..e2cd7949 --- /dev/null +++ b/options/pl_options.py @@ -0,0 +1,24 @@ +from .train_options import TrainOptions + +class PLOptions(TrainOptions): + def initialize(self): + TrainOptions.initialize(self) + self.parser.add_argument('--gpus', type=int, default=1) + self.parser.add_argument('--max_epochs', type=int, default=200) + self.parser.add_argument('--warmup_epochs', type=int, default=50) + self.parser.add_argument('--nclasses', type=int, default=2) + self.parser.add_argument('--input_nc', type=int, default=5) + self.parser.add_argument('--class_weights', nargs='+', default=[0.5, 2], type=float) + self.parser.add_argument('--from_pretrained', type=str, default=None) + self.parser.add_argument('--optimizer', choices=['adam', 'sgd', 'adamw'], type=str, default='adam') + self.parser.add_argument('--weight_decay', type=float, default=0.0002) + + self.parser.add_argument('--progress_bar_refresh_rate', type=int, default=20) + self.parser.add_argument('--default_root_dir', default='checkpoints/', + help='pytorch-lightning log path') + # options used for decimation script only + self.parser.add_argument('--model_path', default='checkpoints/lightning_logs/version_0/checkpoints/epoch=95-val_acc_epoch=0.00.ckpt', + help='.ckpt file with trained model') + self.parser.add_argument('--decimation_dir', + default='datasets/roof_seg/obj_new', + help='augmented meshes are saved here') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..9a449a63 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +pytorch-lightning==1.2.9 +ray==1.9.0 +trimesh==3.8.19 +open3d==0.11.2 \ No newline at end of file diff --git a/scripts/roof_seg/train.sh b/scripts/roof_seg/train.sh new file mode 100644 index 00000000..78aa7f5a --- /dev/null +++ b/scripts/roof_seg/train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +## run the training +python train_pl.py \ +--dataroot datasets/roof_seg \ +--name roof_seg \ +--arch meshunet \ +--dataset_mode segmentation \ +--ncf 32 64 128 256 \ +--ninput_edges 16000 \ +--pool_res 12000 10500 9000 \ +--resblocks 3 \ +--batch_size 1 \ +--lr 0.005 \ +--num_aug 20 \ +--slide_verts 0.2 \ +--warmup_epochs 300 + +#--from_pretrained checkpoints/coseg_aliens/latest_net.pth \ No newline at end of file diff --git a/scripts/roof_seg/tune.sh b/scripts/roof_seg/tune.sh new file mode 100644 index 00000000..1c9b3521 --- /dev/null +++ b/scripts/roof_seg/tune.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +## run the training +python tuning.py \ +--dataroot $(pwd)/datasets/roof_seg \ +--name roof_seg \ +--arch meshunet \ +--dataset_mode segmentation \ +--ninput_edges 16000 \ +--pool_res 12000 10500 9000 \ +--batch_size 1 \ +--num_aug 20 \ +--max_epochs 100 +#--from_pretrained checkpoints/coseg_aliens/latest_net.pth \ No newline at end of file diff --git a/test.py b/test.py index 15492f5b..708e386d 100644 --- a/test.py +++ b/test.py @@ -1,25 +1,102 @@ +from copy import deepcopy + +import numpy as np +import torch +import torchmetrics +import trimesh as tm + from options.test_options import TestOptions from data import DataLoader from models import create_model from util.writer import Writer -def run_test(epoch=-1): +def edges_to_path(edges, color=tm.visual.color.random_color()): + lines = np.asarray(edges) + args = tm.path.exchange.misc.lines_to_path(lines) + colors = [color for _ in range(len(args['entities']))] + path = tm.path.Path3D(**args, colors=colors) + return path + + +def show_edges(mesh, label, colors=[[0,0,0,255], [120,120,120,255]]): + colors = np.array(colors) + edges = mesh.vs[mesh.edges] + tm.Scene([edges_to_path(e, colors[int(l)]) for e, l in zip(edges, label)]).show() + + +# def run_test(epoch=-1): +# print('Running Test') +# opt = TestOptions().parse() +# opt.serial_batches = True # no shuffle +# dataset = DataLoader(opt) +# model = create_model(opt) +# writer = Writer(opt) +# # test +# writer.reset_counter() +# for i, data in enumerate(dataset): +# mesh = deepcopy(data['mesh'][0]) +# model.set_input(data) +# +# # pred_class = model.forward().max(1)[1] +# # # show_mesh(mesh, pred_class[0]) +# # edges = mesh.edges +# # vertices = mesh.vs +# # vertex_label = np.zeros(len(vertices)) +# # for e_l, e in zip(pred_class[0], edges): +# # if e_l == 1: +# # vertex_label[e] = 1 +# # faces = mesh.faces +# # vertex_colors = np.array([tm.visual.random_color(), tm.visual.random_color()])[vertex_label.astype(int)] +# # tm.Trimesh(faces=faces, vertices=vertices, vertex_colors=vertex_colors).show() +# +# ncorrect, nexamples = model.test() +# +# writer.update_counter(ncorrect, nexamples) +# writer.print_acc(epoch, writer.acc) +# return writer.acc + + +def run_test(epoch=-1, data_phase='test'): print('Running Test') opt = TestOptions().parse() opt.serial_batches = True # no shuffle + opt.phase = data_phase dataset = DataLoader(opt) model = create_model(opt) writer = Writer(opt) # test writer.reset_counter() - for i, data in enumerate(dataset): - model.set_input(data) - ncorrect, nexamples = model.test() - writer.update_counter(ncorrect, nexamples) - writer.print_acc(epoch, writer.acc) - return writer.acc + + metrics = [ + torchmetrics.Accuracy(num_classes=2, average='macro').to(model.device), + torchmetrics.Accuracy(num_classes=2).to(model.device), + torchmetrics.IoU(num_classes=2).to(model.device), + torchmetrics.F1(num_classes=2, average='macro').to(model.device) + ] + with torch.no_grad(): + for i, data in enumerate(dataset): + model.set_input(data) + out = model.forward() + pred_class = out.data.max(1)[1] + label_class = model.labels + pred_class = pred_class.to(model.device) + label_class = label_class.to(model.device) + not_padding = label_class != -1 + label_class = label_class[not_padding] + pred_class = pred_class[not_padding] + + for m in metrics: + m(pred_class, label_class) + # print(f"Metrics on 3D model {i} - accuracy: {acc}, F1: {f1}, IoU: {iou}") + # writer.print_acc(epoch, writer.acc) + metric_vals = [] + for m in metrics: + m_name = str(m).split('(')[0] + metric_vals.append(f'{m_name}: {m.compute()}') + metrics_str = ' '.join(metric_vals) + print(f'epoch: {epoch}, {data_phase.upper()} {metrics_str}') if __name__ == '__main__': - run_test() + run_test() \ No newline at end of file diff --git a/train.py b/train.py index 41b326b7..27e28a60 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,8 @@ from models import create_model from util.writer import Writer from test import run_test +import torch + if __name__ == '__main__': opt = TrainOptions().parse() @@ -21,6 +23,7 @@ epoch_iter = 0 for i, data in enumerate(dataset): + iter_start_time = time.time() if total_steps % opt.print_freq == 0: t_data = iter_start_time - iter_data_time @@ -45,7 +48,7 @@ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save_network('latest') - model.save_network(epoch) + # model.save_network(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) @@ -54,7 +57,10 @@ writer.plot_model_wts(model, epoch) if epoch % opt.run_test_freq == 0: - acc = run_test(epoch) - writer.plot_acc(acc, epoch) + run_test(epoch, 'train') + run_test(epoch, 'test') + # writer.plot_acc(acc, epoch) + + torch.cuda.empty_cache() writer.close() diff --git a/train_pl.py b/train_pl.py new file mode 100644 index 00000000..cbe6db9b --- /dev/null +++ b/train_pl.py @@ -0,0 +1,126 @@ +import argparse +import os +import random + +import torch +import pytorch_lightning as pl +from torch.utils.data import Dataset, DataLoader +import glob +import json +import numpy as np +import torchmetrics +from options.pl_options import PLOptions +from data import DataLoader +from models import create_model +from models.losses import postprocess +from models.losses import ce_jaccard +import warnings +from models import networks +from models.mesh_classifier import ClassifierModel +warnings.filterwarnings("ignore") + + +class MeshSegmenter(pl.LightningModule, ClassifierModel): + + def __init__(self, opt): + pl.LightningModule.__init__(self) + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.optimizer = None + self.edge_features = None + self.labels = None + self.mesh = None + self.soft_label = None + self.loss = None + self.nclasses = opt.nclasses + + # load/define networks + self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt, + self.gpu_ids, opt.arch, opt.init_type, opt.init_gain) + self.criterion = networks.define_loss(opt) + if opt.from_pretrained is not None: + print('Loaded pretrained weights:', opt.from_pretrained) + self.load_weights(opt.from_pretrained) + if self.training: + self.train_metrics = torch.nn.ModuleList([ + torchmetrics.Accuracy(),# (num_classes=opt.nclasses, average='macro'), + torchmetrics.IoU(num_classes=opt.nclasses), + torchmetrics.F1(num_classes=opt.nclasses, average='macro') + ]) + self.val_metrics = torch.nn.ModuleList([ + torchmetrics.Accuracy(), #num_classes=opt.nclasses, average='macro'), + torchmetrics.IoU(num_classes=opt.nclasses), + torchmetrics.F1(num_classes=opt.nclasses, average='macro') + ]) + + def step(self, batch, metrics, metric_prefix=''): + out = self.forward(batch) + true, pred = postprocess(self.labels, out) + loss = self.criterion(true, pred) + + true = true.view(-1) + pred = pred.argmax(1).view(-1) + + prefix = metric_prefix + for m in metrics: + val = m(pred, true) + metric_name = str(m).split('(')[0] + self.log(prefix + metric_name.lower(), val, logger=True, prog_bar=True, on_epoch=True) + self.log(prefix + 'loss', loss, on_epoch=True) + return loss + + def training_step(self, batch, idx): + + return self.step(batch, self.train_metrics) + + def validation_step(self, batch, idx): + return self.step(batch, self.val_metrics, metric_prefix='val_') + + def forward(self, data): + input_edge_features = torch.from_numpy(data['edge_features']).float() + if 'label' in data: + self.labels = torch.from_numpy(data['label']).long().to(self.device) + self.edge_features = input_edge_features.to(self.device).requires_grad_(self.training) + self.mesh = data['mesh'] + return self.net(self.edge_features, self.mesh) + + def on_train_epoch_end(self, unused=None): + for m in self.train_metrics: + m.reset() + + def on_validation_epoch_end(self) -> None: + for m in self.val_metrics: + m.reset() + + def train_dataloader(self): + self.opt.phase = 'train' + return DataLoader(self.opt) + + def val_dataloader(self): + self.opt.phase = 'test' + return DataLoader(self.opt) + + def configure_optimizers(self): + if self.opt.optimizer == 'adam': + opt = torch.optim.Adam(self.net.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay) + elif self.opt.optimizer == 'sgd': + opt = torch.optim.SGD(self.net.parameters(), lr=self.opt.lr, + momentum=0.9, + weight_decay=self.opt.weight_decay) + elif self.opt.optimizer == 'adamw': + opt = torch.optim.AdamW(self.net.parameters(), lr=self.opt.lf, weight_decay=self.opt.weight_decay) + sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, self.opt.warmup_epochs) + # sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, self.opt.max_epochs * 2) + return [opt], [sched] + + +if __name__ == '__main__': + from pytorch_lightning.callbacks import ModelCheckpoint + args = PLOptions().parse() + model = MeshSegmenter(args) + trainer = pl.Trainer.from_argparse_args(args, + callbacks=[ModelCheckpoint(monitor='val_iou', + mode='max', + save_top_k=3, + filename='{epoch:02d}-{val_iou:.2f}',)]) + trainer.fit(model) \ No newline at end of file diff --git a/tuning.py b/tuning.py new file mode 100644 index 00000000..3a0cf54a --- /dev/null +++ b/tuning.py @@ -0,0 +1,61 @@ +import json + +import torch.cuda +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import Trainer +from train_pl import * +from ray import tune +from ray.tune.integration.pytorch_lightning import TuneReportCallback + +args = PLOptions().parse() + +def train_segmentation(config): + for k, v in config.items(): + args.__dict__[k] = v + + model = MeshSegmenter(args) + callback_tune = TuneReportCallback(metrics='val_iou', on="validation_end") + callback_lightning = ModelCheckpoint(monitor='val_iou', mode='max', save_top_k=3, + filename='{epoch:02d}-{val_iou:.2f}', ) + + # callback_tune_f1 = TuneReportCallback(metrics='val_f1', on="validation_end") + # callback_lightning_f1 = ModelCheckpoint(monitor='val_f1', mode='max', save_top_k=3, + # filename='{epoch:02d}-{val_acc_epoch:.2f}', ) + + trainer = Trainer.from_argparse_args(args, callbacks=[callback_tune, callback_lightning]) + trainer.fit(model) + torch.cuda.empty_cache() + + +if __name__== '__main__': + # Execute the hyperparameter search + config = { + 'resblocks': tune.grid_search([2, 3, 4]), + 'ncf': tune.grid_search([[64, 128, 256, 512], [32, 64, 128, 256], [16, 32, 64, 128]]), + 'slide_verts': tune.grid_search([0.1, 0.2]), + 'lr': tune.grid_search([0.01, 0.001]), + 'optimizer': tune.grid_search(['adam', 'sgd', 'adamw']), + 'warmup_epochs': tune.grid_search([200, 100, 50]), + 'weight_decay': tune.grid_search([0, 0.0002]), + } + + ## CPU only + # analysis = tune.run( + # train_segmentation, + # config=config, num_samples=1, resources_per_trial={"cpu": 1}, mode='max') + + # GPU + analysis = tune.run( + tune.with_parameters(train_segmentation), + config=config, num_samples=1, resources_per_trial={"gpu": 1, 'cpu': 1}) + + # Saving the results + best_config = analysis.get_best_config(metric='val_iou', mode="max") + print("Best config: ", best_config) + + file = open(os.path.join(args.checkpoints_dir, 'roof_seg', 'best_config.json'), 'w') + json.dump(best_config, file) + + # Get a dataframe for analyzing trial results. + df = analysis.results_df + print(df) \ No newline at end of file diff --git a/util/mesh_viewer.py b/util/mesh_viewer.py index c7214fb8..b0d34dc5 100644 --- a/util/mesh_viewer.py +++ b/util/mesh_viewer.py @@ -118,6 +118,8 @@ def fix_vertices(): if len(splitted_line) >= 4: edge_v = [int(c) - 1 for c in splitted_line[1:-1]] edge_c = int(splitted_line[-1]) + if edge_c < 0: + continue add_to_edges() vs = V(vs) @@ -142,8 +144,7 @@ def view_meshes(*files, offset=.2): if __name__=='__main__': import argparse parser = argparse.ArgumentParser("view meshes") - parser.add_argument('--files', nargs='+', default=['checkpoints/human_seg/meshes/shrec__14_0.obj', - 'checkpoints/human_seg/meshes/shrec__14_3.obj'], type=str, + parser.add_argument('--files', nargs='+', default=['/home/ihahanov/Projects/meshcnn/checkpoints/roof_seg/meshes/basnett_0.obj'], type=str, help="list of 1 or more .obj files") args = parser.parse_args() diff --git a/util/util.py b/util/util.py index 562c22f6..db42c733 100644 --- a/util/util.py +++ b/util/util.py @@ -66,3 +66,18 @@ def calculate_entropy(np_array): entropy -= a * np.log(a) entropy /= np.log(np_array.shape[0]) return entropy + + +def remove_padding(label_class, pred_class): + num_classes = pred_class.size()[1] + label_class, pred_class = label_class.flatten(), pred_class.flatten() + + not_padding = label_class != -1 + label_class = label_class[not_padding] + label_class = label_class.view(1, -1) + + not_padding = not_padding.repeat(num_classes) + pred_class = pred_class[not_padding] + pred_class = pred_class.view(1, num_classes, -1) + + return label_class, pred_class \ No newline at end of file diff --git a/visualize.py b/visualize.py new file mode 100644 index 00000000..d30ba4b2 --- /dev/null +++ b/visualize.py @@ -0,0 +1,60 @@ +from options.test_options import TestOptions +from data import DataLoader +from models import create_model +from util.writer import Writer +import trimesh as tm +import numpy as np +from copy import deepcopy +from data.segmentation_data import show_mesh +import trimesh as tm +import torch + + +def edges_to_path(edges, color=tm.visual.color.random_color()): + lines = np.asarray(edges) + args = tm.path.exchange.misc.lines_to_path(lines) + colors = [color for _ in range(len(args['entities']))] + path = tm.path.Path3D(**args, colors=colors) + return path + + +def show_edges(mesh, label, colors=[[0, 0, 0, 255], [120, 120, 120, 255]]): + colors = np.array(colors) + edges = mesh.vs[mesh.edges] + tm.Scene([edges_to_path(e, colors[int(l)]) for e, l in zip(edges, label)]).show() + + +def show_mesh(mesh, label): + edges = mesh.edges + vertices = mesh.vs + vertex_label = np.zeros(len(vertices)) + for e_l, e in zip(label[0], edges): + if e_l == 1: + vertex_label[e] = 1 + faces = mesh.faces + vertex_colors = np.array([[255, 100, 0, 255], [0, 100, 255, 255]])[vertex_label.astype(int)] + tm.Trimesh(faces=faces, vertices=vertices, vertex_colors=vertex_colors).show() + +def run_test(epoch=-1): + print('Running Test') + opt = TestOptions().parse() + opt.serial_batches = True # no shuffle + dataset = DataLoader(opt) + model = create_model(opt) + writer = Writer(opt) + # test + writer.reset_counter() + for i, data in enumerate(dataset): + torch.cuda.empty_cache() + mesh = deepcopy(data['mesh'][0]) + + # show_mesh(mesh, data['label'][0]) + model.set_input(data) + # + pred_class = model.forward().max(1)[1] + # show_mesh(mesh, pred_class[0]) + show_mesh(mesh, label=pred_class) + torch.cuda.empty_cache() + +if __name__ == '__main__': + run_test() diff --git a/visualize_annotations.py b/visualize_annotations.py new file mode 100644 index 00000000..4c2cbc05 --- /dev/null +++ b/visualize_annotations.py @@ -0,0 +1,37 @@ +from options.test_options import TestOptions +from data import DataLoader +from models import create_model +from util.writer import Writer +import trimesh as tm +import numpy as np +from copy import deepcopy +from visualize import show_mesh +import trimesh as tm + + +def edges_to_path(edges, color=tm.visual.color.random_color()): + lines = np.asarray(edges) + args = tm.path.exchange.misc.lines_to_path(lines) + colors = [color for _ in range(len(args['entities']))] + path = tm.path.Path3D(**args, colors=colors) + return path + + +def show_edges(mesh, label, colors=[[0, 0, 0, 255], [120, 120, 120, 255]]): + colors = np.array(colors) + edges = mesh.vs[mesh.edges] + tm.Scene([edges_to_path(e, colors[int(l)]) for e, l in zip(edges, label)]).show() + + +def run_test(epoch=-1): + print('Running Test') + opt = TestOptions().parse() + opt.serial_batches = True # no shuffle + dataset = DataLoader(opt) + for i, data in enumerate(dataset): + print(data['path']) + + show_mesh(data['mesh'][0], data['label']) + +if __name__ == '__main__': + run_test()