Realcat
add: efficientloftr
e02ffe6
raw
history blame
2.42 kB
import math
import torch
from torch import nn
class RoPEPositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()
i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1]
j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1]
assert npe is not None
train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W
div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4)))
div_term = div_term[None, None, :] # [1, 1, C//4]
sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term)
sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term)
cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term)
cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term)
sin = sin.repeat_interleave(2, dim=-1)
cos = cos.repeat_interleave(2, dim=-1)
self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2]
self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2]
def forward(self, x, ratio=1):
"""
Args:
x: [N, H, W, C]
"""
return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :])
def rotate_half(self, x):
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)