Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ python bin/run_patchcore.py \

patch_core # We now pass all PatchCore-related parameters.
-b wideresnet50 # Which backbone to use.
-bp /path/to/simclr.pth # Optional path to custom backbone weights.
-le layer2 -le layer3 # Which layers to extract features from.
--faiss_on_gpu # If similarity-searches should be performed on GPU.
--pretrain_embed_dimension 1024 --target_embed_dimension 1024 # Dimensionality of features extracted from backbone layer(s) and final aggregated PatchCore Dimensionality
Expand Down Expand Up @@ -145,7 +146,8 @@ allows you to log all training & test performances online to Weights-and-Biases

Finally, due to the effectiveness and efficiency of PatchCore, we also incorporate the option to use
an ensemble of backbone networks and network featuremaps. For this, provide the list of backbones to
use (as listed in `/src/anomaly_detection/backbones.py`) with `-b <backbone` and, given their
use (as listed in `/src/anomaly_detection/backbones.py`) with `-b <backbone>` and optionally
provide paths to custom pretrained weights via `-bp <path>` following the same order. Given their
ordering, denote the layers to extract with `-le idx.<layer_name>`. An example with three different
backbones would look something like

Expand Down
17 changes: 15 additions & 2 deletions bin/run_patchcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def mask_transform(mask):
@main.command("patch_core")
# Pretraining-specific parameters.
@click.option("--backbone_names", "-b", type=str, multiple=True, default=[])
@click.option(
"--backbone_paths", "-bp", type=click.Path(exists=True), multiple=True, default=[]
)
@click.option("--layers_to_extract_from", "-le", type=str, multiple=True, default=[])
# Parameters for Glue-code (to merge different parts of the pipeline.
@click.option("--pretrain_embed_dimension", type=int, default=1024)
Expand All @@ -258,6 +261,7 @@ def mask_transform(mask):
@click.option("--faiss_num_workers", type=int, default=8)
def patch_core(
backbone_names,
backbone_paths,
layers_to_extract_from,
pretrain_embed_dimension,
target_embed_dimension,
Expand All @@ -272,6 +276,13 @@ def patch_core(
faiss_num_workers,
):
backbone_names = list(backbone_names)
backbone_paths = list(backbone_paths)
if len(backbone_paths) == 0:
backbone_paths = [None] * len(backbone_names)
elif len(backbone_paths) != len(backbone_names):
raise click.ClickException(
"--backbone_paths must match number of --backbone_names"
)
if len(backbone_names) > 1:
layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))]
for layer in layers_to_extract_from:
Expand All @@ -283,8 +294,8 @@ def patch_core(

def get_patchcore(input_shape, sampler, device):
loaded_patchcores = []
for backbone_name, layers_to_extract_from in zip(
backbone_names, layers_to_extract_from_coll
for backbone_name, layers_to_extract_from, backbone_path in zip(
backbone_names, layers_to_extract_from_coll, backbone_paths
):
backbone_seed = None
if ".seed-" in backbone_name:
Expand All @@ -293,6 +304,8 @@ def get_patchcore(input_shape, sampler, device):
)
backbone = patchcore.backbones.load(backbone_name)
backbone.name, backbone.seed = backbone_name, backbone_seed
if backbone_path is not None:
patchcore.utils.load_pretrained_weights(backbone, backbone_path, device)

nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers)

Expand Down
33 changes: 33 additions & 0 deletions src/patchcore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,36 @@ def compute_and_store_final_results(

mean_metrics = {"mean_{0}".format(key): item for key, item in mean_metrics.items()}
return mean_metrics


def load_pretrained_weights(model, weight_path, device):
"""Load a backbone checkpoint into ``model``.

The checkpoint is expected to be a ``torch.load``-able object and may
optionally contain a ``state_dict`` entry. Key prefixes such as
``module.``, ``model.``, or ``backbone.`` are stripped to allow loading
checkpoints produced by frameworks like SimCLR.

Args:
model: Backbone network to load weights into.
weight_path: Path to the checkpoint file.
device: Torch device to map the checkpoint to.
"""

checkpoint = torch.load(weight_path, map_location=device)
state_dict = checkpoint.get("state_dict", checkpoint)

cleaned_state_dict = {}
for key, value in state_dict.items():
for prefix in ("module.", "model.", "backbone."):
if key.startswith(prefix):
key = key[len(prefix) :]
cleaned_state_dict[key] = value

missing, unexpected = model.load_state_dict(cleaned_state_dict, strict=False)
if missing:
LOGGER.warning("Missing keys when loading state dict: %s", missing)
if unexpected:
LOGGER.warning("Unexpected keys when loading state dict: %s", unexpected)

return model