Spaces:
Running
Running
from omegaconf import OmegaConf | |
from query import VectaraQuery | |
import os | |
import streamlit as st | |
from PIL import Image | |
import concurrent.futures | |
SCORE_THRESHOLD = 0.7 | |
def inject_custom_css(): | |
st.markdown( | |
""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;500;700&display=swap'); | |
body { | |
font-family: 'Roboto', sans-serif; | |
background-color: #f5f5f5; | |
color: #333; | |
} | |
body { | |
font-family: 'Roboto', sans-serif; | |
background-color: #f5f5f5; | |
color: #333; | |
padding-top: 0px; | |
margin-top: 0px; | |
} | |
.stApp { | |
padding-top: 0px; | |
margin-top: 0px; | |
} | |
.stButton>button { | |
margin-top: 25px; | |
background-color: #4CAF50; | |
color: white; | |
padding: 10px 24px; | |
border: none; | |
cursor: pointer; | |
border-radius: 4px; | |
} | |
.stButton>button:hover { | |
background-color: #45a049; | |
} | |
.stTextInput>div>input { | |
padding: 10px; | |
border-radius: 4px; | |
border: 1px solid #ccc; | |
font-size: 16px; | |
} | |
.stTextInput>div>input:focus { | |
border-color: #007BFF; | |
outline: none; | |
box-shadow: 0 0 5px rgba(0, 123, 255, 0.5); | |
} | |
.centered { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
} | |
.sidebar .stImage { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 20px; | |
} | |
.css-1lcbmhc.e1fqkh3o3 { /* This targets the sidebar */ | |
background-color: #ffffff !important; | |
color: #333 !important; | |
} | |
.css-1d391kg { /* This targets the sidebar headings */ | |
color: #333 !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
def fetch_summary(vq, matching_text, doc_id): | |
return vq.get_summary(matching_text, doc_id) | |
def launch_app(): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
if 'cfg' not in st.session_state: | |
cfg = OmegaConf.create({ | |
'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']), | |
'corpus_id': str(os.environ['VECTARA_CORPUS_ID']), | |
'api_key': str(os.environ['VECTARA_API_KEY']), | |
'streaming': False | |
}) | |
st.session_state.cfg = cfg | |
st.session_state.vq = VectaraQuery(cfg.api_key, cfg.customer_id, [cfg.corpus_id]) | |
cfg = st.session_state.cfg | |
vq = st.session_state.vq | |
st.set_page_config(page_title="Media Demo", layout="wide") | |
inject_custom_css() | |
header_image = Image.open('header-image-2.png') | |
cropped_image = header_image.crop((0, 0, header_image.width, 150)) | |
st.image(cropped_image, use_column_width=True) | |
# left side content | |
with st.sidebar: | |
image = Image.open('vectara-logo.png') | |
st.markdown("## Welcome to Media Demo\n\n" | |
"This demo uses Vectara to find the movie where a quote is from.\n\n" | |
"Covers movies from this [playlist](https://www.youtube.com/playlist?list=PLHPTxTxtC0ibVZrT2_WKWUl2SAxsKuKwx) of free movies.") | |
st.markdown("---") | |
st.markdown( | |
"## How this works?\n" | |
"This app was built with [Vectara](https://vectara.com).\n" | |
) | |
st.markdown("---") | |
st.image(image, width=250) | |
st.markdown("<center> <h3>\"Where did I hear that line?\"</h3> </center>", unsafe_allow_html=True) | |
_, q_col, _ = st.columns([1, 4, 1]) | |
with q_col: | |
quote = st.text_input("quote", label_visibility="hidden", placeholder="Enter a quote from a movie.") | |
prev_quote = st.session_state.get('prev_quote', '') | |
if quote != prev_quote: | |
st.session_state.quote = quote | |
st.session_state.prev_quote = quote | |
st.session_state.movie_name, st.session_state.match_url, st.session_state.score, doc_id, matching_text = vq.submit_query(quote) | |
if st.session_state.score < SCORE_THRESHOLD: | |
st.session_state.movie_name = None | |
else: | |
future = executor.submit(fetch_summary, vq, matching_text, doc_id) | |
st.session_state.summary_future = future | |
if 'score' in st.session_state and st.session_state.score: | |
if st.session_state.movie_name is None: | |
st.write("Sorry, I couldn't find a match for that quote. Please try another one.") | |
else: | |
video_url, start_time = st.session_state.match_url.split('&t=') | |
video_url = f"{video_url}&cc_load_policy=1" | |
start_time = start_time[:-1] # remove the trailing 's' | |
_, video_col, summary_col = st.columns([1, 4, 3]) | |
with video_col: | |
st.video(video_url, start_time=int(float(start_time))) | |
with summary_col: | |
# Display the summary when it's ready | |
if 'summary_future' in st.session_state: | |
if st.session_state.summary_future.done(): | |
st.markdown("**Summary:**") | |
st.session_state.summary = st.session_state.summary_future.result() | |
st.markdown(st.session_state.summary) | |
if not st.session_state.summary_future.done(): | |
st.rerun() | |
if __name__ == "__main__": | |
launch_app() | |