dwarkesh's picture
first try
626e00c
raw
history blame
2.86 kB
import whisper
import gradio as gr
import datetime
import subprocess
import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import wave
import contextlib
import math
from sklearn.cluster import AgglomerativeClustering
import numpy as np
model = whisper.load_model("large-v2")
embedding_model = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb",
device=torch.device("cuda"))
def transcribe(audio, num_speakers):
path = convert_to_wav(audio)
result = model.transcribe(path)
segments = result["segments"]
num_speakers = max(round(num_speakers), 1)
if len(segments) < num_speakers:
num_speakers = len(segments)
if len(segments) == 1:
segments[0]['speaker'] = 'SPEAKER 1'
else:
duration = get_duration(path)
embeddings = make_embeddings(path, segments, duration)
add_speaker_labels(segments, embeddings, num_speakers)
output = get_output(segments)
return output
def convert_to_wav(path):
if path[-3:] != 'wav':
subprocess.call(['ffmpeg', '-i', path, 'audio.wav', '-y'])
path = 'audio.wav'
return path
def get_duration(path):
with contextlib.closing(wave.open(path,'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
return frames / float(rate)
def make_embeddings(path, segments, duration):
embeddings = np.zeros(shape=(len(segments), 192))
for i, segment in enumerate(segments):
embeddings[i] = segment_embedding(path, segment, duration)
return np.nan_to_num(embeddings)
audio = Audio()
def segment_embedding(path, segment, duration):
start = segment["start"]
# Whisper overshoots the end timestamp in the last segment
end = min(duration, segment["end"])
clip = Segment(start, end)
waveform, sample_rate = audio.crop(path, clip)
return embedding_model(waveform[None])
def add_speaker_labels(segments, embeddings, num_speakers):
clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
labels = clustering.labels_
for i in range(len(segments)):
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
def time(secs):
return datetime.timedelta(seconds=round(secs))
def get_output(segments):
output = ''
for (i, segment) in enumerate(segments):
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
output += "\n" + segment["speaker"] + ' ' + str(time(segment["start"])) + '\n'
output += segment["text"][1:] + ' '
return output[1:]
gr.Interface(
title = 'Whisper with Speaker Recognition',
fn=transcribe,
inputs=[
gr.inputs.Audio(source="upload", type="filepath"),
gr.inputs.Number(default=2, label="Number of Speakers")
],
outputs=[
gr.outputs.Textbox(label='Transcript')
],
debug=True).launch()