taskswithcode commited on
Commit
e45aab0
1 Parent(s): 002e3e5
Files changed (2) hide show
  1. app.py +35 -18
  2. sim_app_models.json +5 -5
app.py CHANGED
@@ -59,7 +59,8 @@ def get_views(action):
59
 
60
  def construct_model_info_for_display(model_names):
61
  options_arr = []
62
- markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b></div>"
 
63
  for node in model_names:
64
  options_arr .append(node["name"])
65
  if (node["mark"] == "True"):
@@ -86,17 +87,15 @@ with col:
86
 
87
 
88
  @st.experimental_memo
89
- def load_model(model_name,model_names):
90
  try:
91
  ret_model = None
92
- for node in model_names:
93
- if (model_name.startswith(node["name"])):
94
- obj_class = globals()[node["class"]]
95
- ret_model = obj_class()
96
- ret_model.init_model(node["model"])
97
  assert(ret_model is not None)
98
  except Exception as e:
99
- st.error("Unable to load model:" + model_name + " " + str(e))
100
  pass
101
  return ret_model
102
 
@@ -115,18 +114,26 @@ def uncached_compute_similarity(sentences,_model,model_name,main_index):
115
  #st.success("Similarity computation complete")
116
  return results
117
 
 
118
  def get_model_info(model_names,model_name):
119
  for node in model_names:
120
  if (model_name == node["name"]):
121
- return node
 
122
 
123
- def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded):
124
  display_area.text("Loading model:" + model_name)
125
- model_info = get_model_info(model_names,model_name)
 
 
 
 
 
 
126
  if ("Note" in model_info):
127
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
128
  display_area.write(fail_link)
129
- model = load_model(model_name,model_names)
130
  display_area.text("Model " + model_name + " load complete")
131
  try:
132
  if (user_uploaded):
@@ -146,9 +153,10 @@ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploa
146
 
147
 
148
 
149
- def display_results(orig_sentences,main_index,results,response_info,app_mode):
150
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
151
- score_text = "cosine_distance" if app_mode == SEM_SIMILARITY else "cosine_distance/score"
 
152
  pivot_name = "main sentence" if app_mode == SEM_SIMILARITY else "query"
153
  main_sent += f"<div style=\"font-size:14px; color: #6f6f6f; text-align: left\">Results sorted by {score_text}. Closest to furthest away from {pivot_name}</div>"
154
  pivot_name = pivot_name[0].upper() + pivot_name[1:]
@@ -183,7 +191,8 @@ def app_main(app_mode,example_files,model_name_files):
183
  with open(model_name_files) as fp:
184
  model_names = json.load(fp)
185
  curr_use_case = use_case[app_mode].split(".")[0]
186
- st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
 
187
  st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['1']}<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
188
  st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
189
 
@@ -206,6 +215,9 @@ def app_main(app_mode,example_files,model_name_files):
206
  selected_model = st.selectbox(label=selection_label,
207
  options = options_arr, index=0, key = "twc_model")
208
  st.write("")
 
 
 
209
  if (app_mode == SEM_SIMILARITY):
210
  main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence',value=1,min_value = 1)
211
  else:
@@ -231,14 +243,19 @@ def app_main(app_mode,example_files,model_name_files):
231
  if (len(sentences) > MAX_INPUT):
232
  st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
233
  sentences = sentences[:MAX_INPUT]
234
- st.session_state["model_name"] = selected_model
 
 
 
 
235
  st.session_state["main_index"] = main_index
236
- results = run_test(model_names,selected_model,sentences,display_area,main_index - 1,(uploaded_file is not None))
 
237
  display_area.empty()
238
  with display_area.container():
239
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
240
  response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
241
- display_results(sentences,main_index - 1,results,response_info,app_mode)
242
  #st.json(results)
243
  st.download_button(
244
  label="Download results as json",
 
59
 
60
  def construct_model_info_for_display(model_names):
61
  options_arr = []
62
+ markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>These are either state-of-the-art or the most downloaded models on Huggingface</i></div>"
63
+ markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
64
  for node in model_names:
65
  options_arr .append(node["name"])
66
  if (node["mark"] == "True"):
 
87
 
88
 
89
  @st.experimental_memo
90
+ def load_model(model_name,model_class,load_model_name):
91
  try:
92
  ret_model = None
93
+ obj_class = globals()[model_class]
94
+ ret_model = obj_class()
95
+ ret_model.init_model(load_model_name)
 
 
96
  assert(ret_model is not None)
97
  except Exception as e:
98
+ st.error("Unable to load model:" + model_name + " " + load_model_name + " " + str(e))
99
  pass
100
  return ret_model
101
 
 
114
  #st.success("Similarity computation complete")
115
  return results
116
 
117
+ DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
118
  def get_model_info(model_names,model_name):
119
  for node in model_names:
120
  if (model_name == node["name"]):
121
+ return node,model_name
122
+ return get_model_info(model_names,DEFAULT_HF_MODEL)
123
 
124
+ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded,custom_model):
125
  display_area.text("Loading model:" + model_name)
126
+ #Note. model_name may get mapped to new name in the call below for custom models
127
+ orig_model_name = model_name
128
+ model_info,model_name = get_model_info(model_names,model_name)
129
+ if (model_name != orig_model_name):
130
+ load_model_name = orig_model_name
131
+ else:
132
+ load_model_name = model_info["model"]
133
  if ("Note" in model_info):
134
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
135
  display_area.write(fail_link)
