anzorq's picture
Create app.py
8ca2e83 verified
raw
history blame
1.38 kB
import gradio as gr
import torch
import torchaudio
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
model = AutoModelForCTC.from_pretrained("anzorq/output")
processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/output")
def transcribe_speech(audio):
# Load the audio file
waveform, sr = torchaudio.load(audio)
# Resample the audio if needed
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
waveform = resampler(waveform)
# Convert to mono if needed
if waveform.dim() > 1:
waveform = torch.mean(waveform, dim=0)
# Normalize the audio
waveform = waveform / torch.max(torch.abs(waveform))
# Extract input features
input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
input_features = torch.from_numpy(input_features).to("cuda" if torch.cuda.is_available() else "cpu")
# Generate logits using the model
with torch.no_grad():
logits = model(input_features).logits
# Decode the predicted ids to text
pred_ids = torch.argmax(logits, dim=-1)[0]
pred_text = processor.decode(pred_ids)
return pred_text
# Define the Gradio interface
interface = gr.Interface(
fn=transcribe_speech,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
live=True,
)
# Launch the app
interface.launch()