Spaces:
Build error
Build error
# !/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# @Time : 2021/3/9 16:33 | |
# @Author : dongchao yang | |
# @File : train.py | |
from itertools import zip_longest | |
import numpy as np | |
from scipy import ndimage | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import time | |
from torchlibrosa.augmentation import SpecAugmentation | |
from torchlibrosa.stft import Spectrogram, LogmelFilterBank | |
import math | |
from sklearn.cluster import KMeans | |
import os | |
import time | |
from functools import partial | |
# import timm | |
# from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
import warnings | |
from functools import partial | |
# from timm.models.registry import register_model | |
# from timm.models.vision_transformer import _cfg | |
# from mmdet.utils import get_root_logger | |
# from mmcv.runner import load_checkpoint | |
# from mmcv.runner import _load_checkpoint, load_state_dict | |
# import mmcv.runner | |
import copy | |
from collections import OrderedDict | |
import io | |
import re | |
DEBUG=0 | |
event_labels = ['Alarm', 'Alarm_clock', 'Animal', 'Applause', 'Arrow', 'Artillery_fire', | |
'Babbling', 'Baby_laughter', 'Bark', 'Basketball_bounce', 'Battle_cry', | |
'Bell', 'Bird', 'Bleat', 'Bouncing', 'Breathing', 'Buzz', 'Camera', | |
'Cap_gun', 'Car', 'Car_alarm', 'Cat', 'Caw', 'Cheering', 'Child_singing', | |
'Choir', 'Chop', 'Chopping_(food)', 'Clapping', 'Clickety-clack', 'Clicking', | |
'Clip-clop', 'Cluck', 'Coin_(dropping)', 'Computer_keyboard', 'Conversation', | |
'Coo', 'Cough', 'Cowbell', 'Creak', 'Cricket', 'Croak', 'Crow', 'Crowd', 'DTMF', | |
'Dog', 'Door', 'Drill', 'Drip', 'Engine', 'Engine_starting', 'Explosion', 'Fart', | |
'Female_singing', 'Filing_(rasp)', 'Finger_snapping', 'Fire', 'Fire_alarm', 'Firecracker', | |
'Fireworks', 'Frog', 'Gasp', 'Gears', 'Giggle', 'Glass', 'Glass_shatter', 'Gobble', 'Groan', | |
'Growling', 'Hammer', 'Hands', 'Hiccup', 'Honk', 'Hoot', 'Howl', 'Human_sounds', 'Human_voice', | |
'Insect', 'Laughter', 'Liquid', 'Machine_gun', 'Male_singing', 'Mechanisms', 'Meow', 'Moo', | |
'Motorcycle', 'Mouse', 'Music', 'Oink', 'Owl', 'Pant', 'Pant_(dog)', 'Patter', 'Pig', 'Plop', | |
'Pour', 'Power_tool', 'Purr', 'Quack', 'Radio', 'Rain_on_surface', 'Rapping', 'Rattle', | |
'Reversing_beeps', 'Ringtone', 'Roar', 'Run', 'Rustle', 'Scissors', 'Scrape', 'Scratch', | |
'Screaming', 'Sewing_machine', 'Shout', 'Shuffle', 'Shuffling_cards', 'Singing', | |
'Single-lens_reflex_camera', 'Siren', 'Skateboard', 'Sniff', 'Snoring', 'Speech', | |
'Speech_synthesizer', 'Spray', 'Squeak', 'Squeal', 'Steam', 'Stir', 'Surface_contact', | |
'Tap', 'Tap_dance', 'Telephone_bell_ringing', 'Television', 'Tick', 'Tick-tock', 'Tools', | |
'Train', 'Train_horn', 'Train_wheels_squealing', 'Truck', 'Turkey', 'Typewriter', 'Typing', | |
'Vehicle', 'Video_game_sound', 'Water', 'Whimper_(dog)', 'Whip', 'Whispering', 'Whistle', | |
'Whistling', 'Whoop', 'Wind', 'Writing', 'Yip', 'and_pans', 'bird_song', 'bleep', 'clink', | |
'cock-a-doodle-doo', 'crinkling', 'dove', 'dribble', 'eructation', 'faucet', 'flapping_wings', | |
'footsteps', 'gunfire', 'heartbeat', 'infant_cry', 'kid_speaking', 'man_speaking', 'mastication', | |
'mice', 'river', 'rooster', 'silverware', 'skidding', 'smack', 'sobbing', 'speedboat', 'splatter', | |
'surf', 'thud', 'thwack', 'toot', 'truck_horn', 'tweet', 'vroom', 'waterfowl', 'woman_speaking'] | |
def load_checkpoint(model, | |
filename, | |
map_location=None, | |
strict=False, | |
logger=None, | |
revise_keys=[(r'^module\.', '')]): | |
"""Load checkpoint from a file or URI. | |
Args: | |
model (Module): Module to load checkpoint. | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
details. | |
map_location (str): Same as :func:`torch.load`. | |
strict (bool): Whether to allow different params for the model and | |
checkpoint. | |
logger (:mod:`logging.Logger` or None): The logger for error message. | |
revise_keys (list): A list of customized keywords to modify the | |
state_dict in checkpoint. Each item is a (pattern, replacement) | |
pair of the regular expression operations. Default: strip | |
the prefix 'module.' by [(r'^module\\.', '')]. | |
Returns: | |
dict or OrderedDict: The loaded checkpoint. | |
""" | |
checkpoint = _load_checkpoint(filename, map_location, logger) | |
''' | |
new_proj = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) | |
new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=1).unsqueeze(1)) | |
checkpoint['patch_embed1.proj.weight'] = new_proj.weight | |
new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=2).unsqueeze(2).repeat(1,1,3,1)) | |
checkpoint['patch_embed1.proj.weight'] = new_proj.weight | |
new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=3).unsqueeze(3).repeat(1,1,1,3)) | |
checkpoint['patch_embed1.proj.weight'] = new_proj.weight | |
''' | |
new_proj = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) | |
new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=1).unsqueeze(1)) | |
checkpoint['patch_embed1.proj.weight'] = new_proj.weight | |
# OrderedDict is a subclass of dict | |
if not isinstance(checkpoint, dict): | |
raise RuntimeError( | |
f'No state_dict found in checkpoint file {filename}') | |
# get state_dict from checkpoint | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
else: | |
state_dict = checkpoint | |
# strip prefix of state_dict | |
metadata = getattr(state_dict, '_metadata', OrderedDict()) | |
for p, r in revise_keys: | |
state_dict = OrderedDict( | |
{re.sub(p, r, k): v | |
for k, v in state_dict.items()}) | |
state_dict = OrderedDict({k.replace('backbone.',''):v for k,v in state_dict.items()}) | |
# Keep metadata in state_dict | |
state_dict._metadata = metadata | |
# load state_dict | |
load_state_dict(model, state_dict, strict, logger) | |
return checkpoint | |
def init_weights(m): | |
if isinstance(m, (nn.Conv2d, nn.Conv1d)): | |
nn.init.kaiming_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def init_layer(layer): | |
"""Initialize a Linear or Convolutional layer. """ | |
nn.init.xavier_uniform_(layer.weight) | |
if hasattr(layer, 'bias'): | |
if layer.bias is not None: | |
layer.bias.data.fill_(0.) | |
def init_bn(bn): | |
"""Initialize a Batchnorm layer. """ | |
bn.bias.data.fill_(0.) | |
bn.weight.data.fill_(1.) | |
class MaxPool(nn.Module): | |
def __init__(self, pooldim=1): | |
super().__init__() | |
self.pooldim = pooldim | |
def forward(self, logits, decision): | |
return torch.max(decision, dim=self.pooldim)[0] | |
class LinearSoftPool(nn.Module): | |
"""LinearSoftPool | |
Linear softmax, takes logits and returns a probability, near to the actual maximum value. | |
Taken from the paper: | |
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling | |
https://arxiv.org/abs/1810.09050 | |
""" | |
def __init__(self, pooldim=1): | |
super().__init__() | |
self.pooldim = pooldim | |
def forward(self, logits, time_decision): | |
return (time_decision**2).sum(self.pooldim) / (time_decision.sum( | |
self.pooldim)+1e-7) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.conv2 = nn.Conv2d(in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.init_weight() | |
def init_weight(self): | |
init_layer(self.conv1) | |
init_layer(self.conv2) | |
init_bn(self.bn1) | |
init_bn(self.bn2) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
x = input | |
x = F.relu_(self.bn1(self.conv1(x))) | |
x = F.relu_(self.bn2(self.conv2(x))) | |
if pool_type == 'max': | |
x = F.max_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg': | |
x = F.avg_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg+max': | |
x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
x2 = F.max_pool2d(x, kernel_size=pool_size) | |
x = x1 + x2 | |
else: | |
raise Exception('Incorrect argument!') | |
return x | |
class ConvBlock_GLU(nn.Module): | |
def __init__(self, in_channels, out_channels,kernel_size=(3,3)): | |
super(ConvBlock_GLU, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.sigmoid = nn.Sigmoid() | |
self.init_weight() | |
def init_weight(self): | |
init_layer(self.conv1) | |
init_bn(self.bn1) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
x = input | |
x = self.bn1(self.conv1(x)) | |
cnn1 = self.sigmoid(x[:, :x.shape[1]//2, :, :]) | |
cnn2 = x[:,x.shape[1]//2:,:,:] | |
x = cnn1*cnn2 | |
if pool_type == 'max': | |
x = F.max_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg': | |
x = F.avg_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg+max': | |
x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
x2 = F.max_pool2d(x, kernel_size=pool_size) | |
x = x1 + x2 | |
elif pool_type == 'None': | |
pass | |
elif pool_type == 'LP': | |
pass | |
#nn.LPPool2d(4, pool_size) | |
else: | |
raise Exception('Incorrect argument!') | |
return x | |
class Mul_scale_GLU(nn.Module): | |
def __init__(self): | |
super(Mul_scale_GLU,self).__init__() | |
self.conv_block1_1 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(1,1)) # 1*1 | |
self.conv_block1_2 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(3,3)) # 3*3 | |
self.conv_block1_3 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(5,5)) # 5*5 | |
self.conv_block2 = ConvBlock_GLU(in_channels=96, out_channels=128*2) | |
# self.conv_block3 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock_GLU(in_channels=128, out_channels=128*2) | |
self.conv_block4 = ConvBlock_GLU(in_channels=128, out_channels=256*2) | |
self.conv_block5 = ConvBlock_GLU(in_channels=256, out_channels=256*2) | |
self.conv_block6 = ConvBlock_GLU(in_channels=256, out_channels=512*2) | |
self.conv_block7 = ConvBlock_GLU(in_channels=512, out_channels=512*2) | |
self.padding = nn.ReplicationPad2d((0,1,0,1)) | |
def forward(self, input, fi=None): | |
""" | |
Input: (batch_size, data_length)""" | |
x1 = self.conv_block1_1(input, pool_size=(2, 2), pool_type='avg') | |
x1 = x1[:,:,:500,:32] | |
#print('x1 ',x1.shape) | |
x2 = self.conv_block1_2(input,pool_size=(2,2),pool_type='avg') | |
#print('x2 ',x2.shape) | |
x3 = self.conv_block1_3(input,pool_size=(2,2),pool_type='avg') | |
x3 = self.padding(x3) | |
#print('x3 ',x3.shape) | |
# assert 1==2 | |
x = torch.cat([x1,x2],dim=1) | |
x = torch.cat([x,x3],dim=1) | |
#print('x ',x.shape) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='None') | |
x = self.conv_block3(x,pool_size=(2,2),pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) # | |
#print('x2,3 ',x.shape) | |
x = self.conv_block4(x, pool_size=(2, 4), pool_type='None') | |
x = self.conv_block5(x,pool_size=(2,4),pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
#print('x4,5 ',x.shape) | |
x = self.conv_block6(x, pool_size=(1, 4), pool_type='None') | |
x = self.conv_block7(x, pool_size=(1, 4), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
# print('x6,7 ',x.shape) | |
# assert 1==2 | |
return x | |
class Cnn14(nn.Module): | |
def __init__(self, sample_rate=32000, window_size=1024, hop_size=320, mel_bins=64, fmin=50, | |
fmax=14000, classes_num=527): | |
super(Cnn14, self).__init__() | |
window = 'hann' | |
center = True | |
pad_mode = 'reflect' | |
ref = 1.0 | |
amin = 1e-10 | |
top_db = None | |
# Spectrogram extractor | |
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, | |
win_length=window_size, window=window, center=center, pad_mode=pad_mode, | |
freeze_parameters=True) | |
# Logmel feature extractor | |
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, | |
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, | |
freeze_parameters=True) | |
# Spec augmenter | |
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, | |
freq_drop_width=8, freq_stripes_num=2) | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) | |
self.fc1 = nn.Linear(2048, 128, bias=True) | |
self.fc_audioset = nn.Linear(128, classes_num, bias=True) | |
self.init_weight() | |
def init_weight(self): | |
init_layer(self.fc1) | |
init_layer(self.fc_audioset) | |
def forward(self, input_, mixup_lambda=None): | |
""" | |
Input: (batch_size, data_length)""" | |
input_ = input_.unsqueeze(1) | |
x = self.conv_block1(input_, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block5(x, pool_size=(1, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block6(x, pool_size=(1, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
# print(x.shape) | |
# x = torch.mean(x, dim=3) | |
x = x.transpose(1, 2).contiguous().flatten(-2) | |
x = self.fc1(x) | |
# print(x.shape) | |
# assert 1==2 | |
# (x1,_) = torch.max(x, dim=2) | |
# x2 = torch.mean(x, dim=2) | |
# x = x1 + x2 | |
# x = F.dropout(x, p=0.5, training=self.training) | |
# x = F.relu_(self.fc1(x)) | |
# embedding = F.dropout(x, p=0.5, training=self.training) | |
return x | |
class Cnn10_fi(nn.Module): | |
def __init__(self): | |
super(Cnn10_fi, self).__init__() | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
# self.fc1 = nn.Linear(512, 512, bias=True) | |
# self.fc_audioset = nn.Linear(512, classes_num, bias=True) | |
# self.init_weight() | |
def forward(self, input, fi=None): | |
""" | |
Input: (batch_size, data_length)""" | |
x = self.conv_block1(input, pool_size=(2, 2), pool_type='avg') | |
if fi != None: | |
gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
x = (gamma)*x + beta | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
if fi != None: | |
gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
x = (gamma)*x + beta | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 4), pool_type='avg') | |
if fi != None: | |
gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
x = (gamma)*x + beta | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(1, 4), pool_type='avg') | |
if fi != None: | |
gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) | |
x = (gamma)*x + beta | |
x = F.dropout(x, p=0.2, training=self.training) | |
return x | |
class Cnn10_mul_scale(nn.Module): | |
def __init__(self,scale=8): | |
super(Cnn10_mul_scale, self).__init__() | |
self.conv_block1_1 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(1,1)) | |
self.conv_block1_2 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(3,3)) | |
self.conv_block1_3 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(5,5)) | |
self.conv_block2 = ConvBlock(in_channels=96, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.scale = scale | |
self.padding = nn.ReplicationPad2d((0,1,0,1)) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
""" | |
Input: (batch_size, data_length)""" | |
if self.scale == 8: | |
pool_size1 = (2,2) | |
pool_size2 = (2,2) | |
pool_size3 = (2,4) | |
pool_size4 = (1,4) | |
elif self.scale == 4: | |
pool_size1 = (2,2) | |
pool_size2 = (2,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
elif self.scale == 2: | |
pool_size1 = (2,2) | |
pool_size2 = (1,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
else: | |
pool_size1 = (1,2) | |
pool_size2 = (1,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
# print('input ',input.shape) | |
x1 = self.conv_block1_1(input, pool_size=pool_size1, pool_type='avg') | |
x1 = x1[:,:,:500,:32] | |
#print('x1 ',x1.shape) | |
x2 = self.conv_block1_2(input, pool_size=pool_size1, pool_type='avg') | |
#print('x2 ',x2.shape) | |
x3 = self.conv_block1_3(input, pool_size=pool_size1, pool_type='avg') | |
x3 = self.padding(x3) | |
#print('x3 ',x3.shape) | |
# assert 1==2 | |
m_i = min(x3.shape[2],min(x1.shape[2],x2.shape[2])) | |
#print('m_i ', m_i) | |
x = torch.cat([x1[:,:,:m_i,:],x2[:,:, :m_i,:],x3[:,:, :m_i,:]],dim=1) | |
# x = torch.cat([x,x3],dim=1) | |
# x = self.conv_block1(input, pool_size=pool_size1, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=pool_size2, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=pool_size3, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=pool_size4, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
return x | |
class Cnn10(nn.Module): | |
def __init__(self,scale=8): | |
super(Cnn10, self).__init__() | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.scale = scale | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
""" | |
Input: (batch_size, data_length)""" | |
if self.scale == 8: | |
pool_size1 = (2,2) | |
pool_size2 = (2,2) | |
pool_size3 = (2,4) | |
pool_size4 = (1,4) | |
elif self.scale == 4: | |
pool_size1 = (2,2) | |
pool_size2 = (2,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
elif self.scale == 2: | |
pool_size1 = (2,2) | |
pool_size2 = (1,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
else: | |
pool_size1 = (1,2) | |
pool_size2 = (1,2) | |
pool_size3 = (1,4) | |
pool_size4 = (1,4) | |
x = self.conv_block1(input, pool_size=pool_size1, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=pool_size2, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=pool_size3, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=pool_size4, pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
return x | |
class MeanPool(nn.Module): | |
def __init__(self, pooldim=1): | |
super().__init__() | |
self.pooldim = pooldim | |
def forward(self, logits, decision): | |
return torch.mean(decision, dim=self.pooldim) | |
class ResPool(nn.Module): | |
def __init__(self, pooldim=1): | |
super().__init__() | |
self.pooldim = pooldim | |
self.linPool = LinearSoftPool(pooldim=1) | |
class AutoExpPool(nn.Module): | |
def __init__(self, outputdim=10, pooldim=1): | |
super().__init__() | |
self.outputdim = outputdim | |
self.alpha = nn.Parameter(torch.full((outputdim, ), 1)) | |
self.pooldim = pooldim | |
def forward(self, logits, decision): | |
scaled = self.alpha * decision # \alpha * P(Y|x) in the paper | |
return (logits * torch.exp(scaled)).sum( | |
self.pooldim) / torch.exp(scaled).sum(self.pooldim) | |
class SoftPool(nn.Module): | |
def __init__(self, T=1, pooldim=1): | |
super().__init__() | |
self.pooldim = pooldim | |
self.T = T | |
def forward(self, logits, decision): | |
w = torch.softmax(decision / self.T, dim=self.pooldim) | |
return torch.sum(decision * w, dim=self.pooldim) | |
class AutoPool(nn.Module): | |
"""docstring for AutoPool""" | |
def __init__(self, outputdim=10, pooldim=1): | |
super().__init__() | |
self.outputdim = outputdim | |
self.alpha = nn.Parameter(torch.ones(outputdim)) | |
self.dim = pooldim | |
def forward(self, logits, decision): | |
scaled = self.alpha * decision # \alpha * P(Y|x) in the paper | |
weight = torch.softmax(scaled, dim=self.dim) | |
return torch.sum(decision * weight, dim=self.dim) # B x C | |
class ExtAttentionPool(nn.Module): | |
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): | |
super().__init__() | |
self.inputdim = inputdim | |
self.outputdim = outputdim | |
self.pooldim = pooldim | |
self.attention = nn.Linear(inputdim, outputdim) | |
nn.init.zeros_(self.attention.weight) | |
nn.init.zeros_(self.attention.bias) | |
self.activ = nn.Softmax(dim=self.pooldim) | |
def forward(self, logits, decision): | |
# Logits of shape (B, T, D), decision of shape (B, T, C) | |
w_x = self.activ(self.attention(logits) / self.outputdim) | |
h = (logits.permute(0, 2, 1).contiguous().unsqueeze(-2) * | |
w_x.unsqueeze(-1)).flatten(-2).contiguous() | |
return torch.sum(h, self.pooldim) | |
class AttentionPool(nn.Module): | |
"""docstring for AttentionPool""" | |
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): | |
super().__init__() | |
self.inputdim = inputdim | |
self.outputdim = outputdim | |
self.pooldim = pooldim | |
self.transform = nn.Linear(inputdim, outputdim) | |
self.activ = nn.Softmax(dim=self.pooldim) | |
self.eps = 1e-7 | |
def forward(self, logits, decision): | |
# Input is (B, T, D) | |
# B, T , D | |
w = self.activ(torch.clamp(self.transform(logits), -15, 15)) | |
detect = (decision * w).sum( | |
self.pooldim) / (w.sum(self.pooldim) + self.eps) | |
# B, T, D | |
return detect | |
class Block2D(nn.Module): | |
def __init__(self, cin, cout, kernel_size=3, padding=1): | |
super().__init__() | |
self.block = nn.Sequential( | |
nn.BatchNorm2d(cin), | |
nn.Conv2d(cin, | |
cout, | |
kernel_size=kernel_size, | |
padding=padding, | |
bias=False), | |
nn.LeakyReLU(inplace=True, negative_slope=0.1)) | |
def forward(self, x): | |
return self.block(x) | |
class AudioCNN(nn.Module): | |
def __init__(self, classes_num): | |
super(AudioCNN, self).__init__() | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.fc1 = nn.Linear(512,128,bias=True) | |
self.fc = nn.Linear(128, classes_num, bias=True) | |
self.init_weights() | |
def init_weights(self): | |
init_layer(self.fc) | |
def forward(self, input): | |
''' | |
Input: (batch_size, times_steps, freq_bins)''' | |
# [128, 801, 168] --> [128,1,801,168] | |
x = input[:, None, :, :] | |
'''(batch_size, 1, times_steps, freq_bins)''' | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') # 128,64,400,84 | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') # 128,128,200,42 | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') # 128,256,100,21 | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') # 128,512,50,10 | |
'''(batch_size, feature_maps, time_steps, freq_bins)''' | |
x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_stpes) # 128,512,50 | |
(x, _) = torch.max(x, dim=2) # (batch_size, feature_maps) 128,512 | |
x = self.fc1(x) # 128,128 | |
output = self.fc(x) # 128,10 | |
return x,output | |
def extract(self,input): | |
'''Input: (batch_size, times_steps, freq_bins)''' | |
x = input[:, None, :, :] | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
'''(batch_size, feature_maps, time_steps, freq_bins)''' | |
x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_stpes) | |
(x, _) = torch.max(x, dim=2) # (batch_size, feature_maps) | |
x = self.fc1(x) # 128,128 | |
return x | |
def parse_poolingfunction(poolingfunction_name='mean', **kwargs): | |
"""parse_poolingfunction | |
A heler function to parse any temporal pooling | |
Pooling is done on dimension 1 | |
:param poolingfunction_name: | |
:param **kwargs: | |
""" | |
poolingfunction_name = poolingfunction_name.lower() | |
if poolingfunction_name == 'mean': | |
return MeanPool(pooldim=1) | |
elif poolingfunction_name == 'max': | |
return MaxPool(pooldim=1) | |
elif poolingfunction_name == 'linear': | |
return LinearSoftPool(pooldim=1) | |
elif poolingfunction_name == 'expalpha': | |
return AutoExpPool(outputdim=kwargs['outputdim'], pooldim=1) | |
elif poolingfunction_name == 'soft': | |
return SoftPool(pooldim=1) | |
elif poolingfunction_name == 'auto': | |
return AutoPool(outputdim=kwargs['outputdim']) | |
elif poolingfunction_name == 'attention': | |
return AttentionPool(inputdim=kwargs['inputdim'], | |
outputdim=kwargs['outputdim']) | |
class conv1d(nn.Module): | |
def __init__(self, nin, nout, kernel_size=3, stride=1, padding='VALID', dilation=1): | |
super(conv1d, self).__init__() | |
if padding == 'VALID': | |
dconv_pad = 0 | |
elif padding == 'SAME': | |
dconv_pad = dilation * ((kernel_size - 1) // 2) | |
else: | |
raise ValueError("Padding Mode Error!") | |
self.conv = nn.Conv1d(nin, nout, kernel_size=kernel_size, stride=stride, padding=dconv_pad) | |
self.act = nn.ReLU() | |
self.init_layer(self.conv) | |
def init_layer(self, layer, nonlinearity='relu'): | |
"""Initialize a Linear or Convolutional layer. """ | |
nn.init.kaiming_normal_(layer.weight, nonlinearity=nonlinearity) | |
nn.init.constant_(layer.bias, 0.1) | |
def forward(self, x): | |
out = self.act(self.conv(x)) | |
return out | |
class Atten_1(nn.Module): | |
def __init__(self, input_dim, context=2, dropout_rate=0.2): | |
super(Atten_1, self).__init__() | |
self._matrix_k = nn.Linear(input_dim, input_dim // 4) | |
self._matrix_q = nn.Linear(input_dim, input_dim // 4) | |
self.relu = nn.ReLU() | |
self.context = context | |
self._dropout_layer = nn.Dropout(dropout_rate) | |
self.init_layer(self._matrix_k) | |
self.init_layer(self._matrix_q) | |
def init_layer(self, layer, nonlinearity='leaky_relu'): | |
"""Initialize a Linear or Convolutional layer. """ | |
nn.init.kaiming_uniform_(layer.weight, nonlinearity=nonlinearity) | |
if hasattr(layer, 'bias'): | |
if layer.bias is not None: | |
layer.bias.data.fill_(0.) | |
def forward(self, input_x): | |
k_x = input_x | |
k_x = self.relu(self._matrix_k(k_x)) | |
k_x = self._dropout_layer(k_x) | |
# print('k_x ',k_x.shape) | |
q_x = input_x[:, self.context, :] | |
# print('q_x ',q_x.shape) | |
q_x = q_x[:, None, :] | |
# print('q_x1 ',q_x.shape) | |
q_x = self.relu(self._matrix_q(q_x)) | |
q_x = self._dropout_layer(q_x) | |
# print('q_x2 ',q_x.shape) | |
x_ = torch.matmul(k_x, q_x.transpose(-2, -1) / math.sqrt(k_x.size(-1))) | |
# print('x_ ',x_.shape) | |
x_ = x_.squeeze(2) | |
alpha = F.softmax(x_, dim=-1) | |
att_ = alpha | |
# print('alpha ',alpha) | |
alpha = alpha.unsqueeze(2).repeat(1,1,input_x.shape[2]) | |
# print('alpha ',alpha) | |
# alpha = alpha.view(alpha.size(0), alpha.size(1), alpha.size(2), 1) | |
out = alpha * input_x | |
# print('out ', out.shape) | |
# out = out.mean(2) | |
out = out.mean(1) | |
# print('out ',out.shape) | |
# assert 1==2 | |
#y = alpha * input_x | |
#return y, att_ | |
out = input_x[:, self.context, :] + out | |
return out | |
class Fusion(nn.Module): | |
def __init__(self, inputdim, inputdim2, n_fac): | |
super().__init__() | |
self.fuse_layer1 = conv1d(inputdim, inputdim2*n_fac,1) | |
self.fuse_layer2 = conv1d(inputdim2, inputdim2*n_fac,1) | |
self.avg_pool = nn.AvgPool1d(n_fac, stride=n_fac) # 沿着最后一个维度进行pooling | |
def forward(self,embedding,mix_embed): | |
embedding = embedding.permute(0,2,1) | |
fuse1_out = self.fuse_layer1(embedding) # [2, 501, 2560] ,512*5, 1D卷积融合,spk_embeding ,扩大其维度 | |
fuse1_out = fuse1_out.permute(0,2,1) | |
mix_embed = mix_embed.permute(0,2,1) | |
fuse2_out = self.fuse_layer2(mix_embed) # [2, 501, 2560] ,512*5, 1D卷积融合,spk_embeding ,扩大其维度 | |
fuse2_out = fuse2_out.permute(0,2,1) | |
as_embs = torch.mul(fuse1_out, fuse2_out) # 相乘 [2, 501, 2560] | |
# (10, 501, 512) | |
as_embs = self.avg_pool(as_embs) # [2, 501, 512] 相当于 2560//5 | |
return as_embs | |
class CDur_fusion(nn.Module): | |
def __init__(self, inputdim, outputdim, **kwargs): | |
super().__init__() | |
self.features = nn.Sequential( | |
Block2D(1, 32), | |
nn.LPPool2d(4, (2, 4)), | |
Block2D(32, 128), | |
Block2D(128, 128), | |
nn.LPPool2d(4, (2, 4)), | |
Block2D(128, 128), | |
Block2D(128, 128), | |
nn.LPPool2d(4, (1, 4)), | |
nn.Dropout(0.3), | |
) | |
with torch.no_grad(): | |
rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
self.gru = nn.GRU(128, 128, bidirectional=True, batch_first=True) | |
self.fusion = Fusion(128,2) | |
self.fc = nn.Linear(256,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding): # | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = self.fusion(embedding,x) | |
#x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur(nn.Module): | |
def __init__(self, inputdim, outputdim,time_resolution, **kwargs): | |
super().__init__() | |
self.features = nn.Sequential( | |
Block2D(1, 32), | |
nn.LPPool2d(4, (2, 4)), | |
Block2D(32, 128), | |
Block2D(128, 128), | |
nn.LPPool2d(4, (2, 4)), | |
Block2D(128, 128), | |
Block2D(128, 128), | |
nn.LPPool2d(4, (2, 4)), | |
nn.Dropout(0.3), | |
) | |
with torch.no_grad(): | |
rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
self.gru = nn.GRU(256, 256, bidirectional=True, batch_first=True) | |
self.fc = nn.Linear(512,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding,one_hot=None): # | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur_big(nn.Module): | |
def __init__(self, inputdim, outputdim, **kwargs): | |
super().__init__() | |
self.features = nn.Sequential( | |
Block2D(1, 64), | |
Block2D(64, 64), | |
nn.LPPool2d(4, (2, 2)), | |
Block2D(64, 128), | |
Block2D(128, 128), | |
nn.LPPool2d(4, (2, 2)), | |
Block2D(128, 256), | |
Block2D(256, 256), | |
nn.LPPool2d(4, (2, 4)), | |
Block2D(256, 512), | |
Block2D(512, 512), | |
nn.LPPool2d(4, (1, 4)), | |
nn.Dropout(0.3),) | |
with torch.no_grad(): | |
rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) | |
self.fc = nn.Linear(1024,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding): # | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur_GLU(nn.Module): | |
def __init__(self, inputdim, outputdim, **kwargs): | |
super().__init__() | |
self.features = Mul_scale_GLU() | |
# with torch.no_grad(): | |
# rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
# rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
self.gru = nn.GRU(640, 512,1, bidirectional=True, batch_first=True) # previous is 640 | |
# self.gru = LSTMModel(640, 512,1) | |
self.fc = nn.Linear(1024,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
# self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding,one_hot=None): # | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) | |
# print('x ',x.shape) | |
# assert 1==2 | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
# x = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur_CNN14(nn.Module): | |
def __init__(self, inputdim, outputdim,time_resolution,**kwargs): | |
super().__init__() | |
if time_resolution==125: | |
self.features = Cnn10(8) | |
elif time_resolution == 250: | |
#print('time_resolution ',time_resolution) | |
self.features = Cnn10(4) | |
elif time_resolution == 500: | |
self.features = Cnn10(2) | |
else: | |
self.features = Cnn10(0) | |
with torch.no_grad(): | |
rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
# self.features = Cnn10() | |
self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) | |
# self.gru = LSTMModel(640, 512,1) | |
self.fc = nn.Linear(1024,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
# self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding,one_hot=None): | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) | |
# print('x ',x.shape) | |
# assert 1==2 | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
# x = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur_CNN_mul_scale(nn.Module): | |
def __init__(self, inputdim, outputdim,time_resolution,**kwargs): | |
super().__init__() | |
if time_resolution==125: | |
self.features = Cnn10_mul_scale(8) | |
elif time_resolution == 250: | |
#print('time_resolution ',time_resolution) | |
self.features = Cnn10_mul_scale(4) | |
elif time_resolution == 500: | |
self.features = Cnn10_mul_scale(2) | |
else: | |
self.features = Cnn10_mul_scale(0) | |
# with torch.no_grad(): | |
# rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
# rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
# self.features = Cnn10() | |
self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) | |
# self.gru = LSTMModel(640, 512,1) | |
self.fc = nn.Linear(1024,256) | |
self.outputlayer = nn.Linear(256, outputdim) | |
# self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding,one_hot=None): | |
# print('x ',x.shape) | |
# assert 1==2 | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) | |
# print('x ',x.shape) | |
# assert 1==2 | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
# x = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class CDur_CNN_mul_scale_fusion(nn.Module): | |
def __init__(self, inputdim, outputdim, time_resolution,**kwargs): | |
super().__init__() | |
if time_resolution==125: | |
self.features = Cnn10_mul_scale(8) | |
elif time_resolution == 250: | |
#print('time_resolution ',time_resolution) | |
self.features = Cnn10_mul_scale(4) | |
elif time_resolution == 500: | |
self.features = Cnn10_mul_scale(2) | |
else: | |
self.features = Cnn10_mul_scale(0) | |
# with torch.no_grad(): | |
# rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape | |
# rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] | |
# self.features = Cnn10() | |
self.gru = nn.GRU(512, 512, bidirectional=True, batch_first=True) | |
# self.gru = LSTMModel(640, 512,1) | |
self.fc = nn.Linear(1024,256) | |
self.fusion = Fusion(128,512,2) | |
self.outputlayer = nn.Linear(256, outputdim) | |
# self.features.apply(init_weights) | |
self.outputlayer.apply(init_weights) | |
def forward(self, x, embedding,one_hot=None): | |
# print('x ',x.shape) | |
# assert 1==2 | |
batch, time, dim = x.shape | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) | |
# print('x ',x.shape) | |
# assert 1==2 | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
x = self.fusion(embedding, x) | |
#x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.gru.flatten_parameters() | |
x, _ = self.gru(x) # x torch.Size([16, 125, 256]) | |
# x = self.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.fc(x) | |
decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0],decision_up | |
class RaDur_fusion(nn.Module): | |
def __init__(self, model_config, inputdim, outputdim, time_resolution, **kwargs): | |
super().__init__() | |
self.encoder = Cnn14() | |
self.detection = CDur_CNN_mul_scale_fusion(inputdim, outputdim, time_resolution) | |
self.softmax = nn.Softmax(dim=2) | |
#self.temperature = 5 | |
# if model_config['pre_train']: | |
# self.encoder.load_state_dict(torch.load(model_config['encoder_path'])['model']) | |
# self.detection.load_state_dict(torch.load(model_config['CDur_path'])) | |
self.q = nn.Linear(128,128) | |
self.k = nn.Linear(128,128) | |
self.q_ee = nn.Linear(128, 128) | |
self.k_ee = nn.Linear(128, 128) | |
self.temperature = 11.3 # sqrt(128) | |
self.att_pool = model_config['att_pool'] | |
self.enhancement = model_config['enhancement'] | |
self.tao = model_config['tao'] | |
self.top = model_config['top'] | |
self.bn = nn.BatchNorm1d(128) | |
self.EE_fusion = Fusion(128, 128, 4) | |
def get_w(self,q,k): | |
q = self.q(q) | |
k = self.k(k) | |
q = q.unsqueeze(1) | |
attn = torch.bmm(q, k.transpose(1, 2)) | |
attn = attn/self.temperature | |
attn = self.softmax(attn) | |
return attn | |
def get_w_ee(self,q,k): | |
q = self.q_ee(q) | |
k = self.k_ee(k) | |
q = q.unsqueeze(1) | |
attn = torch.bmm(q, k.transpose(1, 2)) | |
attn = attn/self.temperature | |
attn = self.softmax(attn) | |
return attn | |
def attention_pooling(self, embeddings, mean_embedding): | |
att_pool_w = self.get_w(mean_embedding,embeddings) | |
embedding = torch.bmm(att_pool_w, embeddings).squeeze(1) | |
# print(embedding.shape) | |
# print(att_pool_w.shape) | |
# print(att_pool_w[0]) | |
# assert 1==2 | |
return embedding | |
def select_topk_embeddings(self, scores, embeddings, k): | |
_, idx_DESC = scores.sort(descending=True, dim=1) # 根据分数进行排序 | |
top_k = _[:,:k] | |
# print('top_k ', top_k) | |
# top_k = top_k.mean(1) | |
idx_topk = idx_DESC[:, :k] # 取top_k个 | |
# print('index ', idx_topk) | |
idx_topk = idx_topk.unsqueeze(2).expand([-1, -1, embeddings.shape[2]]) | |
selected_embeddings = torch.gather(embeddings, 1, idx_topk) | |
return selected_embeddings,top_k | |
def sum_with_attention(self, embedding, top_k, selected_embeddings): | |
# print('embedding ',embedding) | |
# print('selected_embeddings ',selected_embeddings.shape) | |
att_1 = self.get_w_ee(embedding, selected_embeddings) | |
att_1 = att_1.squeeze(1) | |
#print('att_1 ',att_1.shape) | |
larger = top_k > self.tao | |
# print('larger ',larger) | |
top_k = top_k*larger | |
# print('top_k ',top_k.shape) | |
# print('top_k ',top_k) | |
att_1 = att_1*top_k | |
#print('att_1 ',att_1.shape) | |
# assert 1==2 | |
att_2 = att_1.unsqueeze(2).repeat(1,1,128) | |
Es = selected_embeddings*att_2 | |
return Es | |
def orcal_EE(self, x, embedding, label): | |
batch, time, dim = x.shape | |
mixture_embedding = self.encoder(x) # 8, 125, 128 | |
mixture_embedding = mixture_embedding.transpose(1,2) | |
mixture_embedding = self.bn(mixture_embedding) | |
mixture_embedding = mixture_embedding.transpose(1,2) | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.detection.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) | |
embedding_pre = embedding.unsqueeze(1) | |
embedding_pre = embedding_pre.repeat(1, x.shape[1], 1) | |
f = self.detection.fusion(embedding_pre, x) # the first stage results | |
#f = torch.cat((x, embedding_pre), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.detection.gru.flatten_parameters() | |
f, _ = self.detection.gru(f) # x torch.Size([16, 125, 256]) | |
f = self.detection.fc(f) | |
decision_time = torch.softmax(self.detection.outputlayer(f),dim=2) # x torch.Size([16, 125, 2]) | |
selected_embeddings, top_k = self.select_topk_embeddings(decision_time[:,:,0], mixture_embedding, self.top) | |
selected_embeddings = self.sum_with_attention(embedding, top_k, selected_embeddings) # add the weight | |
mix_embedding = selected_embeddings.mean(1).unsqueeze(1) # | |
mix_embedding = mix_embedding.repeat(1, x.shape[1], 1) | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
mix_embedding = self.EE_fusion(mix_embedding, embedding) # 使用神经网络进行融合 | |
# mix_embedding2 = selected_embeddings2.mean(1) | |
#mix_embedding = embedding + mix_embedding # 直接相加 | |
# new detection results | |
# embedding_now = mix_embedding.unsqueeze(1) | |
# embedding_now = embedding_now.repeat(1, x.shape[1], 1) | |
f_now = self.detection.fusion(mix_embedding, x) | |
#f_now = torch.cat((x, embedding_now), dim=2) # | |
f_now, _ = self.detection.gru(f_now) # x torch.Size([16, 125, 256]) | |
f_now = self.detection.fc(f_now) | |
decision_time_now = torch.softmax(self.detection.outputlayer(f_now), dim=2) # x torch.Size([16, 125, 2]) | |
top_k = top_k.mean(1) # get avg score,higher score will have more weight | |
larger = top_k > self.tao | |
top_k = top_k * larger | |
top_k = top_k/2.0 | |
# print('top_k ',top_k) | |
# assert 1==2 | |
# print('tok_k[ ',top_k.shape) | |
# print('decision_time ',decision_time.shape) | |
# print('decision_time_now ',decision_time_now.shape) | |
neg_w = top_k.unsqueeze(1).unsqueeze(2) | |
neg_w = neg_w.repeat(1, decision_time_now.shape[1], decision_time_now.shape[2]) | |
# print('neg_w ',neg_w.shape) | |
#print('neg_w ',neg_w[:,0:10,0]) | |
pos_w = 1-neg_w | |
#print('pos_w ',pos_w[:,0:10,0]) | |
decision_time_final = decision_time*pos_w + neg_w*decision_time_now | |
#print('decision_time_final ',decision_time_final[0,0:10,0]) | |
# print(decision_time_final[0,:,:]) | |
#assert 1==2 | |
return decision_time_final | |
def forward(self, x, ref, label=None): | |
batch, time, dim = x.shape | |
logit = torch.zeros(1).cuda() | |
embeddings = self.encoder(ref) | |
mean_embedding = embeddings.mean(1) | |
if self.att_pool == True: | |
mean_embedding = self.bn(mean_embedding) | |
embeddings = embeddings.transpose(1,2) | |
embeddings = self.bn(embeddings) | |
embeddings = embeddings.transpose(1,2) | |
embedding = self.attention_pooling(embeddings, mean_embedding) | |
else: | |
embedding = mean_embedding | |
if self.enhancement == True: | |
decision_time = self.orcal_EE(x, embedding, label) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), # [16, 2, 125] | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0], decision_up, logit | |
x = x.unsqueeze(1) # (b,1,t,d) | |
x = self.detection.features(x) # | |
x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) | |
embedding = embedding.unsqueeze(1) | |
embedding = embedding.repeat(1, x.shape[1], 1) | |
# x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
x = self.detection.fusion(embedding, x) | |
# embedding = embedding.unsqueeze(1) | |
# embedding = embedding.repeat(1, x.shape[1], 1) | |
# x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] | |
if not hasattr(self, '_flattened'): | |
self.detection.gru.flatten_parameters() | |
x, _ = self.detection.gru(x) # x torch.Size([16, 125, 256]) | |
x = self.detection.fc(x) | |
decision_time = torch.softmax(self.detection.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) | |
decision_up = torch.nn.functional.interpolate( | |
decision_time.transpose(1, 2), | |
time, # 501 | |
mode='linear', | |
align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) | |
return decision_time[:,:,0], decision_up, logit |