File size: 6,879 Bytes
a2ce4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import torch
from speechbrain.pretrained import Pretrained
class Speech_Emotion_Diarization(Pretrained):
"""A ready-to-use SED interface (audio -> emotions and their durations)
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
Example
-------
>>> from speechbrain.pretrained import Speech_Emotion_Diarization
>>> tmpdir = getfixture("tmpdir")
>>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP
>>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP
"""
MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def diarize_file(self, path):
"""Get emotion diarization of a spoken utterance.
Arguments
---------
path : str
Path to audio file which to diarize.
Returns
-------
dict
The emotions and their boundaries.
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
frame_class = self.diarize_batch(
batch, rel_length, [path]
)
return frame_class
def encode_batch(self, wavs, wav_lens):
"""Encodes audios into fine-grained emotional embeddings
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = self.mods.input_norm(wavs, wav_lens)
outputs = self.mods.wav2vec2(wavs)
return outputs
def diarize_batch(self, wavs, wav_lens, batch_id):
"""Get emotion diarization of a batch of waveforms.
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The frame-wise predictions
"""
outputs = self.encode_batch(wavs, wav_lens)
averaged_out = self.hparams.avg_pool(outputs)
outputs = self.mods.output_mlp(averaged_out)
outputs = self.hparams.log_softmax(outputs)
score, index = torch.max(outputs, dim=-1)
preds = self.hparams.label_encoder.decode_torch(index)
results = self.preds_to_diarization(preds, batch_id)
return results
def preds_to_diarization(self, prediction, batch_id):
"""Convert frame-wise predictions into a dictionary of
diarization results.
Returns
-------
dictionary
A dictionary with the start/end of each emotion
"""
results = {}
for i in range(len(prediction)):
pred = prediction[i]
lol = []
for j in range(len(pred)):
start = round(self.hparams.stride * 0.02 * j, 2)
end = round(start + self.hparams.window_length * 0.02, 2)
lol.append([batch_id[i], start, end, pred[j]])
lol = merge_ssegs_same_emotion_adjacent(lol)
results[batch_id[i]] = [{"start": k[1], "end":k[2], "emotion": k[3]} for k in lol]
return results
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)
def is_overlapped(end1, start2):
"""Returns True if segments are overlapping.
Arguments
---------
end1 : float
End time of the first segment.
start2 : float
Start time of the second segment.
Returns
-------
overlapped : bool
True of segments overlapped else False.
Example
-------
>>> from speechbrain.processing import diarization as diar
>>> diar.is_overlapped(5.5, 3.4)
True
>>> diar.is_overlapped(5.5, 6.4)
False
"""
if start2 > end1:
return False
else:
return True
def merge_ssegs_same_emotion_adjacent(lol):
"""Merge adjacent sub-segs if they are the same emotion.
Arguments
---------
lol : list of list
Each list contains [utt_id, sseg_start, sseg_end, emo_label].
Returns
-------
new_lol : list of list
new_lol contains adjacent segments merged from the same emotion ID.
Example
-------
>>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent
>>> lol=[['u1', 0.0, 7.0, 'a'],
... ['u1', 7.0, 9.0, 'a'],
... ['u1', 9.0, 11.0, 'n'],
... ['u1', 11.0, 13.0, 'n'],
... ['u1', 13.0, 15.0, 'n'],
... ['u1', 15.0, 16.0, 'a']]
>>> merge_ssegs_same_emotion_adjacent(lol)
[['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']]
"""
new_lol = []
# Start from the first sub-seg
sseg = lol[0]
flag = False
for i in range(1, len(lol)):
next_sseg = lol[i]
# IF sub-segments overlap AND has same emotion THEN merge
if is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3]:
sseg[2] = next_sseg[2] # just update the end time
# This is important. For the last sseg, if it is the same emotion then merge
# Make sure we don't append the last segment once more. Hence, set FLAG=True
if i == len(lol) - 1:
flag = True
new_lol.append(sseg)
else:
new_lol.append(sseg)
sseg = next_sseg
# Add last segment only when it was skipped earlier.
if flag is False:
new_lol.append(lol[-1])
return new_lol
|