from typing import Optional, Tuple import torch as T import torch.nn as nn import torch.nn.functional as F from ioblocks import GaussianMixtureIOLayer, FSQ from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm from tokenizer import make_tokenizer from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored from utils import load_ckpt @si_module class LatentQuantizer(nn.Module): class Config: compressor_config: Optional[FSQ.Config] = None dim: Optional[int] = None ff_dim: Optional[int] = None input_dim: int = None from_pretrained: Optional[Tuple[str, str]] = None def __init__(self, c: Config): super().__init__() if exists(c.from_pretrained): checkpoint = load_ckpt(*c.from_pretrained) else: assert exists(c.compressor_config), f'hmm {c}' self.compressor = c.compressor_config() self.ffnn = FFNN(c.dim, c.ff_dim) self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity() if exists(c.from_pretrained): self.load_state_dict(checkpoint) @T.no_grad() def forward(self, x, return_latent=False, known_latent=None): """ x: (B, S, D) """ if exists(known_latent): return self.compressor.indices_to_codes(known_latent) x = self.input(x) x = self.ffnn(x) x, tokens = self.compressor(x) if return_latent: return x, tokens return x @si_module class TransformerVAE(nn.Module): class Config: io_config: Optional[GaussianMixtureIOLayer.Config] = None stack_config: Optional[Stack.Config] = None quantizer_config: Optional[LatentQuantizer.Config] = None plex_layer: int = None plex_roll: int = 1 split: bool = True from_pretrained: Optional[Tuple[str, str]] = None def __init__(self, c: Config): super().__init__() if exists(c.from_pretrained): checkpoint = load_ckpt(*c.from_pretrained) else: assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}' self.io = c.io_config() self.stack = c.stack_config() self.plex_layer = c.stack_config.layers//2 self.plex_roll = c.plex_roll self.plex_dim = c.quantizer_config.dim assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}' self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim) self.out_norm = Norm(c.stack_config.dim) if c.split: self.io2 = c.io_config() self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim) self.io2.fc_loc = None self.io2.fc_scale = None self.io2.fc_weight = None kv_heads = c.stack_config.kv_heads or c.stack_config.n_head head_dim = c.stack_config.dim // c.stack_config.n_head self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0) cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim] self.cache_shape = cache_shape self.cache = [None] * self.cache_num_layers if exists(c.from_pretrained): result = self.load_state_dict(checkpoint, strict=False) print0_colored(result, 'yellow') self.quantizer = c.quantizer_config().eval() self.quantizer.requires_grad = False @T.no_grad() def quantize(self, x): if self.c.split: x1, x2 = x.chunk(2, dim=-1) with T.autocast(device_type='cuda', dtype=T.bfloat16): quantized1 = self.quantizer(x1) quantized2 = self.quantizer(x2) return quantized1, quantized2 else: with T.autocast(device_type='cuda', dtype=T.bfloat16): return self.quantizer(x) @T.no_grad() def untokenize(self, token_data): return self.quantizer(None, known_latent=token_data) def init_cache(self, bsize, device, dtype, length:int=None): cache_shape = self.cache_shape.copy() cache_shape[1] = length or cache_shape[1] self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) def deinit_cache(self): self.cache = [None] * self.cache_num_layers @T.no_grad() def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None): if self.c.split: x1, x2 = data.chunk(2, dim=-1) x = self.io.input(x1) + self.io2.input(x2) else: x = self.io.input(data) cache_idx = 0 for l, layer in enumerate(self.stack.layers): if l == self.plex_layer: if self.c.split: plex1, plex2 = self.quantize(data) plex1 = T.roll(plex1, -self.c.plex_roll, dims=1) plex2 = T.roll(plex2, -self.c.plex_roll, dims=1) if exists(next_tokens): plex1[:, -1:] = self.untokenize(next_tokens[0]) plex2[:, -1:] = self.untokenize(next_tokens[1]) x1 = x + self.plex_projection(plex1) x2 = x + self.plex_projection2(plex2) else: plex = self.quantize(data) plex = T.roll(plex, -self.c.plex_roll, dims=1) if exists(next_tokens): plex[:, -1:] = self.untokenize(next_tokens) x = x + self.plex_projection(plex) if l < self.plex_layer: x = layer(x, kv=self.cache[l]) else: if self.c.split: x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx]) cache_idx += 1 x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx]) cache_idx += 1 else: x = layer(x, kv=self.cache[l]) with T.autocast(device_type='cuda', dtype=T.bfloat16): if self.c.split: x1, x2 = self.out_norm(x1), self.out_norm(x2) out1, out2 = self.io.output(x1), self.io.output(x2) else: x = self.out_norm(x) out = self.io.output(x) if isnt(temps): if self.c.split: return out1, out2 else: return out else: if self.c.split: next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :] next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :] next_data = T.cat([next_data1, next_data2], dim=-1) return next_data else: next_data = self.io.temp_sample(out, temps)[:, -1:, :] return next_data @si_module class HertzDevModel(nn.Module): class Config: dim: int vocab_size: int stack_config: Optional[Stack.Config] = None latent_size: int = 32 split: bool = True quantizer_config: Optional[LatentQuantizer.Config] = None resynthesizer_config: Optional[TransformerVAE.Config] = None from_pretrained: Optional[Tuple[str, str]] = None def __init__(self, c: Config): super().__init__() if exists(c.from_pretrained): checkpoint = load_ckpt(*c.from_pretrained) else: assert (exists(c.stack_config)), f'hmm {c}' self.input = nn.Linear(c.latent_size, c.dim) if self.c.split: self.input2 = nn.Linear(c.latent_size, c.dim) self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta) self.layers = nn.ModuleList([ PerfBlock( dim=c.stack_config.dim, layer_id=l, n_head=c.stack_config.n_head, kv_heads=c.stack_config.kv_heads, ff_dim=c.stack_config.ff_dim, eps=c.stack_config.eps, shape_rotator=self.shape_rotator, ) for l in range(c.stack_config.layers) ]) self.output = GPTOutput(c.dim, c.vocab_size) if self.c.split: self.output2 = GPTOutput(c.dim, c.vocab_size) self.cache = [None] * c.stack_config.layers self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head self.head_dim = c.stack_config.dim // c.stack_config.n_head if exists(c.from_pretrained): result = self.load_state_dict(checkpoint, strict=False) print0_colored(result, 'yellow') self.resynthesizer = c.resynthesizer_config().eval() self.resynthesizer.requires_grad = False self.audio_tokenizer = make_tokenizer(device='cpu') self.audio_cache = None self.audio_latent_cache = None self.use_audio_cache = False @T.no_grad() def tokenize(self, audio_data): orig_audio_shape = audio_data.shape if exists(self.audio_cache): audio_data = T.cat([self.audio_cache, audio_data], dim=-1) self.audio_cache = audio_data[..., -(6*16_000):] elif self.use_audio_cache: self.audio_cache = audio_data[..., -(6*16_000):] if audio_data.shape[1] == 2: enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1]) enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2]) return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):] else: return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):] @T.no_grad() def untokenize(self, token_data): if exists(self.audio_latent_cache): token_data = T.cat([self.audio_latent_cache, token_data], dim=1) self.audio_latent_cache = token_data[:, -(6*8):] elif self.use_audio_cache: self.audio_latent_cache = token_data[:, -(6*8):] if token_data.shape[-1] == 2*self.c.latent_size: dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size]) dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:]) return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):] else: return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):] def init_cache(self, bsize, device, dtype, length:int=None): cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim] self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) self.resynthesizer.init_cache(bsize, device, dtype, length) self.use_audio_cache = True def deinit_cache(self): self.cache = [None] * len(self.layers) self.resynthesizer.deinit_cache() self.audio_cache = None self.audio_latent_cache = None self.use_audio_cache = False @T.no_grad() def forward(self, data): if self.c.split: x1, x2 = data.chunk(2, dim=-1) x = self.input(x1) + self.input2(x2) else: x = self.input(data) for l, layer in enumerate(self.layers): x = layer(x, kv=self.cache[l]) if self.c.split: return self.output(x), self.output2(x) else: return self.output(x) @T.no_grad() def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))): latents_in = self.tokenize(audio_data) next_latents = self.next_latent(latents_in, temps) next_model_latent = next_latents[..., self.c.latent_size:] audio_decoded = self.untokenize(next_model_latent)[..., -2000:] return audio_decoded @T.no_grad() def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))): if self.c.split: logits1, logits2 = self.forward(model_input) next_logits1 = logits1[:, -1] next_logits2 = logits2[:, -1] next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1) next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1) next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1]) else: logits = self.forward(model_input) next_logits = logits[:, -1] next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1) next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1]) return next_input @T.no_grad() def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor: """ only accepts latent-space data. """ if use_cache: self.init_cache(data.shape[0], data.device, T.bfloat16) next_input = generated = data target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len) for _ in tqdm0(range(data.shape[1], target_len)): model_input = next_input if use_cache else generated next_input = self.next_latent(model_input, temps) generated = T.cat([generated, next_input], dim=1) if use_cache: self.deinit_cache() return generated def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False): if is_split: checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')] elif not use_pure_audio_ablation: checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')] else: checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')] quantizer_config=LatentQuantizer.Config( from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'), compressor_config=FSQ.Config( levels=[8,8,8,8,8], dim=2048, num_codebooks=1, keep_num_codebooks_dim=None, scale=None, allowed_dtypes=['float32', 'float64', 'bfloat16'], channel_first=False, projection_has_bias=True, return_indices=True, force_quantization_f32=True, use_rms=False ), dim=2048, ff_dim=8192, input_dim=32 ) resynthesizer_config=TransformerVAE.Config( io_config=GaussianMixtureIOLayer.Config( latent_dim=32, dim=4096, num_components=8, ), stack_config=Stack.Config( layers=8, dim=4096, seq_len=8192, n_head=16, ff_dim=11008, kv_heads=16, eps=1e-5, theta=10_000 ), quantizer_config=quantizer_config, plex_layer=None, plex_roll=1, split=is_split, from_pretrained=checkpoints[0], ) return HertzDevModel.Config( dim=4096, vocab_size=32_768, stack_config=Stack.Config( layers=32, dim=4096, seq_len=2048, n_head=32, ff_dim=None, kv_heads=None, eps=1e-5, theta=10_000, ), quantizer_config=quantizer_config, resynthesizer_config=resynthesizer_config, split=is_split, from_pretrained=checkpoints[1], )