Spaces:
Running
Running
Commit
β’
0e05863
1
Parent(s):
fabd282
remove flask server
Browse files- pages/__init__.py +0 -2
- pages/document.py +0 -21
- pages/search_engine.py +0 -157
- server/api.py β st_utils.py +53 -33
pages/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .search_engine import page as search_engine_page
|
2 |
-
from .document import page as document_page
|
|
|
|
|
|
pages/document.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import json
|
4 |
-
import datetime
|
5 |
-
import itertools
|
6 |
-
import requests
|
7 |
-
from PIL import Image
|
8 |
-
import base64
|
9 |
-
import streamlit as st
|
10 |
-
|
11 |
-
def page():
|
12 |
-
record = st.session_state.get("selected_record")
|
13 |
-
st.set_page_config(
|
14 |
-
page_title=f"Record {record['filename']}",
|
15 |
-
page_icon="π¨ββοΈ",
|
16 |
-
layout="wide",
|
17 |
-
initial_sidebar_state="collapsed",
|
18 |
-
)
|
19 |
-
st.button("Back", on_click=lambda: set_record(None))
|
20 |
-
|
21 |
-
st.write(record)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/search_engine.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import json
|
4 |
-
import datetime
|
5 |
-
import itertools
|
6 |
-
import requests
|
7 |
-
from PIL import Image
|
8 |
-
import base64
|
9 |
-
import streamlit as st
|
10 |
-
from huggingface_hub import ModelSearchArguments
|
11 |
-
import webbrowser
|
12 |
-
from numerize.numerize import numerize
|
13 |
-
|
14 |
-
def paginator(label, articles, articles_per_page=10, on_sidebar=True):
|
15 |
-
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
|
16 |
-
"""Lets the user paginate a set of article.
|
17 |
-
Parameters
|
18 |
-
----------
|
19 |
-
label : str
|
20 |
-
The label to display over the pagination widget.
|
21 |
-
article : Iterator[Any]
|
22 |
-
The articles to display in the paginator.
|
23 |
-
articles_per_page: int
|
24 |
-
The number of articles to display per page.
|
25 |
-
on_sidebar: bool
|
26 |
-
Whether to display the paginator widget on the sidebar.
|
27 |
-
|
28 |
-
Returns
|
29 |
-
-------
|
30 |
-
Iterator[Tuple[int, Any]]
|
31 |
-
An iterator over *only the article on that page*, including
|
32 |
-
the item's index.
|
33 |
-
"""
|
34 |
-
|
35 |
-
# Figure out where to display the paginator
|
36 |
-
if on_sidebar:
|
37 |
-
location = st.sidebar.empty()
|
38 |
-
else:
|
39 |
-
location = st.empty()
|
40 |
-
|
41 |
-
# Display a pagination selectbox in the specified location.
|
42 |
-
articles = list(articles)
|
43 |
-
n_pages = (len(articles) - 1) // articles_per_page + 1
|
44 |
-
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
|
45 |
-
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)
|
46 |
-
|
47 |
-
# Iterate over the articles in the page to let the user display them.
|
48 |
-
min_index = page_number * articles_per_page
|
49 |
-
max_index = min_index + articles_per_page
|
50 |
-
|
51 |
-
return itertools.islice(enumerate(articles), min_index, max_index)
|
52 |
-
|
53 |
-
def page():
|
54 |
-
### SIDEBAR
|
55 |
-
search_backend = st.sidebar.selectbox(
|
56 |
-
"Search method",
|
57 |
-
["semantic", "bm25", "hfapi"],
|
58 |
-
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x],
|
59 |
-
)
|
60 |
-
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
|
61 |
-
|
62 |
-
st.sidebar.markdown("# Filters")
|
63 |
-
args = ModelSearchArguments()
|
64 |
-
library = st.sidebar.multiselect(
|
65 |
-
"Library", args.library.values(), format_func=lambda x: {v: k for k, v in args.library.items()}[x]
|
66 |
-
)
|
67 |
-
task = st.sidebar.multiselect(
|
68 |
-
"Task", args.pipeline_tag.values(), format_func=lambda x: {v: k for k, v in args.pipeline_tag.items()}[x]
|
69 |
-
)
|
70 |
-
|
71 |
-
### MAIN PAGE
|
72 |
-
st.markdown(
|
73 |
-
"<h1 style='text-align: center; '>ππ€ HF Search Engine</h1>",
|
74 |
-
unsafe_allow_html=True,
|
75 |
-
)
|
76 |
-
|
77 |
-
# Search bar
|
78 |
-
search_query = st.text_input(
|
79 |
-
"Search for a model in HuggingFace", value="", max_chars=None, key=None, type="default"
|
80 |
-
)
|
81 |
-
|
82 |
-
# Search API
|
83 |
-
endpoint = "http://localhost:5000"
|
84 |
-
headers = {
|
85 |
-
"Content-Type": "application/json",
|
86 |
-
"api-key": "password",
|
87 |
-
}
|
88 |
-
search_url = f"{endpoint}/{search_backend}/search"
|
89 |
-
filters = {
|
90 |
-
"library": library,
|
91 |
-
"task": task,
|
92 |
-
}
|
93 |
-
search_body = {
|
94 |
-
"query": search_query,
|
95 |
-
"filters": json.dumps(filters, default=str),
|
96 |
-
"limit": limit_results,
|
97 |
-
}
|
98 |
-
|
99 |
-
if search_query != "":
|
100 |
-
response = requests.post(search_url, headers=headers, json=search_body).json()
|
101 |
-
|
102 |
-
hit_list = []
|
103 |
-
_ = [
|
104 |
-
hit_list.append(
|
105 |
-
{
|
106 |
-
"modelId": hit["modelId"],
|
107 |
-
"tags": hit["tags"],
|
108 |
-
"downloads": hit["downloads"],
|
109 |
-
"likes": hit["likes"],
|
110 |
-
"readme": hit.get("readme", None),
|
111 |
-
}
|
112 |
-
)
|
113 |
-
for hit in response.get("value")
|
114 |
-
]
|
115 |
-
|
116 |
-
|
117 |
-
if hit_list:
|
118 |
-
st.write(f'Search results ({response.get("count")}):')
|
119 |
-
|
120 |
-
if response.get("count") > 100:
|
121 |
-
shown_results = 100
|
122 |
-
else:
|
123 |
-
shown_results = response.get("count")
|
124 |
-
|
125 |
-
for i, hit in paginator(
|
126 |
-
f"Select results (showing {shown_results} of {response.get('count')} results)",
|
127 |
-
hit_list,
|
128 |
-
):
|
129 |
-
col1, col2, col3 = st.columns([5,1,1])
|
130 |
-
col1.metric("Model", hit["modelId"])
|
131 |
-
col2.metric("NΒ° downloads", numerize(hit["downloads"]))
|
132 |
-
col3.metric("NΒ° likes", numerize(hit["likes"]))
|
133 |
-
st.button(f"View model on π€", on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"), key=hit["modelId"])
|
134 |
-
st.write(f"**Tags:** {' β’ '.join(hit['tags'])}")
|
135 |
-
|
136 |
-
if hit["readme"]:
|
137 |
-
with st.expander("See README"):
|
138 |
-
st.write(hit["readme"])
|
139 |
-
|
140 |
-
# TODO: embed huggingface spaces
|
141 |
-
# import streamlit.components.v1 as components
|
142 |
-
# components.html(
|
143 |
-
# f"""
|
144 |
-
# <link rel="stylesheet" href="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.css">
|
145 |
-
# <div id="target"></div>
|
146 |
-
# <script src="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.js"></script>
|
147 |
-
# <script>
|
148 |
-
# launchGradioFromSpaces("abidlabs/question-answering", "#target")
|
149 |
-
# </script>
|
150 |
-
# """,
|
151 |
-
# height=400,
|
152 |
-
# )
|
153 |
-
|
154 |
-
st.markdown("---")
|
155 |
-
|
156 |
-
else:
|
157 |
-
st.write(f"No Search results, please try again with different keywords")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/api.py β st_utils.py
RENAMED
@@ -1,23 +1,12 @@
|
|
1 |
-
from flask import Flask, request
|
2 |
import json
|
3 |
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
|
4 |
from pprint import pprint
|
5 |
from hf_search import hf_search
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
@app.route("/hello")
|
11 |
-
def hello():
|
12 |
-
return "<h1 style='color:blue'>Hello There!</h1>"
|
13 |
-
|
14 |
-
|
15 |
-
@app.route("/hfapi/search", methods=["POST"])
|
16 |
-
def hf_api():
|
17 |
-
request_data = request.get_json()
|
18 |
-
query = request_data.get("query")
|
19 |
-
filters = json.loads(request_data.get("filters"))
|
20 |
-
limit = request_data.get("limit", 5)
|
21 |
print("query", query)
|
22 |
print("filters", filters)
|
23 |
print("limit", limit)
|
@@ -43,15 +32,11 @@ def hf_api():
|
|
43 |
if len(hits) > limit:
|
44 |
hits = hits[:limit]
|
45 |
pprint(hits)
|
46 |
-
return
|
47 |
|
48 |
|
49 |
-
@
|
50 |
-
def semantic_search():
|
51 |
-
request_data = request.get_json()
|
52 |
-
query = request_data.get("query")
|
53 |
-
filters = json.loads(request_data.get("filters"))
|
54 |
-
limit = request_data.get("limit", 5)
|
55 |
print("query", query)
|
56 |
print("filters", filters)
|
57 |
print("limit", limit)
|
@@ -67,14 +52,11 @@ def semantic_search():
|
|
67 |
}
|
68 |
for hit in hits
|
69 |
]
|
70 |
-
return
|
|
|
71 |
|
72 |
-
@
|
73 |
-
def bm25_search():
|
74 |
-
request_data = request.get_json()
|
75 |
-
query = request_data.get("query")
|
76 |
-
filters = json.loads(request_data.get("filters"))
|
77 |
-
limit = request_data.get("limit", 5)
|
78 |
print("query", query)
|
79 |
print("filters", filters)
|
80 |
print("limit", limit)
|
@@ -91,9 +73,47 @@ def bm25_search():
|
|
91 |
}
|
92 |
for hit in hits
|
93 |
]
|
94 |
-
hits = [
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
app.run(host="localhost", port=5000)
|
|
|
|
|
1 |
import json
|
2 |
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
|
3 |
from pprint import pprint
|
4 |
from hf_search import hf_search
|
5 |
+
import streamlit as st
|
6 |
+
import itertools
|
7 |
|
8 |
+
@st.cache
|
9 |
+
def hf_api(query, limit=5, filters={}):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
print("query", query)
|
11 |
print("filters", filters)
|
12 |
print("limit", limit)
|
|
|
32 |
if len(hits) > limit:
|
33 |
hits = hits[:limit]
|
34 |
pprint(hits)
|
35 |
+
return {"hits": hits, "count": count}
|
36 |
|
37 |
|
38 |
+
@st.cache
|
39 |
+
def semantic_search(query, limit=5, filters={}):
|
|
|
|
|
|
|
|
|
40 |
print("query", query)
|
41 |
print("filters", filters)
|
42 |
print("limit", limit)
|
|
|
52 |
}
|
53 |
for hit in hits
|
54 |
]
|
55 |
+
return {"hits": hits, "count": len(hits)}
|
56 |
+
|
57 |
|
58 |
+
@st.cache
|
59 |
+
def bm25_search(query, limit=5, filters={}):
|
|
|
|
|
|
|
|
|
60 |
print("query", query)
|
61 |
print("filters", filters)
|
62 |
print("limit", limit)
|
|
|
73 |
}
|
74 |
for hit in hits
|
75 |
]
|
76 |
+
hits = [
|
77 |
+
hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]
|
78 |
+
] # unique hits
|
79 |
+
return {"hits": hits, "count": len(hits)}
|
80 |
+
|
81 |
+
|
82 |
+
def paginator(label, articles, articles_per_page=10, on_sidebar=True):
|
83 |
+
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
|
84 |
+
"""Lets the user paginate a set of article.
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
label : str
|
88 |
+
The label to display over the pagination widget.
|
89 |
+
article : Iterator[Any]
|
90 |
+
The articles to display in the paginator.
|
91 |
+
articles_per_page: int
|
92 |
+
The number of articles to display per page.
|
93 |
+
on_sidebar: bool
|
94 |
+
Whether to display the paginator widget on the sidebar.
|
95 |
+
|
96 |
+
Returns
|
97 |
+
-------
|
98 |
+
Iterator[Tuple[int, Any]]
|
99 |
+
An iterator over *only the article on that page*, including
|
100 |
+
the item's index.
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Figure out where to display the paginator
|
104 |
+
if on_sidebar:
|
105 |
+
location = st.sidebar.empty()
|
106 |
+
else:
|
107 |
+
location = st.empty()
|
108 |
+
|
109 |
+
# Display a pagination selectbox in the specified location.
|
110 |
+
articles = list(articles)
|
111 |
+
n_pages = (len(articles) - 1) // articles_per_page + 1
|
112 |
+
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
|
113 |
+
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)
|
114 |
|
115 |
+
# Iterate over the articles in the page to let the user display them.
|
116 |
+
min_index = page_number * articles_per_page
|
117 |
+
max_index = min_index + articles_per_page
|
118 |
|
119 |
+
return itertools.islice(enumerate(articles), min_index, max_index)
|
|