Spaces:
Runtime error
Runtime error
import nltk | |
import whisper | |
from pytube import YouTube | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer, util | |
nltk.download('punkt') | |
def init_sentence_model(embedding_model): | |
return SentenceTransformer(embedding_model) | |
def init_whisper(whisper_size): | |
return whisper.load_model(whisper_size) | |
def inference(link): | |
yt = YouTube(link) | |
path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4") | |
options = whisper.DecodingOptions(without_timestamps=True) | |
results = whisper_model.transcribe(path) | |
return results['segments'] | |
def get_embeddings(segments): | |
return model.encode(segments["text"]) | |
def format_segments(segments, window=10): | |
new_segments = dict() | |
new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)] | |
new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)] | |
return new_segments | |
with st.form("transcribe"): | |
yt_link = st.text_input("Youtube link") | |
whisper_size = st.selectbox("Whisper model size", ("small", "base", "large")) | |
embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2') | |
top_k = st.number_input("Number of query results", value=5) | |
window = st.number_input("Number of segments per result", value=10) | |
transcribe_submit = st.form_submit_button("Submit") | |
if transcribe_submit and 'start_search' not in st.session_state: | |
st.session_state.start_search = True | |
if 'start_search' in st.session_state: | |
model = init_sentence_model(embedding_model) | |
whisper_model = init_whisper(whisper_size) | |
segments = inference(yt_link) | |
segments = format_segments(segments, window) | |
embeddings = get_embeddings(segments) | |
query = st.text_input('Enter a query') | |
if query: | |
query_embedding = model.encode(query) | |
results = util.semantic_search(query_embedding, embeddings, top_k=top_k) | |
st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True) |