Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def compute_imagewise_retrieval_metrics(
auroc = metrics.roc_auc_score(
anomaly_ground_truth_labels, anomaly_prediction_weights
)

precision, recall, _ = metrics.precision_recall_curve(
anomaly_ground_truth_labels, anomaly_prediction_weights
)
auc_pr = metrics.auc(recall, precision)

return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds}


Expand Down Expand Up @@ -88,7 +88,7 @@ def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_mask
def compute_pro(masks, amaps, num_th=200):

df = pd.DataFrame([], columns=["pro", "fpr", "threshold"])
binary_amaps = np.zeros_like(amaps, dtype=np.bool)
binary_amaps = np.zeros_like(amaps, dtype=np.bool_)

min_th = amaps.min()
max_th = amaps.max()
Expand All @@ -112,11 +112,11 @@ def compute_pro(masks, amaps, num_th=200):
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()

df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)
df = df._append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)

# Normalize FPR from 0 ~ 1 to 0 ~ 0.3
df = df[df["fpr"] < 0.3]
df["fpr"] = df["fpr"] / df["fpr"].max()

pro_auc = metrics.auc(df["fpr"], df["pro"])
return pro_auc
return pro_auc
10 changes: 6 additions & 4 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
datapath=/data4/MVTec_ad
datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')
# datapath=/data4/MVTec_ad
datapath=/content
# datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')
datasets=('pill')
dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done))

python3 main.py \
--gpu 4 \
--gpu 1 \
--seed 0 \
--log_group simplenet_mvtec \
--log_project MVTecAD_Results \
Expand All @@ -16,7 +18,7 @@ net \
--pretrain_embed_dimension 1536 \
--target_embed_dimension 1536 \
--patchsize 3 \
--meta_epochs 40 \
--meta_epochs 15 \
--embedding_size 256 \
--gan_epochs 4 \
--noise_std 0.015 \
Expand Down
89 changes: 47 additions & 42 deletions simplenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,42 @@ def forward(self,x):


class Projection(torch.nn.Module):

def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
super(Projection, self).__init__()

if out_planes is None:
out_planes = in_planes
self.layers = torch.nn.Sequential()
_in = None
_out = None
for i in range(n_layers):
_in = in_planes if i == 0 else _out
_out = out_planes
self.layers.add_module(f"{i}fc",
_out = out_planes
self.layers.add_module(f"{i}fc",
torch.nn.Linear(_in, _out))
if i < n_layers - 1:
# if layer_type > 0:
# self.layers.add_module(f"{i}bn",
# self.layers.add_module(f"{i}bn",
# torch.nn.BatchNorm1d(_out))
if layer_type > 1:
self.layers.add_module(f"{i}relu",
torch.nn.LeakyReLU(.2))
self.apply(init_weight)

def forward(self, x):

# x = .1 * self.layers(x) + x
x = self.layers(x)
return x


class TBWrapper:

def __init__(self, log_dir):
self.g_iter = 0
self.logger = SummaryWriter(log_dir=log_dir)

def step(self):
self.g_iter += 1

Expand All @@ -111,7 +111,7 @@ def load(
pretrain_embed_dimension, # 1536
target_embed_dimension, # 1536
patchsize=3, # 3
patchstride=1,
patchstride=1,
embedding_size=None, # 256
meta_epochs=1, # 40
aed_meta_epochs=1,
Expand Down Expand Up @@ -195,7 +195,7 @@ def show_mem():
self.discriminator.to(self.device)
self.dsc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.dsc_lr, weight_decay=1e-5)
self.dsc_schl = torch.optim.lr_scheduler.CosineAnnealingLR(self.dsc_opt, (meta_epochs - aed_meta_epochs) * gan_epochs, self.dsc_lr*.4)
self.dsc_margin= dsc_margin
self.dsc_margin= dsc_margin

self.model_dir = ""
self.dataset_name = ""
Expand All @@ -204,14 +204,14 @@ def show_mem():

def set_model_dir(self, model_dir, dataset_name):

self.model_dir = model_dir
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
self.ckpt_dir = os.path.join(self.model_dir, dataset_name)
os.makedirs(self.ckpt_dir, exist_ok=True)
self.tb_dir = os.path.join(self.ckpt_dir, "tb")
os.makedirs(self.tb_dir, exist_ok=True)
self.logger = TBWrapper(self.tb_dir) #SummaryWriter(log_dir=tb_dir)


def embed(self, data):
if isinstance(data, torch.utils.data.DataLoader):
Expand Down Expand Up @@ -276,16 +276,16 @@ def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=Fal
_features = _features.reshape(len(_features), -1, *_features.shape[-3:])
features[i] = _features
features = [x.reshape(-1, *x.shape[-3:]) for x in features]

# As different feature backbones & patching provide differently
# sized features, these are brought into the correct form here.
features = self.forward_modules["preprocessing"](features) # pooling each feature to same channel and stack together
features = self.forward_modules["preadapt_aggregator"](features) # further pooling
features = self.forward_modules["preadapt_aggregator"](features) # further pooling


return features, patch_shapes


def test(self, training_data, test_data, save_segmentation_images):

ckpt_path = os.path.join(self.ckpt_dir, "models.ckpt")
Expand Down Expand Up @@ -328,7 +328,7 @@ def test(self, training_data, test_data, save_segmentation_images):

if save_segmentation_images:
self.save_segmentation_images(test_data, segmentations, scores)

auroc = metrics.compute_imagewise_retrieval_metrics(
scores, anomaly_labels
)["auroc"]
Expand All @@ -340,17 +340,17 @@ def test(self, training_data, test_data, save_segmentation_images):
full_pixel_auroc = pixel_scores["auroc"]

return auroc, full_pixel_auroc

def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks_gt):

