Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch import nn | |
import logging | |
logger = logging.getLogger("dinov2") | |
try: | |
from xformers.ops import cross_entropy | |
def lossfunc(t, s, temp): | |
s = s.float() | |
t = t.float() | |
if s.ndim == 2: | |
return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) | |
elif s.ndim == 3: | |
return -cross_entropy(s, t, temp, bw_inplace=True) | |
except ImportError: | |
def lossfunc(t, s, temp): | |
return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) | |
class iBOTPatchLoss(nn.Module): | |
def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): | |
super().__init__() | |
self.student_temp = student_temp | |
self.center_momentum = center_momentum | |
self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) | |
self.updated = True | |
self.reduce_handle = None | |
self.len_teacher_patch_tokens = None | |
self.async_batch_center = None | |
def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): | |
self.apply_center_update() | |
# teacher centering and sharpening | |
# | |
# WARNING: | |
# as self.center is a float32, everything gets casted to float32 afterwards | |
# | |
# teacher_patch_tokens = teacher_patch_tokens.float() | |
# return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) | |
return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) | |
# this is experimental, keep everything in float16 and let's see what happens: | |
# return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) | |
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): | |
teacher_output = teacher_output.float() | |
# world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper | |
# B = Q.shape[1] * world_size # number of samples to assign | |
B = n_masked_patches_tensor | |
dist.all_reduce(B) | |
K = Q.shape[0] # how many prototypes | |
# make the matrix sums to 1 | |
sum_Q = torch.sum(Q) | |
if dist.is_initialized(): | |
dist.all_reduce(sum_Q) | |
Q /= sum_Q | |
for it in range(n_iterations): | |
# normalize each row: total weight per prototype must be 1/K | |
sum_of_rows = torch.sum(Q, dim=1, keepdim=True) | |
if dist.is_initialized(): | |
dist.all_reduce(sum_of_rows) | |
Q /= sum_of_rows | |
Q /= K | |
# normalize each column: total weight per sample must be 1/B | |
Q /= torch.sum(Q, dim=0, keepdim=True) | |
Q /= B | |
Q *= B # the columns must sum to 1 so that Q is an assignment | |
return Q.t() | |
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): | |
""" | |
Cross-entropy between softmax outputs of the teacher and student networks. | |
student_patch_tokens: (B, N, D) tensor | |
teacher_patch_tokens: (B, N, D) tensor | |
student_masks_flat: (B, N) tensor | |
""" | |
t = teacher_patch_tokens | |
s = student_patch_tokens | |
loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) | |
loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) | |
return -loss.mean() | |
def forward_masked( | |
self, | |
student_patch_tokens_masked, | |
teacher_patch_tokens_masked, | |
student_masks_flat, | |
n_masked_patches=None, | |
masks_weight=None, | |
): | |
t = teacher_patch_tokens_masked | |
s = student_patch_tokens_masked | |
# loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) | |
loss = lossfunc(t, s, self.student_temp) | |
if masks_weight is None: | |
masks_weight = ( | |
(1 / student_masks_flat.sum(-1).clamp(min=1.0)) | |
.unsqueeze(-1) | |
.expand_as(student_masks_flat)[student_masks_flat] | |
) | |
if n_masked_patches is not None: | |
loss = loss[:n_masked_patches] | |
loss = loss * masks_weight | |
return -loss.sum() / student_masks_flat.shape[0] | |
def update_center(self, teacher_patch_tokens): | |
self.reduce_center_update(teacher_patch_tokens) | |
def reduce_center_update(self, teacher_patch_tokens): | |
self.updated = False | |
self.len_teacher_patch_tokens = len(teacher_patch_tokens) | |
self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) | |
if dist.is_initialized(): | |
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) | |
def apply_center_update(self): | |
if self.updated is False: | |
world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
if self.reduce_handle is not None: | |
self.reduce_handle.wait() | |
_t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) | |
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) | |
self.updated = True | |