Datasculptor's picture
Duplicate from AIGC-Audio/AudioGPT
98f685a
# -*- coding: utf-8 -*-
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio import transforms
from torchlibrosa.augmentation import SpecAugmentation
from .utils import mean_with_lens, max_with_lens, \
init, pack_wrapper, generate_length_mask, PositionalEncoding
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class BaseEncoder(nn.Module):
"""
Encode the given audio into embedding
Base encoder class, cannot be called directly
All encoders should inherit from this class
"""
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
super(BaseEncoder, self).__init__()
self.spec_dim = spec_dim
self.fc_feat_dim = fc_feat_dim
self.attn_feat_dim = attn_feat_dim
def forward(self, x):
#########################
# an encoder first encodes audio feature into embedding, obtaining
# `encoded`: {
# fc_embs: [N, fc_emb_dim],
# attn_embs: [N, attn_max_len, attn_emb_dim],
# attn_emb_lens: [N,]
# }
#########################
raise NotImplementedError
class Block2D(nn.Module):
def __init__(self, cin, cout, kernel_size=3, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm2d(cin),
nn.Conv2d(cin,
cout,
kernel_size=kernel_size,
padding=padding,
bias=False),
nn.LeakyReLU(inplace=True, negative_slope=0.1))
def forward(self, x):
return self.block(x)
class LinearSoftPool(nn.Module):
"""LinearSoftPool
Linear softmax, takes logits and returns a probability, near to the actual maximum value.
Taken from the paper:
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
https://arxiv.org/abs/1810.09050
"""
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, time_decision):
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
self.pooldim)
class MeanPool(nn.Module):
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, decision):
return torch.mean(decision, dim=self.pooldim)
class AttentionPool(nn.Module):
"""docstring for AttentionPool"""
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
super().__init__()
self.inputdim = inputdim
self.outputdim = outputdim
self.pooldim = pooldim
self.transform = nn.Linear(inputdim, outputdim)
self.activ = nn.Softmax(dim=self.pooldim)
self.eps = 1e-7
def forward(self, logits, decision):
# Input is (B, T, D)
# B, T, D
w = self.activ(torch.clamp(self.transform(logits), -15, 15))
detect = (decision * w).sum(
self.pooldim) / (w.sum(self.pooldim) + self.eps)
# B, T, D
return detect
class MMPool(nn.Module):
def __init__(self, dims):
super().__init__()
self.avgpool = nn.AvgPool2d(dims)
self.maxpool = nn.MaxPool2d(dims)
def forward(self, x):
return self.avgpool(x) + self.maxpool(x)
def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
"""parse_poolingfunction
A heler function to parse any temporal pooling
Pooling is done on dimension 1
:param poolingfunction_name:
:param **kwargs:
"""
poolingfunction_name = poolingfunction_name.lower()
if poolingfunction_name == 'mean':
return MeanPool(pooldim=1)
elif poolingfunction_name == 'linear':
return LinearSoftPool(pooldim=1)
elif poolingfunction_name == 'attention':
return AttentionPool(inputdim=kwargs['inputdim'],
outputdim=kwargs['outputdim'])
def embedding_pooling(x, lens, pooling="mean"):
if pooling == "max":
fc_embs = max_with_lens(x, lens)
elif pooling == "mean":
fc_embs = mean_with_lens(x, lens)
elif pooling == "mean+max":
x_mean = mean_with_lens(x, lens)
x_max = max_with_lens(x, lens)
fc_embs = x_mean + x_max
elif pooling == "last":
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
# indices: [N, 1, hidden]
fc_embs = torch.gather(x, 1, indices).squeeze(1)
else:
raise Exception(f"pooling method {pooling} not support")
return fc_embs
class Cdur5Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.features = nn.Sequential(
Block2D(1, 32),
nn.LPPool2d(4, (2, 4)),
Block2D(32, 128),
Block2D(128, 128),
nn.LPPool2d(4, (2, 4)),
Block2D(128, 128),
Block2D(128, 128),
nn.LPPool2d(4, (1, 4)),
nn.Dropout(0.3),
)
with torch.no_grad():
rnn_input_dim = self.features(
torch.randn(1, 1, 500, spec_dim)).shape
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
self.gru = nn.GRU(rnn_input_dim,
128,
bidirectional=True,
batch_first=True)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
if "upsample" not in input_dict:
input_dict["upsample"] = False
lens = torch.as_tensor(copy.deepcopy(lens))
N, T, _ = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x, _ = self.gru(x)
if input_dict["upsample"]:
x = nn.functional.interpolate(
x.transpose(1, 2),
T,
mode='linear',
align_corners=False).transpose(1, 2)
else:
lens //= 4
attn_emb = x
fc_emb = embedding_pooling(x, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
def conv_conv_block(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel,
out_channel,
kernel_size=3,
bias=False,
padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel,
out_channel,
kernel_size=3,
bias=False,
padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True)
)
class Cdur8Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.features = nn.Sequential(
conv_conv_block(1, 64),
MMPool((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(64, 128),
MMPool((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(128, 256),
MMPool((1, 2)),
nn.Dropout(0.2, True),
conv_conv_block(256, 512),
MMPool((1, 2)),
nn.Dropout(0.2, True),
nn.AdaptiveAvgPool2d((None, 1)),
)
self.init_bn = nn.BatchNorm2d(spec_dim)
self.embedding = nn.Linear(512, 512)
self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
lens = torch.as_tensor(copy.deepcopy(lens))
x = x.unsqueeze(1) # B x 1 x T x D
x = x.transpose(1, 3)
x = self.init_bn(x)
x = x.transpose(1, 3)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.embedding(x))
x, _ = self.gru(x)
attn_emb = x
lens //= 4
fc_emb = embedding_pooling(x, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class Cnn10Encoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.features = nn.Sequential(
conv_conv_block(1, 64),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(64, 128),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(128, 256),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
conv_conv_block(256, 512),
nn.AvgPool2d((2, 2)),
nn.Dropout(0.2, True),
nn.AdaptiveAvgPool2d((None, 1)),
)
self.init_bn = nn.BatchNorm2d(spec_dim)
self.embedding = nn.Linear(512, 512)
self.apply(init)
def forward(self, input_dict):
x = input_dict["spec"]
lens = input_dict["spec_len"]
lens = torch.as_tensor(copy.deepcopy(lens))
x = x.unsqueeze(1) # [N, 1, T, D]
x = x.transpose(1, 3)
x = self.init_bn(x)
x = x.transpose(1, 3)
x = self.features(x) # [N, 512, T/16, 1]
x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
attn_emb = x
lens //= 16
fc_emb = embedding_pooling(x, lens, "mean+max")
fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
fc_emb = self.embedding(fc_emb)
fc_emb = F.relu_(fc_emb)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class Cnn14Encoder(nn.Module):
def __init__(self, sample_rate=32000):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64,
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.downsample_ratio = 32
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def load_pretrained(self, pretrained):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
state_keys = checkpoint["model"].keys()
backbone = False
for key in state_keys:
if key.startswith("backbone."):
backbone = True
break
if backbone: # COLA
state_dict = {}
for key, value in checkpoint["model"].items():
if key.startswith("backbone."):
model_key = key.replace("backbone.", "")
state_dict[model_key] = value
else: # PANNs
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint: # CLAP
state_dict = checkpoint["state_dict"]
state_dict_keys = list(filter(
lambda x: "audio_encoder" in x, state_dict.keys()))
state_dict = {
key.replace('audio_encoder.', ''): state_dict[key]
for key in state_dict_keys
}
else:
raise Exception("Unkown checkpoint format")
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (
model_dict[k].shape == v.shape)
}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict, strict=True)
def forward(self, input_dict):
"""
Input: (batch_size, n_samples)"""
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
# SpecAugment
if self.training and specaug:
x = self.spec_augmenter(x)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class RnnEncoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
pooling="mean", **kwargs):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.pooling = pooling
self.hidden_size = kwargs.get('hidden_size', 512)
self.bidirectional = kwargs.get('bidirectional', False)
self.num_layers = kwargs.get('num_layers', 1)
self.dropout = kwargs.get('dropout', 0.2)
self.rnn_type = kwargs.get('rnn_type', "GRU")
self.in_bn = kwargs.get('in_bn', False)
self.embed_dim = self.hidden_size * (self.bidirectional + 1)
self.network = getattr(nn, self.rnn_type)(
attn_feat_dim,
self.hidden_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout,
batch_first=True)
if self.in_bn:
self.bn = nn.BatchNorm1d(self.embed_dim)
self.apply(init)
def forward(self, input_dict):
x = input_dict["attn"]
lens = input_dict["attn_len"]
lens = torch.as_tensor(lens)
# x: [N, T, E]
if self.in_bn:
x = pack_wrapper(self.bn, x, lens)
out = pack_wrapper(self.network, x, lens)
# out: [N, T, hidden]
attn_emb = out
fc_emb = embedding_pooling(out, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class Cnn14RnnEncoder(nn.Module):
def __init__(self, sample_rate=32000, pretrained=None,
freeze_cnn=False, freeze_cnn_bn=False,
pooling="mean", **kwargs):
super().__init__()
self.cnn = Cnn14Encoder(sample_rate)
self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
if pretrained is not None:
self.cnn.load_pretrained(pretrained)
if freeze_cnn:
assert pretrained is not None, "cnn is not pretrained but frozen"
for param in self.cnn.parameters():
param.requires_grad = False
self.freeze_cnn_bn = freeze_cnn_bn
def train(self, mode):
super().train(mode=mode)
if self.freeze_cnn_bn:
def bn_eval(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
self.cnn.apply(bn_eval)
return self
def forward(self, input_dict):
output_dict = self.cnn(input_dict)
output_dict["attn"] = output_dict["attn_emb"]
output_dict["attn_len"] = output_dict["attn_emb_len"]
del output_dict["attn_emb"], output_dict["attn_emb_len"]
output_dict = self.rnn(output_dict)
return output_dict
class TransformerEncoder(BaseEncoder):
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
self.d_model = d_model
dropout = kwargs.get("dropout", 0.2)
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.attn_proj = nn.Sequential(
nn.Linear(attn_feat_dim, self.d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.LayerNorm(self.d_model)
)
layer = nn.TransformerEncoderLayer(d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
dropout=dropout)
self.model = nn.TransformerEncoder(layer, self.nlayers)
self.cls_token = nn.Parameter(torch.zeros(d_model))
self.init_params()
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, input_dict):
attn_feat = input_dict["attn"]
attn_feat_len = input_dict["attn_len"]
attn_feat_len = torch.as_tensor(attn_feat_len)
attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
attn_feat.size(0), 1, 1)
attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
attn_feat = attn_feat.transpose(0, 1)
attn_feat_len += 1
src_key_padding_mask = ~generate_length_mask(
attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
attn_emb = output.transpose(0, 1)
fc_emb = attn_emb[:, 0]
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": attn_feat_len
}
class Cnn14TransformerEncoder(nn.Module):
def __init__(self, sample_rate=32000, pretrained=None,
freeze_cnn=False, freeze_cnn_bn=False,
d_model="mean", **kwargs):
super().__init__()
self.cnn = Cnn14Encoder(sample_rate)
self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
if pretrained is not None:
self.cnn.load_pretrained(pretrained)
if freeze_cnn:
assert pretrained is not None, "cnn is not pretrained but frozen"
for param in self.cnn.parameters():
param.requires_grad = False
self.freeze_cnn_bn = freeze_cnn_bn
def train(self, mode):
super().train(mode=mode)
if self.freeze_cnn_bn:
def bn_eval(module):
class_name = module.__class__.__name__
if class_name.find("BatchNorm") != -1:
module.eval()
self.cnn.apply(bn_eval)
return self
def forward(self, input_dict):
output_dict = self.cnn(input_dict)
output_dict["attn"] = output_dict["attn_emb"]
output_dict["attn_len"] = output_dict["attn_emb_len"]
del output_dict["attn_emb"], output_dict["attn_emb_len"]
output_dict = self.trm(output_dict)
return output_dict