Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import wave | |
import numpy as np | |
import tempfile | |
import os | |
# Page configuration | |
st.set_page_config( | |
page_title="Speech to Text Converter", | |
page_icon="๐๏ธ", | |
layout="wide" | |
) | |
def load_pipeline(): | |
"""Load the model, processor, and create pipeline""" | |
device = "cpu" | |
torch_dtype = torch.float32 | |
model_id = "distil-whisper/distil-large-v3" | |
# Load model | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
model.to(device) | |
# Load processor | |
processor = AutoProcessor.from_pretrained(model_id) | |
# Create pipeline | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=8, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
return pipe | |
def read_wav_file(wav_file): | |
"""Read WAV file using wave library""" | |
with wave.open(wav_file, 'rb') as wav: | |
# Get wav file parameters | |
channels = wav.getnchannels() | |
sample_width = wav.getsampwidth() | |
sample_rate = wav.getframerate() | |
n_frames = wav.getnframes() | |
# Read raw audio data | |
raw_data = wav.readframes(n_frames) | |
# Convert bytes to numpy array | |
if sample_width == 1: | |
dtype = np.uint8 | |
elif sample_width == 2: | |
dtype = np.int16 | |
else: | |
raise ValueError("Unsupported sample width") | |
audio_data = np.frombuffer(raw_data, dtype=dtype) | |
# Convert to float32 and normalize | |
audio_data = audio_data.astype(np.float32) / np.iinfo(dtype).max | |
# If stereo, convert to mono by averaging channels | |
if channels == 2: | |
audio_data = audio_data.reshape(-1, 2).mean(axis=1) | |
# Resample to 16kHz if necessary | |
if sample_rate != 16000: | |
# Simple resampling | |
original_length = len(audio_data) | |
desired_length = int(original_length * 16000 / sample_rate) | |
indices = np.linspace(0, original_length-1, desired_length) | |
audio_data = np.interp(indices, np.arange(original_length), audio_data) | |
return audio_data | |
def main(): | |
st.title("๐๏ธ Speech to Text Converter") | |
st.markdown("### Upload a WAV file and convert speech to text") | |
# Load pipeline | |
with st.spinner("Loading model... This might take a few minutes the first time."): | |
try: | |
pipe = load_pipeline() | |
st.success("Model loaded successfully! Ready to transcribe.") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# File upload | |
audio_file = st.file_uploader( | |
"Upload your audio file", | |
type=['wav'], | |
help="Only WAV files are supported. For better performance, keep files under 5 minutes." | |
) | |
if audio_file is not None: | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
tmp_file.write(audio_file.getvalue()) | |
temp_path = tmp_file.name | |
try: | |
# Display audio player | |
st.audio(audio_file) | |
# Add transcribe button | |
if st.button("๐ฏ Transcribe Audio", type="primary"): | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
try: | |
# Read audio file | |
status_text.text("Reading audio file...") | |
progress_bar.progress(25) | |
audio_data = read_wav_file(temp_path) | |
# Transcribe | |
status_text.text("Transcribing... This might take a while.") | |
progress_bar.progress(50) | |
# Use pipeline for transcription | |
result = pipe( | |
{"raw": audio_data, "sampling_rate": 16000}, | |
return_timestamps=True | |
) | |
# Update progress | |
progress_bar.progress(100) | |
status_text.text("Transcription completed!") | |
# Display results | |
st.markdown("### Transcription Result:") | |
st.write(result["text"]) | |
# Display timestamps if available | |
if "chunks" in result: | |
st.markdown("### Timestamps:") | |
for chunk in result["chunks"]: | |
st.write(f"{chunk['timestamp']}: {chunk['text']}") | |
# Download button | |
st.download_button( | |
label="๐ฅ Download Transcription", | |
data=result["text"], | |
file_name="transcription.txt", | |
mime="text/plain" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
# Usage instructions | |
with st.expander("โน๏ธ Usage Instructions"): | |
st.markdown(""" | |
### Instructions: | |
1. Upload a WAV file (16-bit PCM format recommended) | |
2. Click 'Transcribe Audio' | |
3. Wait for processing to complete | |
4. View or download the transcription | |
### Notes: | |
- Only WAV files are supported | |
- Keep files under 5 minutes for best results | |
- Audio should be clear with minimal background noise | |
- The transcription includes timestamps for better reference | |
""") | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
"Made with โค๏ธ using Distil-Whisper model" | |
) | |
if __name__ == "__main__": | |
main() |