|
|
|
import nltk |
|
import librosa |
|
import IPython.display |
|
import torch |
|
import gradio as gr |
|
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC |
|
nltk.download("punkt") |
|
|
|
model_name = "facebook/wav2vec2-base-960h" |
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name) |
|
model = Wav2Vec2ForCTC.from_pretrained(model_name) |
|
|
|
def load_data(input_file): |
|
""" Function for resampling to ensure that the speech input is sampled at 16KHz. |
|
""" |
|
|
|
speech, sample_rate = librosa.load(input_file) |
|
|
|
if len(speech.shape) > 1: |
|
speech = speech[:,0] + speech[:,1] |
|
|
|
if sample_rate !=16000: |
|
speech = librosa.resample(speech, sample_rate,16000) |
|
|
|
return speech |
|
def correct_casing(input_sentence): |
|
""" This function is for correcting the casing of the generated transcribed text |
|
""" |
|
sentences = nltk.sent_tokenize(input_sentence) |
|
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences])) |
|
|
|
def asr_transcript(input_file): |
|
"""This function generates transcripts for the provided audio input |
|
""" |
|
speech = load_data(input_file) |
|
|
|
input_values = tokenizer(speech, return_tensors="pt").input_values |
|
|
|
logits = model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
transcription = tokenizer.decode(predicted_ids[0]) |
|
|
|
transcription = correct_casing(transcription.lower()) |
|
return transcription |
|
def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ): |
|
transcript = "" |
|
|
|
sample_rate = librosa.get_samplerate(input_file) |
|
|
|
|
|
stream = librosa.stream( |
|
input_file, |
|
block_length=20, |
|
frame_length=sample_rate, |
|
hop_length=sample_rate, |
|
) |
|
|
|
for speech in stream: |
|
if len(speech.shape) > 1: |
|
speech = speech[:, 0] + speech[:, 1] |
|
if sample_rate !=16000: |
|
speech = librosa.resample(speech, sample_rate,16000) |
|
input_values = tokenizer(speech, return_tensors="pt").input_values |
|
logits = model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = tokenizer.decode(predicted_ids[0]) |
|
|
|
transcript += correct_casing(transcription.lower()) |
|
|
|
|
|
return transcript[:3800] |
|
gr.Interface(asr_transcript_long, |
|
|
|
inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your audio file here"), |
|
outputs = gr.outputs.Textbox(type="str",label="Output Text"), |
|
title="English Audio Transcriptor", |
|
description = "This tool transcribes your audio to the text", |
|
|
|
|
|
|