Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
from torch.utils.data import DataLoader | |
from utils.unifiedmodel import RRUMDataset | |
from utils.huggingface_model_wrapper import YoutubeVideoSimilarityModel | |
from utils.helper_funcs import get_example_videos, update_youtube_embedded_html, get_input_data_df | |
RR_EXAMPLES_URL = os.environ.get( | |
'RR_EXAMPLES_URL', 'https://public-data.telemetry.mozilla.org/api/v1/tables/telemetry_derived/regrets_reporter_study/v1/files/000000000000.json') | |
NUM_RR_EXAMPLES = 5 | |
example_videos, example_videos_rr = get_example_videos( | |
RR_EXAMPLES_URL, NUM_RR_EXAMPLES) | |
demo_title = 'Mozilla RegretsReporter YouTube video similarity' | |
demo_description = f''' | |
# {demo_title} | |
This demo showcases the YouTube video semantic similarity model developed as part of the RegretsReporter research project at Mozilla Foundation. You can read more about the project [here](https://foundation.mozilla.org/en/youtube/user-controls/) and about the semantic similarity model [here](https://foundation.mozilla.org/en/blog/the-regretsreporter-user-controls-study-machine-learning-to-measure-semantic-similarity-of-youtube-videos/). Note: the model is multilingual so you can try it with non-English videos too while it probably works the best with English videos. | |
This demo works by inserting two YouTube video URLs below and clicking the Run button. After a few seconds, you will see model's predicted probability of how similar those two videos are. You can copy URLs from YouTube or also try out a few predefined examples by clicking them on the examples table. | |
''' | |
placeholder_youtube_embedded_html = ''' | |
<p>Insert video URL first</p> | |
''' | |
model_wt = YoutubeVideoSimilarityModel.from_pretrained( | |
'mozilla-foundation/youtube_video_similarity_model_wt') | |
model_nt = YoutubeVideoSimilarityModel.from_pretrained( | |
'mozilla-foundation/youtube_video_similarity_model_nt') | |
cross_encoder_model_name_or_path = model_wt.cross_encoder_model_name_or_path | |
def get_video_similarity(video1_url, video2_url): | |
df = get_input_data_df(video1_url, video2_url) | |
if df['regret_transcript'].isna().any() or df['recommendation_transcript'].isna().any(): | |
with_transcript = False | |
else: | |
with_transcript = True | |
try: | |
dataset = RRUMDataset(df, with_transcript=with_transcript, label_col=None, | |
cross_encoder_model_name_or_path=cross_encoder_model_name_or_path) | |
data_loader = DataLoader(dataset.test_dataset, shuffle=False, | |
batch_size=1, num_workers=0, pin_memory=False) | |
with torch.inference_mode(): | |
if with_transcript: | |
pred = model_wt(next(iter(data_loader))) | |
else: | |
pred = model_nt(next(iter(data_loader))) | |
pred = torch.special.expit(pred).squeeze().tolist() | |
except: | |
raise gr.Error( | |
f'There was error in getting a prediction from the model, please try again.') | |
return f'YouTube videos are {pred:.0%} similar' | |
with gr.Blocks(title=demo_title) as demo: | |
gr.Markdown(demo_description) | |
with gr.Row(): | |
with gr.Column(): | |
input_text1 = gr.Textbox( | |
label='Video 1', placeholder='Insert first YouTube video URL') | |
input_text2 = gr.Textbox( | |
label='Video 2', placeholder='Insert second YouTube video URL') | |
inputs = [input_text1, input_text2] | |
with gr.Row(): | |
clear_btn = gr.Button('Clear', variant='secondary') | |
run_btn = gr.Button('Run', variant='primary') | |
with gr.Column(): | |
output_label = gr.Label(label='Model prediction') | |
outputs = [output_label] | |
with gr.Accordion('See video details', open=False): | |
with gr.Row(): | |
with gr.Column(): | |
video_embedded1 = gr.HTML( | |
value=placeholder_youtube_embedded_html) | |
with gr.Column(): | |
video_embedded2 = gr.HTML( | |
value=placeholder_youtube_embedded_html) | |
with gr.Column(): | |
if example_videos: | |
examples = gr.Examples(examples=example_videos, inputs=inputs) | |
if example_videos_rr: | |
examples_rr = gr.Examples(examples=example_videos_rr, inputs=inputs, | |
label='Example bad becommendations from the RegretsReporter report') | |
def inputs_change(input, position): | |
embedded_value = update_youtube_embedded_html( | |
input, position) if input else placeholder_youtube_embedded_html | |
if position == 1: | |
return {video_embedded1: embedded_value, output_label: None} | |
else: | |
return {video_embedded2: embedded_value, output_label: None} | |
run_btn.click(fn=get_video_similarity, inputs=inputs, outputs=outputs) | |
# no need clear output label as it will get cleared anyway with inputs_change() | |
clear_btn.click(lambda value_1, value_2: (None, None), | |
inputs=inputs, outputs=inputs, queue=False) | |
input_text1.change(lambda input: inputs_change( | |
input, 1), inputs=input_text1, outputs=[video_embedded1, output_label], queue=False) | |
input_text2.change(lambda input: inputs_change( | |
input, 2), inputs=input_text2, outputs=[video_embedded2, output_label], queue=False) | |
demo.queue() | |
demo.launch() | |