From df36d5470fce1ed9c97371fa12835657bb2c88cb Mon Sep 17 00:00:00 2001 From: s12359e <46375313+s12359e@users.noreply.github.com> Date: Thu, 7 Aug 2025 22:01:37 +0800 Subject: [PATCH] feat: allow custom backbone weights --- README.md | 4 +++- bin/run_patchcore.py | 17 +++++++++++++++-- src/patchcore/utils.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 16a73bd..e408b6b 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ python bin/run_patchcore.py \ patch_core # We now pass all PatchCore-related parameters. -b wideresnet50 # Which backbone to use. +-bp /path/to/simclr.pth # Optional path to custom backbone weights. -le layer2 -le layer3 # Which layers to extract features from. --faiss_on_gpu # If similarity-searches should be performed on GPU. --pretrain_embed_dimension 1024 --target_embed_dimension 1024 # Dimensionality of features extracted from backbone layer(s) and final aggregated PatchCore Dimensionality @@ -145,7 +146,8 @@ allows you to log all training & test performances online to Weights-and-Biases Finally, due to the effectiveness and efficiency of PatchCore, we also incorporate the option to use an ensemble of backbone networks and network featuremaps. For this, provide the list of backbones to -use (as listed in `/src/anomaly_detection/backbones.py`) with `-b ` and optionally +provide paths to custom pretrained weights via `-bp ` following the same order. Given their ordering, denote the layers to extract with `-le idx.`. An example with three different backbones would look something like diff --git a/bin/run_patchcore.py b/bin/run_patchcore.py index 8666b2b..3c59f1e 100644 --- a/bin/run_patchcore.py +++ b/bin/run_patchcore.py @@ -240,6 +240,9 @@ def mask_transform(mask): @main.command("patch_core") # Pretraining-specific parameters. @click.option("--backbone_names", "-b", type=str, multiple=True, default=[]) +@click.option( + "--backbone_paths", "-bp", type=click.Path(exists=True), multiple=True, default=[] +) @click.option("--layers_to_extract_from", "-le", type=str, multiple=True, default=[]) # Parameters for Glue-code (to merge different parts of the pipeline. @click.option("--pretrain_embed_dimension", type=int, default=1024) @@ -258,6 +261,7 @@ def mask_transform(mask): @click.option("--faiss_num_workers", type=int, default=8) def patch_core( backbone_names, + backbone_paths, layers_to_extract_from, pretrain_embed_dimension, target_embed_dimension, @@ -272,6 +276,13 @@ def patch_core( faiss_num_workers, ): backbone_names = list(backbone_names) + backbone_paths = list(backbone_paths) + if len(backbone_paths) == 0: + backbone_paths = [None] * len(backbone_names) + elif len(backbone_paths) != len(backbone_names): + raise click.ClickException( + "--backbone_paths must match number of --backbone_names" + ) if len(backbone_names) > 1: layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))] for layer in layers_to_extract_from: @@ -283,8 +294,8 @@ def patch_core( def get_patchcore(input_shape, sampler, device): loaded_patchcores = [] - for backbone_name, layers_to_extract_from in zip( - backbone_names, layers_to_extract_from_coll + for backbone_name, layers_to_extract_from, backbone_path in zip( + backbone_names, layers_to_extract_from_coll, backbone_paths ): backbone_seed = None if ".seed-" in backbone_name: @@ -293,6 +304,8 @@ def get_patchcore(input_shape, sampler, device): ) backbone = patchcore.backbones.load(backbone_name) backbone.name, backbone.seed = backbone_name, backbone_seed + if backbone_path is not None: + patchcore.utils.load_pretrained_weights(backbone, backbone_path, device) nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers) diff --git a/src/patchcore/utils.py b/src/patchcore/utils.py index 5cc724e..3403ee1 100644 --- a/src/patchcore/utils.py +++ b/src/patchcore/utils.py @@ -173,3 +173,36 @@ def compute_and_store_final_results( mean_metrics = {"mean_{0}".format(key): item for key, item in mean_metrics.items()} return mean_metrics + + +def load_pretrained_weights(model, weight_path, device): + """Load a backbone checkpoint into ``model``. + + The checkpoint is expected to be a ``torch.load``-able object and may + optionally contain a ``state_dict`` entry. Key prefixes such as + ``module.``, ``model.``, or ``backbone.`` are stripped to allow loading + checkpoints produced by frameworks like SimCLR. + + Args: + model: Backbone network to load weights into. + weight_path: Path to the checkpoint file. + device: Torch device to map the checkpoint to. + """ + + checkpoint = torch.load(weight_path, map_location=device) + state_dict = checkpoint.get("state_dict", checkpoint) + + cleaned_state_dict = {} + for key, value in state_dict.items(): + for prefix in ("module.", "model.", "backbone."): + if key.startswith(prefix): + key = key[len(prefix) :] + cleaned_state_dict[key] = value + + missing, unexpected = model.load_state_dict(cleaned_state_dict, strict=False) + if missing: + LOGGER.warning("Missing keys when loading state dict: %s", missing) + if unexpected: + LOGGER.warning("Unexpected keys when loading state dict: %s", unexpected) + + return model