136
+ model = load_model(model_name,model_info["class"],load_model_name)
137
  display_area.text("Model " + model_name + " load complete")
138
  try:
139
  if (user_uploaded):
 
153
 
154
 
155
 
156
+ def display_results(orig_sentences,main_index,results,response_info,app_mode,model_name):
157
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
158
+ main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model:&nbsp;<b>{model_name}</b></div>"
159
+ score_text = "cosine distance" if app_mode == SEM_SIMILARITY else "cosine distance/score"
160
  pivot_name = "main sentence" if app_mode == SEM_SIMILARITY else "query"
161
  main_sent += f"<div style=\"font-size:14px; color: #6f6f6f; text-align: left\">Results sorted by {score_text}. Closest to furthest away from {pivot_name}</div>"
162
  pivot_name = pivot_name[0].upper() + pivot_name[1:]
 
191
  with open(model_name_files) as fp:
192
  model_names = json.load(fp)
193
  curr_use_case = use_case[app_mode].split(".")[0]
194
+ st.markdown("<h5 style='text-align: center;'>Compare state-of-the-art/popular models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
195
+ st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
196
  st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['1']}<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
197
  st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
198
 
 
215
  selected_model = st.selectbox(label=selection_label,
216
  options = options_arr, index=0, key = "twc_model")
217
  st.write("")
218
+ custom_model_selection = st.text_input("Model not listed on above? Type any Huggingface sentence similarity model name ", "",key="custom_model")
219
+ hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence similarity models</a><br/><br/><br/></div>"
220
+ st.markdown(hf_link_str, unsafe_allow_html=True)
221
  if (app_mode == SEM_SIMILARITY):
222
  main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence',value=1,min_value = 1)
223
  else:
 
243
  if (len(sentences) > MAX_INPUT):
244
  st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
245
  sentences = sentences[:MAX_INPUT]
246
+ if (len(custom_model_selection) != 0):
247
+ run_model = custom_model_selection
248
+ else:
249
+ run_model = selected_model
250
+ st.session_state["model_name"] = run_model
251
  st.session_state["main_index"] = main_index
252
+
253
+ results = run_test(model_names,run_model,sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
254
  display_area.empty()
255
  with display_area.container():
256
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
257
  response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
258
+ display_results(sentences,main_index - 1,results,response_info,app_mode,run_model)
259
  #st.json(results)
260
  st.download_button(
261
  label="Download results as json",
sim_app_models.json CHANGED
@@ -6,7 +6,7 @@
6
  "orig_author_url":"https://github.com/UKPLab",
7
  "orig_author":"Ubiquitous Knowledge Processing Lab",
8
  "sota_info": {
9
- "task":"Over 3.8 million downloads from huggingface",
10
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
11
  },
12
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -18,7 +18,7 @@
18
  "orig_author_url":"https://github.com/UKPLab",
19
  "orig_author":"Ubiquitous Knowledge Processing Lab",
20
  "sota_info": {
21
- "task":"Over 2 million downloads from huggingface",
22
  "sota_link":"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2"
23
  },
24
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -30,7 +30,7 @@
30
  "orig_author_url":"https://github.com/UKPLab",
31
  "orig_author":"Ubiquitous Knowledge Processing Lab",
32
  "sota_info": {
33
- "task":"Over 700,000 downloads from huggingface",
34
  "sota_link":"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens"
35
  },
36
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -42,7 +42,7 @@
42
  "orig_author_url":"https://github.com/UKPLab",
43
  "orig_author":"Ubiquitous Knowledge Processing Lab",
44
  "sota_info": {
45
- "task":"Over 500,000 downloads from huggingface",
46
  "sota_link":"https://huggingface.co/sentence-transformers/all-mpnet-base-v2"
47
  },
48
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -54,7 +54,7 @@
54
  "orig_author_url":"https://github.com/UKPLab",
55
  "orig_author":"Ubiquitous Knowledge Processing Lab",
56
  "sota_info": {
57
- "task":"Over 500,000 downloads from huggingface",
58
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2"
59
  },
60
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
6
  "orig_author_url":"https://github.com/UKPLab",
7
  "orig_author":"Ubiquitous Knowledge Processing Lab",
8
  "sota_info": {
9
+ "task":"Over 3.8 million downloads from Huggingface",
10
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
11
  },
12
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
18
  "orig_author_url":"https://github.com/UKPLab",
19
  "orig_author":"Ubiquitous Knowledge Processing Lab",
20
  "sota_info": {
21
+ "task":"Over 2 million downloads from Huggingface",
22
  "sota_link":"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2"
23
  },
24
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
30
  "orig_author_url":"https://github.com/UKPLab",
31
  "orig_author":"Ubiquitous Knowledge Processing Lab",
32
  "sota_info": {
33
+ "task":"Over 700,000 downloads from Huggingface",
34
  "sota_link":"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens"
35
  },
36
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
42
  "orig_author_url":"https://github.com/UKPLab",
43
  "orig_author":"Ubiquitous Knowledge Processing Lab",
44
  "sota_info": {
45
+ "task":"Over 500,000 downloads from Huggingface",
46
  "sota_link":"https://huggingface.co/sentence-transformers/all-mpnet-base-v2"
47
  },
48
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
54
  "orig_author_url":"https://github.com/UKPLab",
55
  "orig_author":"Ubiquitous Knowledge Processing Lab",
56
  "sota_info": {
57
+ "task":"Over 500,000 downloads from Huggingface",
58
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2"
59
  },
60
  "paper_url":"https://arxiv.org/abs/1908.10084",