diff --git a/metrics.py b/metrics.py index 50efc3d..d309de1 100644 --- a/metrics.py +++ b/metrics.py @@ -23,12 +23,12 @@ def compute_imagewise_retrieval_metrics( auroc = metrics.roc_auc_score( anomaly_ground_truth_labels, anomaly_prediction_weights ) - + precision, recall, _ = metrics.precision_recall_curve( anomaly_ground_truth_labels, anomaly_prediction_weights ) auc_pr = metrics.auc(recall, precision) - + return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds} @@ -88,7 +88,7 @@ def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_mask def compute_pro(masks, amaps, num_th=200): df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) - binary_amaps = np.zeros_like(amaps, dtype=np.bool) + binary_amaps = np.zeros_like(amaps, dtype=np.bool_) min_th = amaps.min() max_th = amaps.max() @@ -112,11 +112,11 @@ def compute_pro(masks, amaps, num_th=200): fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() fpr = fp_pixels / inverse_masks.sum() - df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) + df = df._append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 df = df[df["fpr"] < 0.3] df["fpr"] = df["fpr"] / df["fpr"].max() pro_auc = metrics.auc(df["fpr"], df["pro"]) - return pro_auc \ No newline at end of file + return pro_auc diff --git a/run.sh b/run.sh index cc12de4..12262e2 100644 --- a/run.sh +++ b/run.sh @@ -1,9 +1,11 @@ -datapath=/data4/MVTec_ad -datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather') +# datapath=/data4/MVTec_ad +datapath=/content +# datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather') +datasets=('pill') dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done)) python3 main.py \ ---gpu 4 \ +--gpu 1 \ --seed 0 \ --log_group simplenet_mvtec \ --log_project MVTecAD_Results \ @@ -16,7 +18,7 @@ net \ --pretrain_embed_dimension 1536 \ --target_embed_dimension 1536 \ --patchsize 3 \ ---meta_epochs 40 \ +--meta_epochs 15 \ --embedding_size 256 \ --gan_epochs 4 \ --noise_std 0.015 \ diff --git a/simplenet.py b/simplenet.py index ff5e3b2..9788c23 100644 --- a/simplenet.py +++ b/simplenet.py @@ -57,10 +57,10 @@ def forward(self,x): class Projection(torch.nn.Module): - + def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): super(Projection, self).__init__() - + if out_planes is None: out_planes = in_planes self.layers = torch.nn.Sequential() @@ -68,31 +68,31 @@ def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): _out = None for i in range(n_layers): _in = in_planes if i == 0 else _out - _out = out_planes - self.layers.add_module(f"{i}fc", + _out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out)) if i < n_layers - 1: # if layer_type > 0: - # self.layers.add_module(f"{i}bn", + # self.layers.add_module(f"{i}bn", # torch.nn.BatchNorm1d(_out)) if layer_type > 1: self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2)) self.apply(init_weight) - + def forward(self, x): - + # x = .1 * self.layers(x) + x x = self.layers(x) return x class TBWrapper: - + def __init__(self, log_dir): self.g_iter = 0 self.logger = SummaryWriter(log_dir=log_dir) - + def step(self): self.g_iter += 1 @@ -111,7 +111,7 @@ def load( pretrain_embed_dimension, # 1536 target_embed_dimension, # 1536 patchsize=3, # 3 - patchstride=1, + patchstride=1, embedding_size=None, # 256 meta_epochs=1, # 40 aed_meta_epochs=1, @@ -195,7 +195,7 @@ def show_mem(): self.discriminator.to(self.device) self.dsc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.dsc_lr, weight_decay=1e-5) self.dsc_schl = torch.optim.lr_scheduler.CosineAnnealingLR(self.dsc_opt, (meta_epochs - aed_meta_epochs) * gan_epochs, self.dsc_lr*.4) - self.dsc_margin= dsc_margin + self.dsc_margin= dsc_margin self.model_dir = "" self.dataset_name = "" @@ -204,14 +204,14 @@ def show_mem(): def set_model_dir(self, model_dir, dataset_name): - self.model_dir = model_dir + self.model_dir = model_dir os.makedirs(self.model_dir, exist_ok=True) self.ckpt_dir = os.path.join(self.model_dir, dataset_name) os.makedirs(self.ckpt_dir, exist_ok=True) self.tb_dir = os.path.join(self.ckpt_dir, "tb") os.makedirs(self.tb_dir, exist_ok=True) self.logger = TBWrapper(self.tb_dir) #SummaryWriter(log_dir=tb_dir) - + def embed(self, data): if isinstance(data, torch.utils.data.DataLoader): @@ -276,16 +276,16 @@ def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=Fal _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) features[i] = _features features = [x.reshape(-1, *x.shape[-3:]) for x in features] - + # As different feature backbones & patching provide differently # sized features, these are brought into the correct form here. features = self.forward_modules["preprocessing"](features) # pooling each feature to same channel and stack together - features = self.forward_modules["preadapt_aggregator"](features) # further pooling + features = self.forward_modules["preadapt_aggregator"](features) # further pooling return features, patch_shapes - + def test(self, training_data, test_data, save_segmentation_images): ckpt_path = os.path.join(self.ckpt_dir, "models.ckpt") @@ -328,7 +328,7 @@ def test(self, training_data, test_data, save_segmentation_images): if save_segmentation_images: self.save_segmentation_images(test_data, segmentations, scores) - + auroc = metrics.compute_imagewise_retrieval_metrics( scores, anomaly_labels )["auroc"] @@ -340,9 +340,9 @@ def test(self, training_data, test_data, save_segmentation_images): full_pixel_auroc = pixel_scores["auroc"] return auroc, full_pixel_auroc - + def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks_gt): - + scores = np.squeeze(np.array(scores)) img_min_scores = scores.min(axis=-1) img_max_scores = scores.max(axis=-1) @@ -350,7 +350,7 @@ def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks # scores = np.mean(scores, axis=0) auroc = metrics.compute_imagewise_retrieval_metrics( - scores, labels_gt + scores, labels_gt )["auroc"] if len(masks_gt) > 0: @@ -377,18 +377,18 @@ def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks # segmentations, masks_gt full_pixel_auroc = pixel_scores["auroc"] - pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), + pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations) else: - full_pixel_auroc = -1 + full_pixel_auroc = -1 pro = -1 return auroc, full_pixel_auroc, pro - - + + def train(self, training_data, test_data): - + state_dict = {} ckpt_path = os.path.join(self.ckpt_dir, "ckpt.pth") if os.path.exists(ckpt_path): @@ -403,21 +403,25 @@ def train(self, training_data, test_data): self.predict(training_data, "train_") scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data) auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt) - + return auroc, full_pixel_auroc, anomaly_pixel_auroc - + def update_state_dict(d): - + state_dict["discriminator"] = OrderedDict({ - k:v.detach().cpu() + k:v.detach().cpu() for k, v in self.discriminator.state_dict().items()}) if self.pre_proj > 0: state_dict["pre_projection"] = OrderedDict({ - k:v.detach().cpu() + k:v.detach().cpu() for k, v in self.pre_projection.state_dict().items()}) best_record = None for i_mepoch in range(self.meta_epochs): + print('----------------------------') + print('----------------------------') + print('----------------------------') + print(f'Epoch: {i_mepoch + 1}') self._train_discriminator(training_data) @@ -439,23 +443,24 @@ def update_state_dict(d): # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()}) elif auroc == best_record[0] and full_pixel_auroc > best_record[1]: best_record[1] = full_pixel_auroc - best_record[2] = pro + best_record[2] = pro update_state_dict(state_dict) # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()}) print(f"----- {i_mepoch} I-AUROC:{round(auroc, 4)}(MAX:{round(best_record[0], 4)})" f" P-AUROC{round(full_pixel_auroc, 4)}(MAX:{round(best_record[1], 4)}) -----" f" PRO-AUROC{round(pro, 4)}(MAX:{round(best_record[2], 4)}) -----") - + torch.save(state_dict, ckpt_path) + torch.save(state_dict, ckpt_path) - + return best_record - + def _train_discriminator(self, input_data): """Computes and sets the support features for SPADE.""" _ = self.forward_modules.eval() - + if self.pre_proj > 0: self.pre_projection.train() self.discriminator.train() @@ -483,7 +488,7 @@ def _train_discriminator(self, input_data): true_feats = self.pre_projection(self._embed(img, evaluation=False)[0]) else: true_feats = self._embed(img, evaluation=False)[0] - + noise_idxs = torch.randint(0, self.mix_noise, torch.Size([true_feats.shape[0]])) noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=self.mix_noise).to(self.device) # (N, K) noise = torch.stack([ @@ -495,7 +500,7 @@ def _train_discriminator(self, input_data): scores = self.discriminator(torch.cat([true_feats, fake_feats])) true_scores = scores[:len(true_feats)] fake_scores = scores[len(fake_feats):] - + th = self.dsc_margin p_true = (true_scores.detach() >= th).sum() / len(true_scores) p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores) @@ -516,17 +521,17 @@ def _train_discriminator(self, input_data): self.backbone_opt.step() self.dsc_opt.step() - loss = loss.detach().cpu() + loss = loss.detach().cpu() all_loss.append(loss.item()) all_p_true.append(p_true.cpu().item()) all_p_fake.append(p_fake.cpu().item()) - + if len(embeddings_list) > 0: self.auto_noise[1] = torch.cat(embeddings_list).std(0).mean(-1) - + if self.cos_lr: self.dsc_schl.step() - + all_loss = sum(all_loss) / len(input_data) all_p_true = sum(all_p_true) / len(input_data) all_p_fake = sum(all_p_fake) / len(input_data) @@ -584,7 +589,7 @@ def _predict(self, images): self.discriminator.eval() with torch.no_grad(): features, patch_shapes = self._embed(images, - provide_patch_shapes=True, + provide_patch_shapes=True, evaluation=True) if self.pre_proj > 0: features = self.pre_projection(features) diff --git a/utils.py b/utils.py index e489f3f..8fdbb81 100644 --- a/utils.py +++ b/utils.py @@ -102,7 +102,7 @@ def set_torch_device(gpu_ids): if len(gpu_ids): # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids[0]) - return torch.device("cuda:{}".format(gpu_ids[0])) + return torch.device("cuda:{}".format(gpu_ids[0]-1)) return torch.device("cpu")