Spaces:
Running
Running
import json | |
import os | |
import pprint | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import requests | |
from typing import Union | |
pp = pprint.PrettyPrinter(indent=2) | |
st.set_page_config(page_title="Gaia Search 🌖🌏", layout="wide") | |
os.makedirs(os.path.join(os.getcwd(), ".streamlit"), exist_ok=True) | |
with open(os.path.join(os.getcwd(), ".streamlit/config.toml"), "w") as file: | |
file.write('[theme]\nbase="light"') | |
corpus_name_map = { | |
"LAION": "laion", | |
"ROOTS": "roots", | |
"The Pile": "pile", | |
"C4": "c4", | |
} | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
font-weight: bold; | |
font-size: 36px; | |
} | |
</style> | |
<p class="aligncenter">Gaia Search 🌖🌏</p> | |
<p>A search engine for large scale texual | |
corpora. Most of the datasets included in the tool are based on Common | |
Crawl. By using the tool, you are also bound by the Common Crawl terms | |
of use in respect of the content contained in the datasets. | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p style='text-align: center'> | |
<a href="https://github.com/huggingface/gaia" style="color:#7978FF;">GitHub</a> | <a href="https://arxiv.org/abs/2306.01481" style="color:#7978FF;" >Paper</a> | <a href="" style="color:#7978FF;" >Colab</a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
# <p class="aligncenter"> | |
# <a href="" target="_blank"> | |
# <img src="https://colab.research.google.com/assets/colab-badge.svg"/> | |
# </a> | |
# </p> | |
query = st.sidebar.text_input(label="Query", placeholder="Type your query here") | |
corpus = st.sidebar.selectbox( | |
"Corpus", | |
tuple(corpus_name_map.keys()), | |
index=2, | |
) | |
max_results = st.sidebar.slider( | |
"Max Results", | |
min_value=1, | |
max_value=100, | |
step=1, | |
value=10, | |
help="Max Number of Documents to return", | |
) | |
# dark_mode_toggle = """ | |
# <script> | |
# function load_image(id){ | |
# console.log(id) | |
# var x = document.getElementById(id); | |
# console.log(x) | |
# if (x.style.display === "none") { | |
# x.style.display = "block"; | |
# } else { | |
# x.style.display = "none"; | |
# } | |
# }; | |
# function myFunction() { | |
# var element = document.body; | |
# element.classList.toggle("dark-mode"); | |
# } | |
# </script> | |
# <button onclick="myFunction()">Toggle dark mode</button> | |
# """ | |
# st.sidebar.markdown(dark_mode_toggle, unsafe_allow_html=True) | |
footer = """ | |
<style> | |
.footer { | |
position: fixed; | |
left: 0; | |
bottom: 0; | |
width: 100%; | |
background-color: white; | |
color: black; | |
text-align: center; | |
} | |
</style> | |
<div class="footer"> | |
<p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p> | |
</div> | |
""" | |
st.sidebar.markdown(footer, unsafe_allow_html=True) | |
def scisearch(query, corpus, num_results=10): | |
try: | |
print(query, corpus, num_results) | |
query = query.strip() | |
if query == "" or query is None: | |
return | |
post_data = {"query": query, "corpus": corpus, "k": num_results, "lang": "all"} | |
address = ( | |
os.environ.get("address") | |
if corpus != "roots" | |
else os.environ.get("address_roots") | |
) | |
output = requests.post( | |
address, | |
headers={"Content-type": "application/json"}, | |
data=json.dumps(post_data), | |
timeout=60, | |
) | |
payload = json.loads(output.text) | |
return payload["results"], payload["highlight_terms"] | |
except Exception as e: | |
print(e) | |
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
PII_PREFIX = "PI:" | |
def process_pii(text): | |
for tag in PII_TAGS: | |
text = text.replace( | |
PII_PREFIX + tag, | |
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format( | |
tag | |
), | |
) | |
return text | |
def highlight_string(paragraph: str, highlight_terms: list) -> str: | |
tokens = paragraph.split() | |
tokens_html = [] | |
for token in tokens: | |
if token in highlight_terms: | |
tokens_html.append("<b>{}</b>".format(token)) | |
else: | |
tokens_html.append(token) | |
tokens_html = " ".join(tokens_html) | |
return process_pii(tokens_html) | |
def extract_lang_from_docid(docid): | |
return docid.split("_")[1] | |
def format_result(result, highlight_terms): | |
text = result["text"] | |
docid = result["docid"] | |
tokens_html = highlight_string(text, highlight_terms) | |
language = extract_lang_from_docid(docid) | |
result_html = """ | |
<span style='font-size:14px; font-family: Arial; color:MediumAquaMarine'>Language: {} | </span> | |
<span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {} | </span><br> | |
<span style='font-family: Arial;'>{}</span><br> | |
<br> | |
""".format( | |
language, docid, tokens_html | |
) | |
return "<p>" + result_html + "</p>" | |
def process_results(corpus: str, hits: Union[list, dict], highlight_terms: list) -> str: | |
hit_list = [] | |
if corpus == "roots": | |
result_page_html = "" | |
for lang, results_for_lang in hits.items(): | |
print("Processing language", lang) | |
if len(results_for_lang) == 0: | |
result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
No results for language: <b>{}</b></div>""".format( | |
lang | |
) | |
continue | |
results_for_lang_html = "" | |
for result in results_for_lang: | |
result_html = format_result(result, highlight_terms) | |
results_for_lang_html += result_html | |
results_for_lang_html = f""" | |
<details> | |
<summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'> | |
Results for language: <b>{lang}</b> | |
</summary> | |
{results_for_lang_html} | |
</details>""" | |
result_page_html += results_for_lang_html | |
return result_page_html | |
for hit in hits: | |
res_head = f""" | |
<p class="searchresult" style="color: #7978FF;">Document ID: {hit['docid']} | Score: {round(hit['score'], 2)}</p> | |
""" | |
if corpus == "laion": | |
res_head += f""" | |
<p style="color: #7978FF;">Caption:</p> | |
<p>{highlight_string(hit['text'], highlight_terms)}</p> | |
""" | |
if ( | |
"meta" in hit | |
and hit["meta"] is not None | |
and "docs" in hit["meta"] | |
and len(hit["meta"]["docs"]) > 0 | |
): | |
res_head += """<p style="color: #7978FF;"> Image links:</p><ul>""" | |
for subhit in hit["meta"]["docs"]: | |
res_head += f"""<li><a href={subhit["URL"]} target="_blank" style="color:#ffcdf8; ">{subhit["URL"]}</a></li>""" | |
res_head += "</ul>" | |
res_head += "<hr>" | |
else: | |
res_head += ( | |
f"""<p>{highlight_string(hit['text'], highlight_terms)}</p></div><hr>""" | |
) | |
hit_list.append(res_head) | |
return " ".join(hit_list) | |
submit_button = st.sidebar.button("Search", type="primary") | |
if submit_button or query: | |
query = query.strip() | |
if query is None or query == "": | |
components.html( | |
"""<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
Please provide a non-empty query. | |
</p><br><hr><br>""" | |
) | |
else: | |
hits, highlight_terms = scisearch(query, corpus_name_map[corpus], max_results) | |
html_results = process_results(corpus_name_map[corpus], hits, highlight_terms) | |
rendered_results = f""" | |
<div id="searchresultsarea"> | |
<br> | |
<p id="searchresultsnumber">About {max_results} results</p> | |
{html_results} | |
</div>""" | |
# st.markdown( | |
# """ | |
# <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet" | |
# integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous"> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
# st.markdown( | |
# """ | |
# <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
# st.markdown( | |
# f""" | |
# <div class="row no-gutters mt-3 align-items-center"> | |
# Gaia Search 🌖🌏 | |
# <div class="col col-md-4"> | |
# <input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2"> | |
# </div> | |
# <div class="col-auto"> | |
# <button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button"> | |
# <i class="fa fa-search"></i> | |
# </button> | |
# </div> | |
# </div> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
# .bk-root{position:relative;width:auto;height:auto;box-sizing:border-box;font-family:Helvetica, Arial, sans-serif;font-size:13px;}.bk-root .bk,.bk-root .bk:before,.bk-root .bk:after{box-sizing:inherit;margin:0;border:0;padding:0;background-image:none;font-family:inherit;font-size:100%;line-height:1.42857143;}.bk-root pre.bk{font-family:Courier, monospace;} | |
components.html( | |
""" | |
<head> | |
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'> | |
</head> | |
<style> | |
#searchresultsarea { | |
font-family: "Source Sans Pro", sans-serif; | |
} | |
#searchresultsnumber { | |
font-size: 0.8rem; | |
color: gray; | |
} | |
.searchresult h2 { | |
font-size: 19px; | |
line-height: 18px; | |
font-weight: normal; | |
color: rgb(7, 111, 222); | |
margin-bottom: 0px; | |
margin-top: 25px; | |
color: #7978FF;" | |
} | |
.searchresult a { | |
font-size: 12px; | |
line-height: 12px; | |
color: green; | |
margin-bottom: 0px; | |
} | |
</style> | |
""" | |
+ rendered_results, | |
height=800, | |
scrolling=True, | |
) | |