AudioSep / models /clap_encoder.py
badayvedat's picture
Initial commit
ae29df4
import random
import torch
import torch.nn as nn
import torchaudio
from models.CLAP.open_clip import create_model
from models.CLAP.training.data import get_audio_features
from transformers import RobertaTokenizer
from utils import ignore_warnings; ignore_warnings()
class CLAP_Encoder(nn.Module):
def __init__(
self,
pretrained_path='checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt',
sampling_rate=32000,
amodel = "HTSAT-base",
):
super().__init__()
self.device = "cpu"
self.precision = "fp32"
self.amodel = amodel # or 'PANN-14'
self.tmodel = "roberta" # the best text encoder in our training
self.enable_fusion = False # False if you do not want to use the fusion model
self.fusion_type = "aff_2d"
self.pretrained = pretrained_path
self.sampling_rate = sampling_rate
self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
self.model, self.model_cfg = create_model(
self.amodel,
self.tmodel,
self.pretrained,
precision=self.precision,
device=self.device,
enable_fusion=self.enable_fusion,
fusion_type=self.fusion_type,
)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
self.encoder_type = 'CLAP'
def batch_to_list(self, batch):
ret = []
for i in range(batch.size(0)):
ret.append(batch[i])
return ret
def _get_audio_embed(self, batch):
# batch: [B, samples]
with torch.no_grad():
audio_dict_list = []
assert (
self.sampling_rate == 32000
), "We only support 32000 sampling rate"
# batch: [bs, 1, t-samples]
batch = torchaudio.functional.resample(
batch, orig_freq=self.sampling_rate, new_freq=48000
)
for waveform in self.batch_to_list(batch):
audio_dict = {}
audio_dict = get_audio_features(
audio_dict,
waveform,
480000,
data_truncating="fusion",
data_filling="repeatpad",
audio_cfg=self.model_cfg["audio_cfg"],
)
audio_dict_list.append(audio_dict)
# [bs, 512]
embed = self.model.get_audio_embedding(audio_dict_list)
return embed.detach()
def _get_text_embed(self, batch):
double_batch = False
if len(batch) == 1:
batch = batch * 2
double_batch = True
with torch.no_grad():
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
text_data = self.tokenizer(batch)
embed = self.model.get_text_embedding(text_data)
if double_batch:
embed = embed[0].unsqueeze(0)
return embed.detach()
def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None):
if modality == 'audio':
embed = self._get_audio_embed(audio)
elif modality == 'text':
embed = self._get_text_embed(text)
elif modality == 'hybird':
if random.random() > use_text_ratio:
embed = self._get_audio_embed(audio)
else:
embed = self._get_text_embed(text)
else:
raise NotImplementedError("Please check flag 'training_modality'.")
return embed.float()
def tokenizer(self, text):
result = self.tokenize(
text,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
return {k: v.squeeze(0) for k, v in result.items()}