Skip to content

Commit

Permalink
Custom dino loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
bjura committed Mar 13, 2024
1 parent 54523d1 commit 0a2f048
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
21 changes: 21 additions & 0 deletions eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import vision_transformer as vits

from pipnet_for_dino import load_pipnet_for_dino
import numpy as np
import matplotlib.pyplot as plt


def extract_feature_pipeline(args):
Expand Down Expand Up @@ -83,6 +85,25 @@ def extract_feature_pipeline(args):
print("Extracting features for val set...")
test_features = extract_features(model, data_loader_val, args.use_cuda)

thresh = [0.5, 0.6, 0.7, 0.8, 0.9, 1]
fig = plt.figure()
for idx, thr in enumerate(thresh):
plt.subplot(2, 3, idx + 1)
counts = np.sum(train_features.detach().cpu().numpy() >= thr, axis=1).flatten()
plt.hist(counts, bins=(1000 if thr == 1 else 10))
plt.title('# >= ' + str(thr))
plt.tight_layout()
plt.savefig('train_histograms_of_high_prototype_activations.png')

fig = plt.figure()
for idx, thr in enumerate(thresh):
plt.subplot(2, 3, idx + 1)
counts = np.sum(test_features.detach().cpu().numpy() >= thr, axis=1).flatten()
plt.hist(counts, bins=(1000 if thr == 1 else 10))
plt.title('# >= ' + str(thr))
plt.tight_layout()
plt.savefig('test_histograms_of_high_prototype_activations.png')

if utils.get_rank() == 0:
train_features = nn.functional.normalize(train_features, dim=1, p=2)
test_features = nn.functional.normalize(test_features, dim=1, p=2)
Expand Down
18 changes: 10 additions & 8 deletions main_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_args_parser():
# Model parameters
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \
+ torchvision_archs + torch.hub.list("facebookresearch/xcit:main") + ['pipnet'],
+ torchvision_archs + ['pipnet'], # torch.hub.list("facebookresearch/xcit:main"),
help="""Name of architecture to train. For quick experiments with ViTs,
we recommend using vit_tiny or vit_small.""")
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
Expand Down Expand Up @@ -173,12 +173,14 @@ def train_dino(args):
)
teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
embed_dim = student.embed_dim

# if the network is a XCiT
elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
student = torch.hub.load('facebookresearch/xcit:main', args.arch,
pretrained=False, drop_path_rate=args.drop_path_rate)
teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False)
embed_dim = student.embed_dim
#elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
# student = torch.hub.load('facebookresearch/xcit:main', args.arch,
# pretrained=False, drop_path_rate=args.drop_path_rate)
# teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False)
# embed_dim = student.embed_dim

# otherwise, we check if the architecture is in torchvision models
elif args.arch in torchvision_models.__dict__.keys():
student = torchvision_models.__dict__[args.arch]()
Expand Down Expand Up @@ -402,7 +404,7 @@ def forward(self, student_output, teacher_output, epoch):

# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = torch.minimum(torch.ones_like(teacher_output), (teacher_output + F.relu(-(teacher_output - self.center)) - self.center) / temp)
teacher_out = teacher_out.detach().chunk(2)

total_loss = 0
Expand All @@ -412,7 +414,7 @@ def forward(self, student_output, teacher_output, epoch):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
loss = torch.sum(-q * torch.log(student_out[v]), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
Expand Down

0 comments on commit 0a2f048

Please sign in to comment.