Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ train:
cache_dataset: true
centering: "centering" # or "sinkhorn_knopp"
unfreeze_last_n_blocks: 40
gradient_accumulation_steps: 1 # micro-batches per optimizer step; >1 enables grad accumulation
pretrained_weights: '' # path to a teacher_checkpoint.pth from a prior training run
student:
arch: vit_large
patch_size: 16
Expand Down
107 changes: 107 additions & 0 deletions dinov2/configs/train/vitg14_reg4_highres.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# High-resolution fine-tuning stage — as described in the Midnight paper (Section 2).
#
# This config is designed to be run AFTER standard-resolution training
# (vitg14_reg4.yaml). It fine-tunes the model at 392×392 global crops for
# 120 000 optimizer steps (~96 epochs), matching the Midnight-92k/392 recipe.
#
# Key differences from vitg14_reg4.yaml:
# crops.global_crops_size : 224 → 392
# crops.local_crops_size : 98 → 168
# train.batch_size_per_gpu: 48 → 16 (392px uses ~3× more GPU memory)
# gradient_accumulation_steps: 1 → 3 (effective batch stays 16×8×3 = 384)
# optim.base_lr : 2e-4 → 1e-4
# optim.epochs : — → 96 (96 × 1250 steps = 120 000 opt. steps)
# train.use_pretrained : True → False
# train.pretrained_weights: "" → set via CLI or edit below
# train.skip_checkpointer : true → false (save FSDP checkpoints for resuming)
#
# Usage:
# Set STAGE1_CHECKPOINT in run_highres.sh to point at the teacher_checkpoint.pth
# produced by vitg14_reg4 training (found under output_vitg14/eval/<iter>/).
# Then run: bash run_highres.sh

dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
do_kde: True
kde_loss_weight: .05
koleo_loss_weight: 0
do_koleo: False

ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.45
separate_head: true
head_n_prototypes: 131072

train:
# ── Data ──────────────────────────────────────────────────────────────────
sample_list_path: /block/TCGA/sample_dataset_30.txt
streaming_from_hf: false
streaming_dataset_path: medarc/TCGA-12K-parquet

# ── Batch / accumulation ──────────────────────────────────────────────────
# 392px global crops need ~3× the memory of 224px, so per-GPU batch is
# reduced from 48 to 16. Three accumulation steps restore the effective
# total batch size to 16 × 8 GPUs × 3 = 384.
batch_size_per_gpu: 16
gradient_accumulation_steps: 3

# ── Initialisation ────────────────────────────────────────────────────────
# Do NOT reload Meta's DINOv2 backbone; load from our own stage-1 checkpoint.
# Set pretrained_weights here OR pass it as a CLI override:
# train.pretrained_weights=/path/to/teacher_checkpoint.pth
use_pretrained: False
pretrained_weights: ""

# ── Checkpointing ─────────────────────────────────────────────────────────
# Enable so that this stage can be interrupted and resumed.
skip_checkpointer: false

# ── Misc ──────────────────────────────────────────────────────────────────
centering: sinkhorn_knopp
OFFICIAL_EPOCH_LENGTH: 1250
num_workers: 24
prefetch_factor: 8

student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
num_register_tokens: 4

teacher:
momentum_teacher: 0.994

optim:
# 96 epochs × 1250 steps/epoch = 120 000 optimizer steps (matches paper)
epochs: 96
early_stop: 96
# Reduced LR for the fine-tuning stage (paper: 1e-4)
base_lr: 1.0e-04
warmup_epochs: 5
weight_decay_end: 0.2
layerwise_decay: 1.0

crops:
# ── Resolution increase (core change for high-res stage) ─────────────────
global_crops_size: 392 # was 224; (392/14)^2 = 784 patch tokens per image
local_crops_size: 168 # was 98
global_crops_scale:
- 0.32
- 1.0
local_crops_scale:
- 0.05
- 0.32
local_crops_number: 8

evaluation:
eval_period_iterations: 5000
bach_root: /block/eva-data/bach
breakhis_root: /block/eva-data/breakhis
pcam_root: /block/eva-data/patch_camelyon
4 changes: 2 additions & 2 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def backprop_loss(self, loss):
else:
loss.backward()

def forward_backward(self, images, teacher_temp):
def forward_backward(self, images, teacher_temp, accumulation_steps=1):
n_global_crops = 2
assert n_global_crops == 2
n_local_crops = self.cfg.crops.local_crops_number
Expand Down Expand Up @@ -355,7 +355,7 @@ def get_teacher_output():
# accumulate loss
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss

self.backprop_loss(loss_accumulator)
self.backprop_loss(loss_accumulator / accumulation_steps)

self.fsdp_synchronize_streams()

Expand Down
73 changes: 69 additions & 4 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,45 @@ def _mlp_kind(block):
student_backbone.norm.bias.copy_(model_pretrained.norm.bias)


def _load_from_teacher_checkpoint(cfg, model):
"""
Initialise student and teacher backbones from a teacher_checkpoint.pth produced
by do_test(). The file has the structure {"teacher": model.teacher.state_dict()},
where keys look like "backbone.patch_embed.proj.weight" (FSDP "module." prefix
already stripped by do_test). This is used to start high-resolution fine-tuning
from a checkpoint of the standard-resolution training stage.
"""
ckpt_path = cfg.train.pretrained_weights
logger.info("Loading backbone from teacher checkpoint: %s", ckpt_path)
state = torch.load(ckpt_path, map_location="cpu")

# teacher_checkpoint.pth is saved as {"teacher": <state_dict>}
teacher_sd = state.get("teacher", state)

# Strip any residual FSDP "module." prefix
teacher_sd = {k.replace("module.", ""): v for k, v in teacher_sd.items()}

# Keep only backbone keys, then strip the "backbone." prefix
backbone_sd = {
k[len("backbone."):]: v
for k, v in teacher_sd.items()
if k.startswith("backbone.")
}

with torch.no_grad():
missing, unexpected = model.student.backbone.load_state_dict(backbone_sd, strict=False)
if missing:
logger.warning("Missing keys loading pretrained backbone (student): %s", missing)
if unexpected:
logger.warning("Unexpected keys loading pretrained backbone (student): %s", unexpected)

# teacher will be synced from student in prepare_for_distributed_training(),
# but copy explicitly here too for clarity
model.teacher.backbone.load_state_dict(backbone_sd, strict=False)

logger.info("Loaded backbone weights from %s", ckpt_path)


def _freeze_student_backbone_except_last_n(cfg, model):
n_unfrozen = cfg.train.unfreeze_last_n_blocks
student_backbone = model.student.backbone
Expand Down Expand Up @@ -391,8 +430,9 @@ def __init__(self, size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.22
]
super().__init__(ops)

eval_size = cfg.crops.global_crops_size # matches inference resolution (392 in high-res stage)
transform = _ResizeAndCrop(
size=224,
size=eval_size,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
Expand Down Expand Up @@ -1141,19 +1181,27 @@ def _worker_init(_):
# training loop

iteration = start_iter
accumulation_steps = int(getattr(cfg.train, "gradient_accumulation_steps", 1))
if accumulation_steps < 1:
accumulation_steps = 1
if accumulation_steps > 1:
logger.info("Gradient accumulation enabled: %d micro-batches per optimizer step", accumulation_steps)

logger.info("Starting training from iteration {}".format(start_iter))
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
header = "Training"

for data in metric_logger.log_every(
# Keep a named reference to the generator so we can call next() on it inside
# the loop body to fetch the additional micro-batches for gradient accumulation.
log_every_iter = metric_logger.log_every(
data_loader,
10,
header,
eta_target_iter + 1,
start_iter,
):
)
for data in log_every_iter:
if iteration >= early_stop_iter:
logger.info("Early stopping at iteration {}".format(iteration))
if cfg.evaluation.eval_period_iterations >= 0:
Expand Down Expand Up @@ -1194,7 +1242,22 @@ def _worker_init(_):

optimizer.zero_grad(set_to_none=True)

loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
# First micro-batch (already fetched by the outer for-loop)
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, accumulation_steps=accumulation_steps)

# Remaining micro-batches (gradient accumulation)
for _accum in range(accumulation_steps - 1):
try:
extra_data = next(log_every_iter)
except StopIteration:
break
extra_loss_dict = model.forward_backward(extra_data, teacher_temp=teacher_temp, accumulation_steps=accumulation_steps)
for k in extra_loss_dict:
loss_dict[k] = loss_dict[k] + extra_loss_dict[k]

# Average the accumulated loss tensors so logging reflects per-sample values
if accumulation_steps > 1:
loss_dict = {k: v / accumulation_steps for k, v in loss_dict.items()}

# clip gradients