scores = np.squeeze(np.array(scores))
img_min_scores = scores.min(axis=-1)
img_max_scores = scores.max(axis=-1)
scores = (scores - img_min_scores) / (img_max_scores - img_min_scores)
# scores = np.mean(scores, axis=0)

auroc = metrics.compute_imagewise_retrieval_metrics(
scores, labels_gt
scores, labels_gt
)["auroc"]

if len(masks_gt) > 0:
Expand All @@ -377,18 +377,18 @@ def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks
# segmentations, masks_gt
full_pixel_auroc = pixel_scores["auroc"]

pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)),
pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)),
norm_segmentations)
else:
full_pixel_auroc = -1
full_pixel_auroc = -1
pro = -1

return auroc, full_pixel_auroc, pro


def train(self, training_data, test_data):


state_dict = {}
ckpt_path = os.path.join(self.ckpt_dir, "ckpt.pth")
if os.path.exists(ckpt_path):
Expand All @@ -403,21 +403,25 @@ def train(self, training_data, test_data):
self.predict(training_data, "train_")
scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)
auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)

return auroc, full_pixel_auroc, anomaly_pixel_auroc

def update_state_dict(d):

state_dict["discriminator"] = OrderedDict({
k:v.detach().cpu()
k:v.detach().cpu()
for k, v in self.discriminator.state_dict().items()})
if self.pre_proj > 0:
state_dict["pre_projection"] = OrderedDict({
k:v.detach().cpu()
k:v.detach().cpu()
for k, v in self.pre_projection.state_dict().items()})

best_record = None
for i_mepoch in range(self.meta_epochs):
print('----------------------------')
print('----------------------------')
print('----------------------------')
print(f'Epoch: {i_mepoch + 1}')

self._train_discriminator(training_data)

Expand All @@ -439,23 +443,24 @@ def update_state_dict(d):
# state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})
elif auroc == best_record[0] and full_pixel_auroc > best_record[1]:
best_record[1] = full_pixel_auroc
best_record[2] = pro
best_record[2] = pro
update_state_dict(state_dict)
# state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})

print(f"----- {i_mepoch} I-AUROC:{round(auroc, 4)}(MAX:{round(best_record[0], 4)})"
f" P-AUROC{round(full_pixel_auroc, 4)}(MAX:{round(best_record[1], 4)}) -----"
f" PRO-AUROC{round(pro, 4)}(MAX:{round(best_record[2], 4)}) -----")

torch.save(state_dict, ckpt_path)

torch.save(state_dict, ckpt_path)

return best_record


def _train_discriminator(self, input_data):
"""Computes and sets the support features for SPADE."""
_ = self.forward_modules.eval()

if self.pre_proj > 0:
self.pre_projection.train()
self.discriminator.train()
Expand Down Expand Up @@ -483,7 +488,7 @@ def _train_discriminator(self, input_data):
true_feats = self.pre_projection(self._embed(img, evaluation=False)[0])
else:
true_feats = self._embed(img, evaluation=False)[0]

noise_idxs = torch.randint(0, self.mix_noise, torch.Size([true_feats.shape[0]]))
noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=self.mix_noise).to(self.device) # (N, K)
noise = torch.stack([
Expand All @@ -495,7 +500,7 @@ def _train_discriminator(self, input_data):
scores = self.discriminator(torch.cat([true_feats, fake_feats]))
true_scores = scores[:len(true_feats)]
fake_scores = scores[len(fake_feats):]

th = self.dsc_margin
p_true = (true_scores.detach() >= th).sum() / len(true_scores)
p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores)
Expand All @@ -516,17 +521,17 @@ def _train_discriminator(self, input_data):
self.backbone_opt.step()
self.dsc_opt.step()

loss = loss.detach().cpu()
loss = loss.detach().cpu()
all_loss.append(loss.item())
all_p_true.append(p_true.cpu().item())
all_p_fake.append(p_fake.cpu().item())

if len(embeddings_list) > 0:
self.auto_noise[1] = torch.cat(embeddings_list).std(0).mean(-1)

if self.cos_lr:
self.dsc_schl.step()

all_loss = sum(all_loss) / len(input_data)
all_p_true = sum(all_p_true) / len(input_data)
all_p_fake = sum(all_p_fake) / len(input_data)
Expand Down Expand Up @@ -584,7 +589,7 @@ def _predict(self, images):
self.discriminator.eval()
with torch.no_grad():
features, patch_shapes = self._embed(images,
provide_patch_shapes=True,
provide_patch_shapes=True,
evaluation=True)
if self.pre_proj > 0:
features = self.pre_projection(features)
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def set_torch_device(gpu_ids):
if len(gpu_ids):
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids[0])
return torch.device("cuda:{}".format(gpu_ids[0]))
return torch.device("cuda:{}".format(gpu_ids[0]-1))
return torch.device("cpu")


Expand Down