Spaces:
Build error
Build error
import whisper | |
import gradio as gr | |
import datetime | |
import subprocess | |
import torch | |
import pyannote.audio | |
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding | |
from pyannote.audio import Audio | |
from pyannote.core import Segment | |
import wave | |
import contextlib | |
import math | |
from sklearn.cluster import AgglomerativeClustering | |
import numpy as np | |
model = whisper.load_model("large-v2") | |
embedding_model = PretrainedSpeakerEmbedding( | |
"speechbrain/spkrec-ecapa-voxceleb", | |
device=torch.device("cuda")) | |
def transcribe(audio, num_speakers): | |
path = convert_to_wav(audio) | |
result = model.transcribe(path) | |
segments = result["segments"] | |
num_speakers = max(round(num_speakers), 1) | |
if len(segments) < num_speakers: | |
num_speakers = len(segments) | |
if len(segments) == 1: | |
segments[0]['speaker'] = 'SPEAKER 1' | |
else: | |
duration = get_duration(path) | |
embeddings = make_embeddings(path, segments, duration) | |
add_speaker_labels(segments, embeddings, num_speakers) | |
output = get_output(segments) | |
return output | |
def convert_to_wav(path): | |
if path[-3:] != 'wav': | |
subprocess.call(['ffmpeg', '-i', path, 'audio.wav', '-y']) | |
path = 'audio.wav' | |
return path | |
def get_duration(path): | |
with contextlib.closing(wave.open(path,'r')) as f: | |
frames = f.getnframes() | |
rate = f.getframerate() | |
return frames / float(rate) | |
def make_embeddings(path, segments, duration): | |
embeddings = np.zeros(shape=(len(segments), 192)) | |
for i, segment in enumerate(segments): | |
embeddings[i] = segment_embedding(path, segment, duration) | |
return np.nan_to_num(embeddings) | |
audio = Audio() | |
def segment_embedding(path, segment, duration): | |
start = segment["start"] | |
# Whisper overshoots the end timestamp in the last segment | |
end = min(duration, segment["end"]) | |
clip = Segment(start, end) | |
waveform, sample_rate = audio.crop(path, clip) | |
return embedding_model(waveform[None]) | |
def add_speaker_labels(segments, embeddings, num_speakers): | |
clustering = AgglomerativeClustering(num_speakers).fit(embeddings) | |
labels = clustering.labels_ | |
for i in range(len(segments)): | |
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1) | |
def time(secs): | |
return datetime.timedelta(seconds=round(secs)) | |
def get_output(segments): | |
output = '' | |
for (i, segment) in enumerate(segments): | |
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: | |
output += "\n" + segment["speaker"] + ' ' + str(time(segment["start"])) + '\n' | |
output += segment["text"][1:] + ' ' | |
return output[1:] | |
gr.Interface( | |
title = 'Whisper with Speaker Recognition', | |
fn=transcribe, | |
inputs=[ | |
gr.inputs.Audio(source="upload", type="filepath"), | |
gr.inputs.Number(default=2, label="Number of Speakers") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label='Transcript') | |
], | |
debug=True).launch() |