Switti / models /helpers.py
realantonvoronov
init commit
55ca09f
import torch
from torch import nn as nn
from torch.nn import functional as F
def sample_with_top_k_top_p_(
logits_BlV: torch.Tensor,
top_k: int = 0,
top_p: float = 0.0,
rng=None,
num_samples=1,
) -> torch.Tensor: # return idx, shaped (B, l)
B, l, V = logits_BlV.shape
if top_k > 0:
idx_to_remove = logits_BlV < logits_BlV.topk(
top_k, largest=True, sorted=False, dim=-1
)[0].amin(dim=-1, keepdim=True)
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
if top_p > 0:
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
sorted_idx_to_remove[..., -1:] = False
logits_BlV.masked_fill_(
sorted_idx_to_remove.scatter(
sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove
),
-torch.inf,
)
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
replacement = num_samples >= 0
num_samples = abs(num_samples)
return torch.multinomial(
logits_BlV.softmax(dim=-1).view(-1, V),
num_samples=num_samples,
replacement=replacement,
generator=rng,
).view(B, l, num_samples)
def gumbel_softmax_with_rng(
logits: torch.Tensor,
tau: float = 1,
hard: bool = False,
eps: float = 1e-10,
dim: int = -1,
rng: torch.Generator | None = None,
) -> torch.Tensor:
if rng is None:
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_(generator=rng)
.log()
)
gumbels = (logits + gumbels) / tau
y_soft = gumbels.softmax(dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
logits, memory_format=torch.legacy_contiguous_format
).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
): # taken from timm
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module): # taken from timm
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"(drop_prob=...)"