File size: 6,239 Bytes
b34130a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75f7430
b34130a
 
 
 
 
 
 
75f7430
b34130a
 
 
 
 
 
 
 
ee2c37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b34130a
ee2c37d
b34130a
 
75f7430
b34130a
75f7430
b34130a
 
 
 
 
 
0432264
b34130a
 
 
 
 
 
 
715ff1a
91b91b4
b34130a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c1b6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import re
import os
import html

import gradio as gr
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth


def mark_tokens_bold(text, tokens):
    for token in tokens:
        if token in ["<", "b", "/", ">"]:
            continue
        pattern = re.escape(token)  # r"\b" + re.escape(token) + r"\b"
        text = re.sub(pattern, "<b>" + token + "</b>", text)
    return text


def process_results(results, query):
    if len(results) == 0:
        return """<br><p>No results retrieved.</p><br><hr>"""

    results_html = ""
    for result in results:
        text_html = result["text"]
        if query.startswith('"') and query.endswith('"'):
            text_html = mark_tokens_bold(text_html, query[1:-1].split(" "))
        else:
            text_html = mark_tokens_bold(text_html, query.split(" "))
        repository = result["repository"]
        commit_id = result["commit_id"]
        path = result["path"]
        license = result["license"]
        language = result["language"]
        code_height = min(
            600, len(text_html.split("\n")) * 20
        )  # limit to maximum height of 600px
        results_html += """\
        <p style='font-size:16px; text-align: left;'><b>Source: </b><a target="_blank" href="https://github.com/{}/blob/{}{}">{}</a>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;<b>Language:</b> \
        <span style='color: #00134d;'>{}</span>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;<b>Licenses: </b><span style='color: #00134d;'>{}</span></p>
        <pre style='height: {}px; overflow-y: scroll; overflow-x: hidden; color: #d9d9d9;border: 1px solid #e6b800; padding: 10px'><code>{}</code></pre>
        <hr>
        """.format(
            repository,
            commit_id,
            path,
            f"{repository}/blob/{commit_id}{path}",
            language,
            license,
            code_height,
            text_html,
        )
    return results_html


def match_query(es, query, num_results=10):
    query_body = {"query": {"match": {"content": query}}, "size": num_results}

    response = es.search(index=os.environ.get("INDEX"), body=query_body)
    hits = [hit["_source"] for hit in response["hits"]["hits"]]
    return hits


def phrase_query(es, query, num_results=10):
    query_body = {"query": {"match_phrase": {"content": query}}, "size": num_results}

    response = es.search(index=os.environ.get("INDEX"), body=query_body)
    hits = [hit["_source"] for hit in response["hits"]["hits"]]
    return hits


def search(query, num_results=10):
    awsauth = AWS4Auth(
        os.environ.get("ACCESS_KEY"),
        os.environ.get("SECRET_KEY"),
        "us-east-1",
        "es",
    )
    
    es = OpenSearch(
        hosts=[{"host": os.environ.get("HOST"), "port": 443}],
        http_auth=awsauth,
        use_ssl=True,
        verify_certs=True,
        connection_class=RequestsHttpConnection,
        http_compress=True,
        timeout=200,
    )
    print(es.ping())
    
    query = query[:200]
    if query.startswith('"') and query.endswith('"'):
        response = phrase_query(es, query[1:-1], num_results=num_results)
    else:
        response = match_query(es, query, num_results=num_results)
    results = [
        {
            "text": html.escape(hit["content"]),
            "repository": hit["repository"],
            "commit_id": hit["commit_id"],
            "path": hit["path"],
            "license": ", ".join(hit["scancode_licenses"]) if (hit["gh_license"] is None or hit["gh_license"] == "NOASSERTION") else hit["gh_license"],
            "language": hit["language"],
        }
        for hit in response
    ]
    return process_results(results, query)


description = """# <p style="text-align: center;"><span style='color: #e6b800;'>StarCoder2:</span> Dataset Search πŸ” </p>
<span>When using <a href="https://huggingface.co/bigcode/starcoder2-15b" style="color: #e6b800;">StarCoder2</a> to generate code, it might produce close or exact copies of code in the pretraining dataset. Identifying such cases can provide important context, and help credit the original developer of the code. With this search tool, our aim is to help in identifying if the code belongs to an existing repository. For exact matches, enclose your query in double quotes. <br><br><i>This first iteration of the search tool truncates queries down to 200 characters, so as not to overwhelm the server it is currently running on. Please note that this is not a production-ready app, but rather a research tool that we make available as a proof-of-concept. If you need a reliable search app for your business or research, we would advise you to index the dataset yourself.</i></span>"""

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)
css = ".generating {visibility: hidden}"

monospace_css = """
#q-input textarea {
    font-family: monospace, 'Consolas', Courier, monospace;
}
"""

css = monospace_css + ".gradio-container {color: black}"


if __name__ == "__main__":
    demo = gr.Blocks(
        theme=theme,
        css=css,
    )

    with demo:
        with gr.Row():
            gr.Markdown(value=description)
        with gr.Row():
            query = gr.Textbox(
                lines=5,
                placeholder="Type your query here...",
                label="Query",
                elem_id="q-input",
            )
        with gr.Row():
            k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
        with gr.Row():
            submit_btn = gr.Button("Submit")
        with gr.Row():
            results = gr.HTML(label="Results", value="")

        def submit(query, k, lang="en"):
            query = query.strip()
            if query is None or query == "":
                return "", ""
            return {
                results: search(query, k),
            }

        query.submit(fn=submit, inputs=[query, k], outputs=[results])
        submit_btn.click(submit, inputs=[query, k], outputs=[results])

    demo.queue(max_size=32).launch()