speechtotextt / app.py
alaahilal's picture
updated the model
4af3d61 verified
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import wave
import numpy as np
import tempfile
import os
# Page configuration
st.set_page_config(
page_title="Speech to Text Converter",
page_icon="๐ŸŽ™๏ธ",
layout="wide"
)
@st.cache_resource
def load_pipeline():
"""Load the model, processor, and create pipeline"""
device = "cpu"
torch_dtype = torch.float32
model_id = "distil-whisper/distil-large-v3"
# Load model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
# Load processor
processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=8,
torch_dtype=torch_dtype,
device=device,
)
return pipe
def read_wav_file(wav_file):
"""Read WAV file using wave library"""
with wave.open(wav_file, 'rb') as wav:
# Get wav file parameters
channels = wav.getnchannels()
sample_width = wav.getsampwidth()
sample_rate = wav.getframerate()
n_frames = wav.getnframes()
# Read raw audio data
raw_data = wav.readframes(n_frames)
# Convert bytes to numpy array
if sample_width == 1:
dtype = np.uint8
elif sample_width == 2:
dtype = np.int16
else:
raise ValueError("Unsupported sample width")
audio_data = np.frombuffer(raw_data, dtype=dtype)
# Convert to float32 and normalize
audio_data = audio_data.astype(np.float32) / np.iinfo(dtype).max
# If stereo, convert to mono by averaging channels
if channels == 2:
audio_data = audio_data.reshape(-1, 2).mean(axis=1)
# Resample to 16kHz if necessary
if sample_rate != 16000:
# Simple resampling
original_length = len(audio_data)
desired_length = int(original_length * 16000 / sample_rate)
indices = np.linspace(0, original_length-1, desired_length)
audio_data = np.interp(indices, np.arange(original_length), audio_data)
return audio_data
def main():
st.title("๐ŸŽ™๏ธ Speech to Text Converter")
st.markdown("### Upload a WAV file and convert speech to text")
# Load pipeline
with st.spinner("Loading model... This might take a few minutes the first time."):
try:
pipe = load_pipeline()
st.success("Model loaded successfully! Ready to transcribe.")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# File upload
audio_file = st.file_uploader(
"Upload your audio file",
type=['wav'],
help="Only WAV files are supported. For better performance, keep files under 5 minutes."
)
if audio_file is not None:
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(audio_file.getvalue())
temp_path = tmp_file.name
try:
# Display audio player
st.audio(audio_file)
# Add transcribe button
if st.button("๐ŸŽฏ Transcribe Audio", type="primary"):
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Read audio file
status_text.text("Reading audio file...")
progress_bar.progress(25)
audio_data = read_wav_file(temp_path)
# Transcribe
status_text.text("Transcribing... This might take a while.")
progress_bar.progress(50)
# Use pipeline for transcription
result = pipe(
{"raw": audio_data, "sampling_rate": 16000},
return_timestamps=True
)
# Update progress
progress_bar.progress(100)
status_text.text("Transcription completed!")
# Display results
st.markdown("### Transcription Result:")
st.write(result["text"])
# Display timestamps if available
if "chunks" in result:
st.markdown("### Timestamps:")
for chunk in result["chunks"]:
st.write(f"{chunk['timestamp']}: {chunk['text']}")
# Download button
st.download_button(
label="๐Ÿ“ฅ Download Transcription",
data=result["text"],
file_name="transcription.txt",
mime="text/plain"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
# Usage instructions
with st.expander("โ„น๏ธ Usage Instructions"):
st.markdown("""
### Instructions:
1. Upload a WAV file (16-bit PCM format recommended)
2. Click 'Transcribe Audio'
3. Wait for processing to complete
4. View or download the transcription
### Notes:
- Only WAV files are supported
- Keep files under 5 minutes for best results
- Audio should be clear with minimal background noise
- The transcription includes timestamps for better reference
""")
# Footer
st.markdown("---")
st.markdown(
"Made with โค๏ธ using Distil-Whisper model"
)
if __name__ == "__main__":
main()