Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
1.85 kB
import math
from typing import List, Optional, Tuple
import torch
def to_sequence(map):
return map.flatten(-2).transpose(-1, -2)
def to_map(sequence):
n = sequence.shape[-2]
e = math.isqrt(n)
assert e * e == n
assert e * e == n
sequence.transpose(-1, -2).unflatten(-1, [e, e])
def pad_to_length(
x,
length: int,
pad_dim: int = -2,
mode: str = "zeros", # zeros, ones, random, random_c
bounds: Tuple[int] = (None, None),
):
shape = list(x.shape)
d = x.shape[pad_dim]
assert d <= length
if d == length:
return x
shape[pad_dim] = length - d
low, high = bounds
if mode == "zeros":
xn = torch.zeros(*shape, device=x.device, dtype=x.dtype)
elif mode == "ones":
xn = torch.ones(*shape, device=x.device, dtype=x.dtype)
elif mode == "random":
low = low if low is not None else x.min()
high = high if high is not None else x.max()
xn = torch.empty(*shape, device=x.device).uniform_(low, high)
elif mode == "random_c":
low, high = bounds # we use the bounds as fallback for empty seq.
xn = torch.cat(
[
torch.empty(*shape[:-1], 1, device=x.device).uniform_(
x[..., i].min() if d > 0 else low,
x[..., i].max() if d > 0 else high,
)
for i in range(shape[-1])
],
dim=-1,
)
else:
raise ValueError(mode)
return torch.cat([x, xn], dim=pad_dim)
def pad_and_stack(
sequences: List[torch.Tensor],
length: Optional[int] = None,
pad_dim: int = -2,
**kwargs,
):
if length is None:
length = max([x.shape[pad_dim] for x in sequences])
y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0)
return y