File size: 6,164 Bytes
b34130a 0432264 b34130a 715ff1a 91b91b4 b34130a 7cec85f b34130a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import re
import os
import html
import gradio as gr
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
awsauth = AWS4Auth(
os.environ.get("ACCESS_KEY"),
os.environ.get("SECRET_KEY"),
"us-east-1",
"es",
)
es = OpenSearch(
hosts=[{"host": os.environ.get("HOST"), "port": 443}],
http_auth=awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
http_compress=True,
timeout=200,
)
def mark_tokens_bold(text, tokens):
for token in tokens:
if token in ["<", "b", "/", ">"]:
continue
pattern = re.escape(token) # r"\b" + re.escape(token) + r"\b"
text = re.sub(pattern, "<b>" + token + "</b>", text)
return text
def process_results(results, query):
if len(results) == 0:
return """<br><p>No results retrieved.</p><br><hr>"""
results_html = ""
for result in results:
text_html = result["text"]
if query.startswith('"') and query.endswith('"'):
text_html = mark_tokens_bold(text_html, query[1:-1].split(" "))
else:
text_html = mark_tokens_bold(text_html, query.split(" "))
repository = result["repository"]
commit_id = result["commit_id"]
path = result["path"]
license = result["license"]
language = result["language"]
code_height = min(
600, len(text_html.split("\n")) * 20
) # limit to maximum height of 600px
results_html += """\
<p style='font-size:16px; text-align: left;'><b>Source: </b><a target="_blank" href="https://github.com/{}/blob/{}{}">{}</a> | <b>Language:</b> \
<span style='color: #00134d;'>{}</span> | <b>Licenses: </b><span style='color: #00134d;'>{}</span></p>
<pre style='height: {}px; overflow-y: scroll; overflow-x: hidden; color: #d9d9d9;border: 1px solid #e6b800; padding: 10px'><code>{}</code></pre>
<hr>
""".format(
repository,
commit_id,
path,
f"{repository}/blob/{commit_id}{path}",
language,
license,
code_height,
text_html,
)
return results_html
def match_query(query, num_results=10):
query_body = {"query": {"match": {"content": query}}, "size": num_results}
response = es.search(index=os.environ.get("INDEX"), body=query_body)
hits = [hit["_source"] for hit in response["hits"]["hits"]]
return hits
def phrase_query(query, num_results=10):
query_body = {"query": {"match_phrase": {"content": query}}, "size": num_results}
response = es.search(index=os.environ.get("INDEX"), body=query_body)
hits = [hit["_source"] for hit in response["hits"]["hits"]]
return hits
def search(query, num_results=10):
print(es.ping())
query = query[:200]
if query.startswith('"') and query.endswith('"'):
response = phrase_query(query[1:-1], num_results=num_results)
else:
response = match_query(query, num_results=num_results)
results = [
{
"text": html.escape(hit["content"]),
"repository": hit["repository"],
"commit_id": hit["commit_id"],
"path": hit["path"],
"license": ", ".join(hit["scancode_licenses"]) if (hit["gh_license"] is None or hit["gh_license"] == "NOASSERTION") else hit["gh_license"],
"language": hit["language"],
}
for hit in response
]
return process_results(results, query)
description = """# <p style="text-align: center;"><span style='color: #e6b800;'>StarCoder2:</span> Dataset Search π </p>
<span>When using <a href="https://huggingface.co/bigcode/starcoder2-15b" style="color: #e6b800;">StarCoder2</a> to generate code, it might produce close or exact copies of code in the pretraining dataset. Identifying such cases can provide important context, and help credit the original developer of the code. With this search tool, our aim is to help in identifying if the code belongs to an existing repository. For exact matches, enclose your query in double quotes. <br><br><i>This first iteration of the search tool truncates queries down to 200 characters, so as not to overwhelm the server it is currently running on. Please note that this is not a production-ready app, but rather a research tool that we make available as a proof-of-concept. If you need a reliable search app for your business or research, we would advise you to index the dataset yourself.</i></span>"""
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
css = ".generating {visibility: hidden}"
monospace_css = """
#q-input textarea {
font-family: monospace, 'Consolas', Courier, monospace;
}
"""
css = monospace_css + ".gradio-container {color: black}"
if __name__ == "__main__":
demo = gr.Blocks(
theme=theme,
css=css,
)
with demo:
with gr.Row():
gr.Markdown(value=description)
with gr.Row():
query = gr.Textbox(
lines=5,
placeholder="Type your query here...",
label="Query",
elem_id="q-input",
)
with gr.Row():
k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row():
results = gr.HTML(label="Results", value="")
def submit(query, k, lang="en"):
query = query.strip()
if query is None or query == "":
return "", ""
return {
results: search(query, k),
}
query.submit(fn=submit, inputs=[query, k], outputs=[results])
submit_btn.click(submit, inputs=[query, k], outputs=[results])
demo.queue()
demo.launch(debug=True)
|