jiuku's picture
Duplicate from haoheliu/audioldm-text-to-audio-generation
4039be3
raw
history blame
33.1 kB
""" CLAP Model
Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
Adapted to the Audio Task.
"""
from collections import OrderedDict
from dataclasses import dataclass
from email.mime import audio
from typing import Tuple, Union, Callable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from .timm_model import TimmModel
import logging
from .utils import freeze_batch_norm_2d
from .pann_model import create_pann_model
from .htsat import create_htsat_model
from transformers import BertModel, RobertaModel, BartModel
from transformers.tokenization_utils_base import BatchEncoding
class MLPLayers(nn.Module):
def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
super(MLPLayers, self).__init__()
self.nonlin = nonlin
self.dropout = dropout
sequence = []
for u0, u1 in zip(units[:-1], units[1:]):
sequence.append(nn.Linear(u0, u1))
sequence.append(self.nonlin)
sequence.append(nn.Dropout(self.dropout))
sequence = sequence[:-2]
self.sequential = nn.Sequential(*sequence)
def forward(self, X):
X = self.sequential(X)
return X
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
(
"0",
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False,
),
),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
2, 0, 1
) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert (
unlocked_groups == 0
), "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
def stem(self, x):
for conv, bn in [
(self.conv1, self.bn1),
(self.conv2, self.bn2),
(self.conv3, self.bn3),
]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", act_layer()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(width, heads, act_layer=act_layer)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisualTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
act_layer: Callable = nn.GELU,
):
super().__init__()
self.image_size = image_size
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert (
unlocked_groups == 0
), "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_branch(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
@dataclass
class CLAPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
timm_model_name: str = (
None # a valid model name overrides layers, width, patch_size
)
timm_model_pretrained: bool = (
False # use (imagenet) pretrained weights for named model
)
timm_pool: str = (
"avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
)
timm_proj: str = (
"linear" # linear projection for timm model output ('linear', 'mlp', '')
)
# Audio Config Class
@dataclass
class CLAPAudioCfp:
model_type: str = "PANN"
model_name: str = "Cnn14"
sample_rate: int = 48000
# Param
audio_length: int = 1024
window_size: int = 1024
hop_size: int = 1024
fmin: int = 50
fmax: int = 14000
class_num: int = 527
mel_bins: int = 64
clip_samples: int = 480000
@dataclass
class CLAPTextCfg:
context_length: int
vocab_size: int
width: int
heads: int
layers: int
model_type: str
class CLAP(nn.Module):
def __init__(
self,
embed_dim: int,
audio_cfg: CLAPAudioCfp,
text_cfg: CLAPTextCfg,
quick_gelu: bool = False,
enable_fusion: bool = False,
fusion_type: str = "None",
joint_embed_shape: int = 512,
mlp_act: str = "relu",
):
super().__init__()
if isinstance(audio_cfg, dict):
audio_cfg = CLAPAudioCfp(**audio_cfg)
if isinstance(text_cfg, dict):
text_cfg = CLAPTextCfg(**text_cfg)
self.audio_cfg = audio_cfg
self.text_cfg = text_cfg
self.enable_fusion = enable_fusion
self.fusion_type = fusion_type
self.joint_embed_shape = joint_embed_shape
self.mlp_act = mlp_act
self.context_length = text_cfg.context_length
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if mlp_act == "relu":
mlp_act_layer = nn.ReLU()
elif mlp_act == "gelu":
mlp_act_layer = nn.GELU()
else:
raise NotImplementedError
# audio branch
# audio branch parameters
if audio_cfg.model_type == "PANN":
self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
elif audio_cfg.model_type == "HTSAT":
self.audio_branch = create_htsat_model(
audio_cfg, enable_fusion, fusion_type
)
else:
logging.error(f"Model config for {audio_cfg.model_type} not found")
raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
# text branch
# text branch parameters
if text_cfg.model_type == "transformer":
self.text_branch = Transformer(
width=text_cfg.width,
layers=text_cfg.layers,
heads=text_cfg.heads,
act_layer=act_layer,
)
self.vocab_size = text_cfg.vocab_size
self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, text_cfg.width)
)
self.ln_final = LayerNorm(text_cfg.width)
self.text_transform = MLPLayers(
units=[
self.joint_embed_shape,
self.joint_embed_shape,
self.joint_embed_shape,
],
dropout=0.1,
)
self.text_projection = nn.Sequential(
nn.Linear(text_cfg.width, self.joint_embed_shape),
mlp_act_layer,
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
)
elif text_cfg.model_type == "bert":
self.text_branch = BertModel.from_pretrained("bert-base-uncased")
self.text_transform = MLPLayers(
units=[
self.joint_embed_shape,
self.joint_embed_shape,
self.joint_embed_shape,
],
dropout=0.1,
)
self.text_projection = nn.Sequential(
nn.Linear(768, self.joint_embed_shape),
mlp_act_layer,
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
)
elif text_cfg.model_type == "roberta":
self.text_branch = RobertaModel.from_pretrained("roberta-base")
self.text_transform = MLPLayers(
units=[
self.joint_embed_shape,
self.joint_embed_shape,
self.joint_embed_shape,
],
dropout=0.1,
)
self.text_projection = nn.Sequential(
nn.Linear(768, self.joint_embed_shape),
mlp_act_layer,
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
)
elif text_cfg.model_type == "bart":
self.text_branch = BartModel.from_pretrained("facebook/bart-base")
self.text_transform = MLPLayers(
units=[
self.joint_embed_shape,
self.joint_embed_shape,
self.joint_embed_shape,
],
dropout=0.1,
)
self.text_projection = nn.Sequential(
nn.Linear(768, self.joint_embed_shape),
mlp_act_layer,
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
)
else:
logging.error(f"Model config for {text_cfg.model_type} not found")
raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
self.text_branch_type = text_cfg.model_type
# text branch parameters
# audio branch parameters
self.audio_transform = MLPLayers(
units=[
self.joint_embed_shape,
self.joint_embed_shape,
self.joint_embed_shape,
],
dropout=0.1,
)
# below here is text branch parameters
# ============================================================================================================
self.audio_projection = nn.Sequential(
nn.Linear(embed_dim, self.joint_embed_shape),
mlp_act_layer,
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
)
self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
self.init_text_branch_parameters()
def init_text_branch_parameters(self):
if self.text_branch_type == "transformer":
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.text_branch.width**-0.5) * (
(2 * self.text_branch.layers) ** -0.5
)
attn_std = self.text_branch.width**-0.5
fc_std = (2 * self.text_branch.width) ** -0.5
for block in self.text_branch.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
elif self.text_branch_type == "bart":
width = self.text_branch.shared.weight.shape[-1]
else:
width = self.text_branch.width
nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
# deprecated
# if hasattr(self.visual, 'init_parameters'):
# self.visual.init_parameters()
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def encode_audio(self, audio, device):
return self.audio_branch(
audio, mixup_lambda=None, device=device
) # mix lambda needs to add
# def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
# tmp = {}
# for k in x[0].keys():
# tmp[k] = []
# for i in range(len(x)):
# tmp[k].append(x[i][k][:77])
# for k in x[0].keys():
# tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
# return tmp
def encode_text(self, text, device):
if self.text_branch_type == "transformer":
text = text.to(device=device, non_blocking=True)
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_branch(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
elif self.text_branch_type == "bert":
# text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
# text = BatchEncoding(text)
x = self.text_branch(
input_ids=text["input_ids"].to(device=device, non_blocking=True),
attention_mask=text["attention_mask"].to(
device=device, non_blocking=True
),
token_type_ids=text["token_type_ids"].to(
device=device, non_blocking=True
),
)["pooler_output"]
x = self.text_projection(x)
elif self.text_branch_type == "roberta":
x = self.text_branch(
input_ids=text["input_ids"].to(device=device, non_blocking=True),
attention_mask=text["attention_mask"].to(
device=device, non_blocking=True
),
)["pooler_output"]
x = self.text_projection(x)
elif self.text_branch_type == "bart":
x = torch.mean(
self.text_branch(
input_ids=text["input_ids"].to(device=device, non_blocking=True),
attention_mask=text["attention_mask"].to(
device=device, non_blocking=True
),
)["encoder_last_hidden_state"],
axis=1,
)
x = self.text_projection(x)
else:
logging.error(f"Model type {self.text_branch_type} not found")
raise RuntimeError(f"Model type {self.text_branch_type} not found.")
return x
def forward(self, audio, text, device=None):
"""Forward audio and text into the CLAP
Parameters
----------
audio: torch.Tensor (batch_size, audio_length)
the time-domain audio input / the batch of mel_spec and longer list.
text: torch.Tensor () // need to add
the text token input
"""
if device is None:
if audio is not None:
device = audio.device
elif text is not None:
device = text.device
if audio is None and text is None:
# a hack to get the logit scale
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
elif audio is None:
return self.encode_text(text, device=device)
elif text is None:
return self.audio_projection(
self.encode_audio(audio, device=device)["embedding"]
)
audio_features = self.audio_projection(
self.encode_audio(audio, device=device)["embedding"]
)
audio_features = F.normalize(audio_features, dim=-1)
text_features = self.encode_text(text, device=device)
# print("text_features", text_features)
# print("text_features.shape", text_features.shape)
# print("text_features.type", type(text_features))
text_features = F.normalize(text_features, dim=-1)
audio_features_mlp = self.audio_transform(audio_features)
text_features_mlp = self.text_transform(text_features)
# Four outputs: audio features (basic & MLP), text features (basic & MLP)
return (
audio_features,
text_features,
audio_features_mlp,
text_features_mlp,
self.logit_scale_a.exp(),
self.logit_scale_t.exp(),
)
def get_logit_scale(self):
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
def get_text_embedding(self, data):
"""Get the text embedding from the model
Parameters
----------
data: torch.Tensor
a tensor of text embedding
Returns
----------
text_embed: torch.Tensor
a tensor of text_embeds (N, D)
"""
device = next(self.parameters()).device
for k in data:
data[k] = data[k].to(device)
if(len(data[k].size()) < 2):
data[k] = data[k].unsqueeze(0)
text_embeds = self.encode_text(data, device=device)
text_embeds = F.normalize(text_embeds, dim=-1)
return text_embeds
def get_audio_embedding(self, data):
"""Get the audio embedding from the model
Parameters
----------
data: a list of dict
the audio input dict list from 'get_audio_feature' method
Returns
----------
audio_embed: torch.Tensor
a tensor of audio_embeds (N, D)
"""
device = next(self.parameters()).device
input_dict = {}
keys = data[0].keys()
for k in keys:
input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
device
)
audio_embeds = self.audio_projection(
self.encode_audio(input_dict, device=device)["embedding"]
)
audio_embeds = F.normalize(audio_embeds, dim=-1)
return audio_embeds
def audio_infer(self, audio, hopsize=None, device=None):
"""Forward one audio and produce the audio embedding
Parameters
----------
audio: (audio_length)
the time-domain audio input, notice that it must be only one input
hopsize: int
the overlap hopsize as the sliding window
Returns
----------
output_dict: {
key: [n, (embedding_shape)] if "HTS-AT"
or
key: [(embedding_shape)] if "PANN"
}
the list of key values of the audio branch
"""
assert not self.training, "the inference mode must be run at eval stage"
output_dict = {}
# PANN
if self.audio_cfg.model_type == "PANN":
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)
elif self.audio_cfg.model_type == "HTSAT":
# repeat
audio_len = len(audio)
k = self.audio_cfg.clip_samples // audio_len
if k > 1:
audio = audio.repeat(k)
audio_len = len(audio)
if hopsize is None:
hopsize = min(hopsize, audio_len)
if audio_len > self.audio_cfg.clip_samples:
audio_input = [
audio[pos : pos + self.audio_cfg.clip_samples].clone()
for pos in range(
0, audio_len - self.audio_cfg.clip_samples, hopsize
)
]
audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
audio_input = torch.stack(audio_input)
output_dict[key] = self.encode_audio(audio_input, device=device)[key]
else:
audio_input = audio.unsqueeze(dim=0)
output_dict[key] = self.encode_audio(audio_input, device=device)[
key
].squeeze(dim=0)
return output_dict
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias",
"bias_k",
"bias_v",
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
# Ignore the state dict of the vision part
def build_model_from_openai_state_dict(
state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
):
embed_dim = model_cfg["embed_dim"]
audio_cfg = model_cfg["audio_cfg"]
text_cfg = model_cfg["text_cfg"]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"transformer.resblocks")
)
)
audio_cfg = CLAPAudioCfp(**audio_cfg)
text_cfg = CLAPTextCfg(**text_cfg)
model = CLAP(
embed_dim,
audio_cfg=audio_cfg,
text_cfg=text_cfg,
quick_gelu=True, # OpenAI models were trained with QuickGELU
enable_fusion=enable_fusion,
fusion_type=fusion_type,
)
state_dict["logit_scale_a"] = state_dict["logit_scale"]
state_dict["logit_scale_t"] = state_dict["logit_scale"]
pop_keys = list(state_dict.keys())[::]
# pop the visual branch saved weights
for key in pop_keys:
if key.startswith("visual."):
state_dict.pop(key, None)
for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
# not use fp16
# convert_weights_to_fp16(model)
model.load_state_dict(state_dict, strict=False)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device("cpu")):
model.eval()
audio_length = model.audio_cfg.audio_length
example_audio = torch.ones((batch_size, audio_length), device=device)
example_text = torch.zeros(
(batch_size, model.context_length), dtype=torch.int, device=device
)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_audio, example_text),
encode_text=(example_text,),
encode_image=(example_audio,),
),
)
model.audio_cfg.audio_length = audio_length # Question: what does this do?
return model