Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer, util, CrossEncoder | |
from datasets import load_dataset | |
import pandas as pd | |
import torch | |
import gradio as gr | |
import whisper | |
import pathlib, os | |
auth_token = os.environ.get("auth_key") | |
#Get the netflix dataset | |
netflix = load_dataset('hugginglearners/netflix-shows',use_auth_token=auth_token) | |
#load ASR model | |
asr_model = whisper.load_model("small") | |
#Filter for relevant columns and convert to pandas | |
netflix_df = netflix['train'].to_pandas() | |
netflix_df = netflix_df[['type','title','country','description','release_year','rating','duration','listed_in','cast']] | |
passages = netflix_df['description'].tolist() | |
#load mpnet model | |
model = SentenceTransformer('all-mpnet-base-v2') | |
#load embeddings | |
flix_ds = load_dataset("nickmuchi/netflix-shows-mpnet-embeddings", use_auth_token=auth_token) | |
dataset_embeddings = torch.from_numpy(flix_ds["train"].to_pandas().to_numpy()).to(torch.float) | |
#load cross-encoder for reranking | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
def display_df_as_table(model,top_k,score='score'): | |
# Display the df with text and scores as a table | |
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text']) | |
df['Score'] = round(df['Score'].astype(float),2) | |
df = df.merge(netflix_df,how='inner',left_on='Text',right_on='description') | |
df.drop('Text',inplace=True,axis=1) | |
return df | |
#function for transcribing audio inputs | |
def asr(audio): | |
results = asr_model.transcribe(audio) | |
query = results['text'] | |
return query | |
#load ASR model | |
def asr_inputs(audio, upload): | |
if audio: | |
query = asr(audio) | |
elif upload: | |
query = asr(upload) | |
return query | |
#function for generating similarity of query and netflix shows | |
def semantic_search(query,top_k): | |
'''Encode query and check similarity with embeddings''' | |
question_embedding = model.encode(query, convert_to_tensor=True).cpu() | |
hits = util.semantic_search(question_embedding, dataset_embeddings, top_k=top_k) | |
hits = hits[0] | |
##### Re-Ranking ##### | |
# Now, score all retrieved passages with the cross_encoder | |
cross_inp = [[query, netflix_df['description'].iloc[hit['corpus_id']]] for hit in hits] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
#Bi-encoder df | |
hits = sorted(hits, key=lambda x: x['score'], reverse=True) | |
bi_df = display_df_as_table(hits,top_k) | |
#Cross encoder df | |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
cross_df = display_df_as_table(hits,top_k,'cross-score') | |
cross_df['Score'] = round(cross_df['Score'].astype(float),2) | |
return bi_df, cross_df | |
title = """<h1 id="title">Voice Activated Netflix Semantic Search</h1>""" | |
description = """ | |
Semantic Search is a way to generate search results based on the actual meaning of the query instead of a standard keyword search. I believe this way of searching provides more meaning results when trying to find a good show to watch on Netflix. For example, one could say "Success, rags to riches story" as provided in the example below to generate shows or movies with a description that is semantically similar to the query. | |
The app uses OpenAI's SOTA ASR model, [Whisper](https://huggingface.co/spaces/openai/whisper), to convert speech to text. | |
- The App generates embeddings using [All-Mpnet-Base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) model from Sentence Transformers. | |
- The model encodes the query and the discerption field from the [Netflix-Shows](https://huggingface.co/datasets/hugginglearners/netflix-shows) dataset which contains 8800 shows and movies currently on Netflix scraped from the web using Selenium. | |
- Similarity scores are then generated, from highest to lowest. The user can select how many suggestions they need from the results. | |
- A Cross Encoder then re-ranks the top selections to further improve on the similarity scores. | |
- You will see 2 tables generated, one from the bi-encoder and the other from the cross encoder which further enhances the similarity score rankings | |
Enjoy and Search like you mean it!! | |
""" | |
example_audio = [[path.as_posix()] for path in sorted(pathlib.Path('audio_examples').rglob('*.wav'))] | |
twitter_link = """ | |
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) | |
""" | |
css = ''' | |
h1#title { | |
text-align: center; | |
} | |
''' | |
demo = gr.Blocks(css=css) | |
with demo: | |
with gr.Box(): | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown(twitter_link) | |
top_k = gr.Slider(minimum=3,maximum=10,value=3,step=1,label='Number of Suggestions to Generate') | |
with gr.Row(): | |
audio = gr.Audio(source='microphone',type='filepath',label='Audio Input: Describe the Netflix show you would like to watch..') | |
audio_file = gr.Audio(source='upload',type='filepath',label='Audio Upload') | |
btn = gr.Button("Transcribe") | |
with gr.Row(): | |
query = gr.Textbox(label='Transcribed Text') | |
with gr.Row(): | |
bi_output = gr.DataFrame(headers=['Similarity Score','Type','Title','Country','Description','Release Year','Rating','Duration','Category Listing','Cast'], | |
label=f'Top-{top_k} Bi-Encoder Retrieval hits', wrap=True) | |
with gr.Row(): | |
cross_output = gr.DataFrame(headers=['Similarity Score','Type','Title','Country','Description','Release Year','Rating','Duration','Category Listing','Cast'], | |
label=f'Top-{top_k} Cross-Encoder Re-ranker hits', wrap=True) | |
with gr.Row(): | |
examples = gr.Examples(examples=example_audio,inputs=[audio_file]) | |
#sem_but = gr.Button('Search') | |
btn.click(asr_inputs, inputs=[audio,audio_file], outputs=[query]) | |
query.change(semantic_search,inputs=[query,top_k],outputs=[bi_output,cross_output],queue=True) | |
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-netflix-shows-semantic-search)") | |
demo.launch(debug=True,enable_queue=True) |