gaia / app.py
cakiki's picture
Update app.py
147eb37
raw
history blame
11.2 kB
import json
import os
import pprint
import streamlit as st
import streamlit.components.v1 as components
import requests
from typing import Union
pp = pprint.PrettyPrinter(indent=2)
st.set_page_config(page_title="Gaia Search 🌖🌏", layout="wide")
os.makedirs(os.path.join(os.getcwd(), ".streamlit"), exist_ok=True)
with open(os.path.join(os.getcwd(), ".streamlit/config.toml"), "w") as file:
file.write('[theme]\nbase="light"')
corpus_name_map = {
"LAION": "laion",
"ROOTS": "roots",
"The Pile": "pile",
"C4": "c4",
}
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
font-weight: bold;
font-size: 36px;
}
</style>
<p class="aligncenter">Gaia Search 🌖🌏</p>
<p>A search engine for large scale texual
corpora. Most of the datasets included in the tool are based on Common
Crawl. By using the tool, you are also bound by the Common Crawl terms
of use in respect of the content contained in the datasets.
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p style='text-align: center'>
<a href="https://github.com/huggingface/gaia" style="color:#7978FF;">GitHub</a> | <a href="https://arxiv.org/abs/2306.01481" style="color:#7978FF;" >Paper</a> | <a href="" style="color:#7978FF;" >Colab</a>
</p>
""",
unsafe_allow_html=True,
)
# <p class="aligncenter">
# <a href="" target="_blank">
# <img src="https://colab.research.google.com/assets/colab-badge.svg"/>
# </a>
# </p>
query = st.sidebar.text_input(label="Query", placeholder="Type your query here")
corpus = st.sidebar.selectbox(
"Corpus",
tuple(corpus_name_map.keys()),
index=2,
)
max_results = st.sidebar.slider(
"Max Results",
min_value=1,
max_value=100,
step=1,
value=10,
help="Max Number of Documents to return",
)
# dark_mode_toggle = """
# <script>
# function load_image(id){
# console.log(id)
# var x = document.getElementById(id);
# console.log(x)
# if (x.style.display === "none") {
# x.style.display = "block";
# } else {
# x.style.display = "none";
# }
# };
# function myFunction() {
# var element = document.body;
# element.classList.toggle("dark-mode");
# }
# </script>
# <button onclick="myFunction()">Toggle dark mode</button>
# """
# st.sidebar.markdown(dark_mode_toggle, unsafe_allow_html=True)
footer = """
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: black;
text-align: center;
}
</style>
<div class="footer">
<p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p>
</div>
"""
st.sidebar.markdown(footer, unsafe_allow_html=True)
def scisearch(query, corpus, num_results=10):
try:
print(query, corpus, num_results)
query = query.strip()
if query == "" or query is None:
return
post_data = {"query": query, "corpus": corpus, "k": num_results, "lang": "all"}
address = (
os.environ.get("address")
if corpus != "roots"
else os.environ.get("address_roots")
)
output = requests.post(
address,
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=60,
)
payload = json.loads(output.text)
return payload["results"], payload["highlight_terms"]
except Exception as e:
print(e)
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
PII_PREFIX = "PI:"
def process_pii(text):
for tag in PII_TAGS:
text = text.replace(
PII_PREFIX + tag,
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(
tag
),
)
return text
def highlight_string(paragraph: str, highlight_terms: list) -> str:
tokens = paragraph.split()
tokens_html = []
for token in tokens:
if token in highlight_terms:
tokens_html.append("<b>{}</b>".format(token))
else:
tokens_html.append(token)
tokens_html = " ".join(tokens_html)
return process_pii(tokens_html)
def extract_lang_from_docid(docid):
return docid.split("_")[1]
def format_result(result, highlight_terms):
text = result["text"]
docid = result["docid"]
tokens_html = highlight_string(text, highlight_terms)
language = extract_lang_from_docid(docid)
result_html = """
<span style='font-size:14px; font-family: Arial; color:MediumAquaMarine'>Language: {} | </span>
<span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {} | </span><br>
<span style='font-family: Arial;'>{}</span><br>
<br>
""".format(
language, docid, tokens_html
)
return "<p>" + result_html + "</p>"
def process_results(corpus: str, hits: Union[list, dict], highlight_terms: list) -> str:
hit_list = []
if corpus == "roots":
result_page_html = ""
for lang, results_for_lang in hits.items():
print("Processing language", lang)
if len(results_for_lang) == 0:
result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
No results for language: <b>{}</b></div>""".format(
lang
)
continue
results_for_lang_html = ""
for result in results_for_lang:
result_html = format_result(result, highlight_terms)
results_for_lang_html += result_html
results_for_lang_html = f"""
<details>
<summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'>
Results for language: <b>{lang}</b>
</summary>
{results_for_lang_html}
</details>"""
result_page_html += results_for_lang_html
return result_page_html
for hit in hits:
res_head = f"""
<p class="searchresult" style="color: #7978FF;">Document ID: {hit['docid']} | Score: {round(hit['score'], 2)}</p>
"""
if corpus == "laion":
res_head += f"""
<p style="color: #7978FF;">Caption:</p>
<p>{highlight_string(hit['text'], highlight_terms)}</p>
"""
if (
"meta" in hit
and hit["meta"] is not None
and "docs" in hit["meta"]
and len(hit["meta"]["docs"]) > 0
):
res_head += """<p style="color: #7978FF;"> Image links:</p><ul>"""
for subhit in hit["meta"]["docs"]:
res_head += f"""<li><a href={subhit["URL"]} target="_blank" style="color:#ffcdf8; ">{subhit["URL"]}</a></li>"""
res_head += "</ul>"
res_head += "<hr>"
else:
res_head += (
f"""<p>{highlight_string(hit['text'], highlight_terms)}</p></div><hr>"""
)
hit_list.append(res_head)
return " ".join(hit_list)
submit_button = st.sidebar.button("Search", type="primary")
if submit_button or query:
query = query.strip()
if query is None or query == "":
components.html(
"""<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
Please provide a non-empty query.
</p><br><hr><br>"""
)
else:
hits, highlight_terms = scisearch(query, corpus_name_map[corpus], max_results)
html_results = process_results(corpus_name_map[corpus], hits, highlight_terms)
rendered_results = f"""
<div id="searchresultsarea">
<br>
<p id="searchresultsnumber">About {max_results} results</p>
{html_results}
</div>"""
# st.markdown(
# """
# <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet"
# integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
# """,
# unsafe_allow_html=True,
# )
# st.markdown(
# """
# <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
# """,
# unsafe_allow_html=True,
# )
# st.markdown(
# f"""
# <div class="row no-gutters mt-3 align-items-center">
# Gaia Search 🌖🌏
# <div class="col col-md-4">
# <input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2">
# </div>
# <div class="col-auto">
# <button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button">
# <i class="fa fa-search"></i>
# </button>
# </div>
# </div>
# """,
# unsafe_allow_html=True,
# )
# .bk-root{position:relative;width:auto;height:auto;box-sizing:border-box;font-family:Helvetica, Arial, sans-serif;font-size:13px;}.bk-root .bk,.bk-root .bk:before,.bk-root .bk:after{box-sizing:inherit;margin:0;border:0;padding:0;background-image:none;font-family:inherit;font-size:100%;line-height:1.42857143;}.bk-root pre.bk{font-family:Courier, monospace;}
components.html(
"""
<head>
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
</head>
<style>
#searchresultsarea {
font-family: "Source Sans Pro", sans-serif;
}
#searchresultsnumber {
font-size: 0.8rem;
color: gray;
}
.searchresult h2 {
font-size: 19px;
line-height: 18px;
font-weight: normal;
color: rgb(7, 111, 222);
margin-bottom: 0px;
margin-top: 25px;
color: #7978FF;"
}
.searchresult a {
font-size: 12px;
line-height: 12px;
color: green;
margin-bottom: 0px;
}
</style>
"""
+ rendered_results,
height=800,
scrolling=True,
)