Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
class PositionalEmbedding(nn.Module): | |
""" | |
Taken from https://github.com/NVlabs/edm | |
""" | |
def __init__(self, num_channels, max_positions=10000, endpoint=False): | |
super().__init__() | |
self.num_channels = num_channels | |
self.max_positions = max_positions | |
self.endpoint = endpoint | |
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32) | |
freqs = 2 * freqs / self.num_channels | |
freqs = (1 / self.max_positions) ** freqs | |
self.register_buffer("freqs", freqs) | |
def forward(self, x): | |
x = torch.outer(x, self.freqs) | |
out = torch.cat([x.cos(), x.sin()], dim=1) | |
return out.to(x.dtype) | |
# ---------------------------------------------------------------------------- | |
# Timestep embedding used in the NCSN++ architecture. | |
class FourierEmbedding(nn.Module): | |
""" | |
Taken from https://github.com/NVlabs/edm | |
""" | |
def __init__(self, num_channels, scale=16): | |
super().__init__() | |
self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) | |
def forward(self, x): | |
x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) | |
x = torch.cat([x.cos(), x.sin()], dim=1) | |
return x | |