Expand Down Expand Up @@ -1265,6 +1328,8 @@ def main(args):
#Load model here from pretrained.
if cfg.train.use_pretrained:
_load_pretrained_backbone(cfg, model)
elif getattr(cfg.train, "pretrained_weights", ""):
_load_from_teacher_checkpoint(cfg, model)
_freeze_student_backbone_except_last_n(cfg, model)

model.prepare_for_distributed_training()
Expand Down
8 changes: 6 additions & 2 deletions dinov2/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ def apply_scaling_rules_to_cfg(cfg): # to fix
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
base_lr = cfg.optim.base_lr
cfg.optim.lr = base_lr
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
# Use the effective batch size (micro-batch × GPUs × accumulation steps) so
# that gradient accumulation doesn't artificially shrink the learning rate.
accum = int(getattr(cfg.train, "gradient_accumulation_steps", 1))
effective_batch = cfg.train.batch_size_per_gpu * distributed.get_global_size() * accum
cfg.optim.lr *= math.sqrt(effective_batch / 1024.0)
logger.info(f"sqrt scaling learning rate; base: {base_lr}, effective_batch: {effective_batch}, new: {cfg.optim.lr}")
else:
raise NotImplementedError
return cfg
Expand Down
71 changes: 71 additions & 0 deletions run_highres.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env bash
# High-resolution fine-tuning stage (Midnight paper, Section 2).
#
# Run this AFTER standard-resolution training (run_1node.sh) has produced a
# teacher checkpoint. Typical path:
# output_vitg14/eval/training_<iter>/teacher_checkpoint.pth
#
# Usage:
# 1. Set STAGE1_CHECKPOINT below to the path of your stage-1 teacher checkpoint.
# 2. bash run_highres.sh
#
# To resume an interrupted high-res run, set RESUME="True".

set -euo pipefail

# ── Distributed setup ────────────────────────────────────────────────────────
export MASTER_ADDR=$(hostname -I | awk '{print $1}')
export MASTER_PORT=29501 # different from stage-1 default (29500)

export NNODES=1
export NPROC_PER_NODE=8
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export NODE_RANK=0

# ── Paths ────────────────────────────────────────────────────────────────────
CONFIG_FILE="./dinov2/configs/train/vitg14_reg4_highres.yaml"
OUTPUT_DIR="./output_vitg14_highres"

# Path to the teacher_checkpoint.pth saved at the end of stage-1 training.
# Example: ./output_vitg14/eval/training_149999/teacher_checkpoint.pth
STAGE1_CHECKPOINT="./output_vitg14/eval/training_149999/teacher_checkpoint.pth"

# Set to "True" to resume a previously started high-res run from its last
# FSDP checkpoint; set to "False" to start fresh from STAGE1_CHECKPOINT.
RESUME="False"

# ── Setup ────────────────────────────────────────────────────────────────────
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
export DINOV2_RUN_SCRIPT="${REPO_ROOT}/$(basename "${BASH_SOURCE[0]}")"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

if [[ "${RESUME}" == "True" ]]; then
echo "Resume mode: continuing from last FSDP checkpoint in ${OUTPUT_DIR}"
RESUME_FLAG=""
# When resuming we do NOT pass pretrained_weights; the FSDP checkpoint
# already contains the full model state.
PRETRAINED_OVERRIDE=""
else
echo "Fresh start: loading backbone from ${STAGE1_CHECKPOINT}"
if [[ ! -f "${STAGE1_CHECKPOINT}" ]]; then
echo "ERROR: STAGE1_CHECKPOINT not found: ${STAGE1_CHECKPOINT}" >&2
exit 1
fi
rm -rf "${OUTPUT_DIR}"
RESUME_FLAG="--no-resume"
PRETRAINED_OVERRIDE="train.pretrained_weights=${STAGE1_CHECKPOINT}"
fi
mkdir -p "${OUTPUT_DIR}"

# ── Launch ───────────────────────────────────────────────────────────────────
uv run torchrun \
--nnodes "${NNODES}" \
--nproc_per_node "${NPROC_PER_NODE}" \
--node_rank "${NODE_RANK}" \
--master_addr "${MASTER_ADDR}" \
--master_port "${MASTER_PORT}" \
dinov2/train/train.py \
--config-file "${CONFIG_FILE}" \
--output-dir "${OUTPUT_DIR}" \
${PRETRAINED_OVERRIDE} \
${RESUME_FLAG}