Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,471 Bytes
4450790 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from einops import rearrange
import math
import comfy.ops
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.empty(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
)
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
class Conditioner(nn.Module):
def __init__(
self,
dim: int,
output_dim: int,
project_out: bool = False
):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
def forward(self, x):
raise NotImplementedError()
class NumberConditioner(Conditioner):
'''
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
'''
def __init__(self,
output_dim: int,
min_val: float=0,
max_val: float=1
):
super().__init__(output_dim, output_dim)
self.min_val = min_val
self.max_val = max_val
self.embedder = NumberEmbedder(features=output_dim)
def forward(self, floats, device=None):
# Cast the inputs to floats
floats = [float(x) for x in floats]
if device is None:
device = next(self.embedder.parameters()).device
floats = torch.tensor(floats).to(device)
floats = floats.clamp(self.min_val, self.max_val)
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
# Cast floats to same type as embedder
embedder_dtype = next(self.embedder.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|