Spaces:
Starting
on
L40S
Starting
on
L40S
File size: 6,240 Bytes
dbac20f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from typing import Literal, Optional
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
from mmaudio.ext.autoencoder import AutoEncoderModule
from mmaudio.ext.mel_converter import MelConverter
from mmaudio.ext.synchformer import Synchformer
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
def patch_clip(clip_model):
# a hack to make it output last hidden states
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
def new_encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
return F.normalize(x, dim=-1) if normalize else x
clip_model.encode_text = new_encode_text.__get__(clip_model)
return clip_model
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
tod_vae_ckpt: Optional[str] = None,
bigvgan_vocoder_ckpt: Optional[str] = None,
synchformer_ckpt: Optional[str] = None,
enable_conditions: bool = True,
mode=Literal['16k', '44k'],
):
super().__init__()
if enable_conditions:
self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
return_transform=False)
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
self.clip_model = patch_clip(self.clip_model)
self.synchformer = Synchformer()
self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
else:
self.clip_model = None
self.synchformer = None
self.tokenizer = None
if tod_vae_ckpt is not None:
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
mode=mode)
else:
self.tod = None
self.mel_converter = MelConverter()
def compile(self):
if self.clip_model is not None:
self.encode_video_with_clip = torch.compile(self.encode_video_with_clip)
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
if self.synchformer is not None:
self.synchformer = torch.compile(self.synchformer)
self.tod.encode = torch.compile(self.tod.encode)
self.decode = torch.compile(self.decode)
self.vocode = torch.compile(self.vocode)
def train(self, mode: bool) -> None:
return super().train(False)
@torch.inference_mode()
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 384 and w == 384
x = self.clip_preprocess(x)
x = rearrange(x, 'b t c h w -> (b t) c h w')
outputs = []
if batch_size < 0:
batch_size = b * t
for i in range(0, b * t, batch_size):
outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True))
x = torch.cat(outputs, dim=0)
# x = self.clip_model.encode_image(x, normalize=True)
x = rearrange(x, '(b t) d -> b t d', b=b)
return x
@torch.inference_mode()
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.synchformer is not None, 'Synchformer is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
# partition the video
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size:i * step_size + segment_size])
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
outputs = []
if batch_size < 0:
batch_size = b
for i in range(0, b, batch_size):
outputs.append(self.synchformer(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0).flatten(start_dim=1, end_dim=2)
return x
@torch.inference_mode()
def encode_text(self, text: list[str]) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
assert self.tokenizer is not None, 'Tokenizer is not loaded'
# x: (B, L)
tokens = self.tokenizer(text).to(self.device)
return self.clip_model.encode_text(tokens, normalize=True)
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
assert self.tod is not None, 'VAE is not loaded'
# x: (B * L)
mel = self.mel_converter(x)
dist = self.tod.encode(mel)
return dist
@torch.inference_mode()
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.vocode(mel)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
assert self.tod is not None, 'VAE is not loaded'
return self.tod.decode(z.transpose(1, 2))
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
|