Spaces:
Running
on
Zero
Running
on
Zero
# code adapted from: https://github.com/Stability-AI/stable-audio-tools | |
import torch | |
import torch.nn as nn | |
from torch import Tensor, einsum | |
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union | |
from einops import rearrange | |
import math | |
import comfy.ops | |
class LearnedPositionalEmbedding(nn.Module): | |
"""Used for continuous time""" | |
def __init__(self, dim: int): | |
super().__init__() | |
assert (dim % 2) == 0 | |
half_dim = dim // 2 | |
self.weights = nn.Parameter(torch.empty(half_dim)) | |
def forward(self, x: Tensor) -> Tensor: | |
x = rearrange(x, "b -> b 1") | |
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi | |
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) | |
fouriered = torch.cat((x, fouriered), dim=-1) | |
return fouriered | |
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: | |
return nn.Sequential( | |
LearnedPositionalEmbedding(dim), | |
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features), | |
) | |
class NumberEmbedder(nn.Module): | |
def __init__( | |
self, | |
features: int, | |
dim: int = 256, | |
): | |
super().__init__() | |
self.features = features | |
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) | |
def forward(self, x: Union[List[float], Tensor]) -> Tensor: | |
if not torch.is_tensor(x): | |
device = next(self.embedding.parameters()).device | |
x = torch.tensor(x, device=device) | |
assert isinstance(x, Tensor) | |
shape = x.shape | |
x = rearrange(x, "... -> (...)") | |
embedding = self.embedding(x) | |
x = embedding.view(*shape, self.features) | |
return x # type: ignore | |
class Conditioner(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
output_dim: int, | |
project_out: bool = False | |
): | |
super().__init__() | |
self.dim = dim | |
self.output_dim = output_dim | |
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() | |
def forward(self, x): | |
raise NotImplementedError() | |
class NumberConditioner(Conditioner): | |
''' | |
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings | |
''' | |
def __init__(self, | |
output_dim: int, | |
min_val: float=0, | |
max_val: float=1 | |
): | |
super().__init__(output_dim, output_dim) | |
self.min_val = min_val | |
self.max_val = max_val | |
self.embedder = NumberEmbedder(features=output_dim) | |
def forward(self, floats, device=None): | |
# Cast the inputs to floats | |
floats = [float(x) for x in floats] | |
if device is None: | |
device = next(self.embedder.parameters()).device | |
floats = torch.tensor(floats).to(device) | |
floats = floats.clamp(self.min_val, self.max_val) | |
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) | |
# Cast floats to same type as embedder | |
embedder_dtype = next(self.embedder.parameters()).dtype | |
normalized_floats = normalized_floats.to(embedder_dtype) | |
float_embeds = self.embedder(normalized_floats).unsqueeze(1) | |
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] | |