Spaces:
Running
on
L40S
Running
on
L40S
#original code from https://github.com/genmoai/models under apache 2.0 license | |
# import functools | |
import math | |
import torch | |
def centers(start: float, stop, num, dtype=None, device=None): | |
"""linspace through bin centers. | |
Args: | |
start (float): Start of the range. | |
stop (float): End of the range. | |
num (int): Number of points. | |
dtype (torch.dtype): Data type of the points. | |
device (torch.device): Device of the points. | |
Returns: | |
centers (Tensor): Centers of the bins. Shape: (num,). | |
""" | |
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) | |
return (edges[:-1] + edges[1:]) / 2 | |
# @functools.lru_cache(maxsize=1) | |
def create_position_matrix( | |
T: int, | |
pH: int, | |
pW: int, | |
device: torch.device, | |
dtype: torch.dtype, | |
*, | |
target_area: float = 36864, | |
): | |
""" | |
Args: | |
T: int - Temporal dimension | |
pH: int - Height dimension after patchify | |
pW: int - Width dimension after patchify | |
Returns: | |
pos: [T * pH * pW, 3] - position matrix | |
""" | |
# Create 1D tensors for each dimension | |
t = torch.arange(T, dtype=dtype) | |
# Positionally interpolate to area 36864. | |
# (3072x3072 frame with 16x16 patches = 192x192 latents). | |
# This automatically scales rope positions when the resolution changes. | |
# We use a large target area so the model is more sensitive | |
# to changes in the learned pos_frequencies matrix. | |
scale = math.sqrt(target_area / (pW * pH)) | |
w = centers(-pW * scale / 2, pW * scale / 2, pW) | |
h = centers(-pH * scale / 2, pH * scale / 2, pH) | |
# Use meshgrid to create 3D grids | |
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") | |
# Stack and reshape the grids. | |
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] | |
pos = pos.view(-1, 3) # [T * pH * pW, 3] | |
pos = pos.to(dtype=dtype, device=device) | |
return pos | |
def compute_mixed_rotation( | |
freqs: torch.Tensor, | |
pos: torch.Tensor, | |
): | |
""" | |
Project each 3-dim position into per-head, per-head-dim 1D frequencies. | |
Args: | |
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position | |
pos: [N, 3] - position of each token | |
num_heads: int | |
Returns: | |
freqs_cos: [N, num_heads, num_freqs] - cosine components | |
freqs_sin: [N, num_heads, num_freqs] - sine components | |
""" | |
assert freqs.ndim == 3 | |
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) | |
freqs_cos = torch.cos(freqs_sum) | |
freqs_sin = torch.sin(freqs_sum) | |
return freqs_cos, freqs_sin | |