import time import streamlit as st import string from io import StringIO import pdb import json from twc_embeddings import HFModel,SimCSEModel,SGPTModel import torch MAX_INPUT = 100 from transformers import BertTokenizer, BertForMaskedLM model_names = [ { "name":"sentence-transformers/all-MiniLM-L6-v2", "model":"sentence-transformers/all-MiniLM-L6-v2", "fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model", "orig_author_url":"https://github.com/UKPLab", "orig_author":"Ubiquitous Knowledge Processing Lab", "sota_info": { "task":"Over 3.8 million downloads from huggingface", "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2" }, "paper_url":"https://arxiv.org/abs/1908.10084", "mark":True, "class":"HFModel"}, { "name":"sentence-transformers/paraphrase-MiniLM-L6-v2", "model":"sentence-transformers/paraphrase-MiniLM-L6-v2", "fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model", "orig_author_url":"https://github.com/UKPLab", "orig_author":"Ubiquitous Knowledge Processing Lab", "sota_info": { "task":"Over 2.4 million downloads from huggingface", "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2" }, "paper_url":"https://arxiv.org/abs/1908.10084", "mark":True, "class":"HFModel"}, { "name":"sentence-transformers/bert-base-nli-mean-tokens", "model":"sentence-transformers/bert-base-nli-mean-tokens", "fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model", "orig_author_url":"https://github.com/UKPLab", "orig_author":"Ubiquitous Knowledge Processing Lab", "sota_info": { "task":"Over 700,000 downloads from huggingface", "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2" }, "paper_url":"https://arxiv.org/abs/1908.10084", "mark":True, "class":"HFModel"}, { "name":"sentence-transformers/all-mpnet-base-v2", "model":"sentence-transformers/all-mpnet-base-v2", "fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model", "orig_author_url":"https://github.com/UKPLab", "orig_author":"Ubiquitous Knowledge Processing Lab", "sota_info": { "task":"Over 500,000 downloads from huggingface", "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2" }, "paper_url":"https://arxiv.org/abs/1908.10084", "mark":True, "class":"HFModel"}, { "name":"SGPT-125M", "model":"Muennighoff/SGPT-125M-weightedmean-nli-bitfit", "fork_url":"https://github.com/taskswithcode/sgpt", "orig_author_url":"https://github.com/Muennighoff", "orig_author":"Niklas Muennighoff", "sota_info": { "task":"#1 in multiple information retrieval & search tasks(smaller variant)", "sota_link":"https://paperswithcode.com/paper/sgpt-gpt-sentence-embeddings-for-semantic", }, "paper_url":"https://arxiv.org/abs/2202.08904v5", "mark":True, "class":"SGPTModel"}, { "name":"SGPT-1.3B", "model": "Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit", "fork_url":"https://github.com/taskswithcode/sgpt", "orig_author_url":"https://github.com/Muennighoff", "orig_author":"Niklas Muennighoff", "sota_info": { "task":"#1 in multiple information retrieval & search tasks(smaller variant)", "sota_link":"https://paperswithcode.com/paper/sgpt-gpt-sentence-embeddings-for-semantic", }, "paper_url":"https://arxiv.org/abs/2202.08904v5", "Note":"If this large model takes too long or fails to load , try this ", "alt_url":"http://www.taskswithcode.com/sentence_similarity/", "mark":True, "class":"SGPTModel"}, { "name":"SGPT-5.8B", "model": "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit" , "fork_url":"https://github.com/taskswithcode/sgpt", "orig_author_url":"https://github.com/Muennighoff", "orig_author":"Niklas Muennighoff", "Note":"If this large model takes too long or fails to load , try this ", "alt_url":"http://www.taskswithcode.com/sentence_similarity/", "sota_info": { "task":"#1 in multiple information retrieval & search tasks", "sota_link":"https://paperswithcode.com/paper/sgpt-gpt-sentence-embeddings-for-semantic", }, "paper_url":"https://arxiv.org/abs/2202.08904v5", "mark":True, "class":"SGPTModel"}, { "name":"SIMCSE-large" , "model":"princeton-nlp/sup-simcse-roberta-large", "fork_url":"https://github.com/taskswithcode/SimCSE", "orig_author_url":"https://github.com/princeton-nlp", "orig_author":"Princeton Natural Language Processing", "sota_info": { "task":"Within top 10 in multiple semantic textual similarity tasks", "sota_link":"https://paperswithcode.com/paper/simcse-simple-contrastive-learning-of" }, "paper_url":"https://arxiv.org/abs/2104.08821v4", "mark":True, "class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}, { "name":"SIMCSE-base" , "model":"princeton-nlp/sup-simcse-roberta-base", "fork_url":"https://github.com/taskswithcode/SimCSE", "orig_author_url":"https://github.com/princeton-nlp", "orig_author":"Princeton Natural Language Processing", "sota_info": { "task":"Within top 10 in multiple semantic textual similarity tasks(smaller variant)", "sota_link":"https://paperswithcode.com/paper/simcse-simple-contrastive-learning-of" }, "paper_url":"https://arxiv.org/abs/2104.08821v4", "mark":True, "class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}, ] example_file_names = { "Machine learning terms (30+ phrases)": "small_test.txt", "Customer feedback mixed with noise (50+ sentences)":"larger_test.txt" } def construct_model_info_for_display(): options_arr = [] markdown_str = "

Models evaluated
" for node in model_names: options_arr .append(node["name"]) if (node["mark"] == True): markdown_str += f"
 • Model: {node['name']}
    Code released by: {node['orig_author']}
    Model info: {node['sota_info']['task']}
" if ("Note" in node): markdown_str += f"
    {node['Note']}link
" markdown_str += "

" markdown_str += "
Note:
• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not saved
" limit = "{:,}".format(MAX_INPUT) markdown_str += f"
• User uploaded file has a maximum limit of {limit} sentences.
" return options_arr,markdown_str st.set_page_config(page_title='TWC - Compare popular/state-of-the-art models for Sentence Similarity task', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto', menu_items={ 'About': 'This app was created by taskswithcode. http://taskswithcode.com' }) col,pad = st.columns([85,15]) with col: st.image("long_form_logo_with_icon.png") @st.experimental_memo def load_model(model_name): try: ret_model = None for node in model_names: if (model_name.startswith(node["name"])): obj_class = globals()[node["class"]] ret_model = obj_class() ret_model.init_model(node["model"]) assert(ret_model is not None) except Exception as e: st.error("Unable to load model:" + model_name + " " + str(e)) pass return ret_model @st.experimental_memo def cached_compute_similarity(sentences,_model,model_name,main_index): texts,embeddings = _model.compute_embeddings(sentences,is_file=False) results = _model.output_results(None,texts,embeddings,main_index) return results def uncached_compute_similarity(sentences,_model,model_name,main_index): with st.spinner('Computing vectors for sentences'): texts,embeddings = _model.compute_embeddings(sentences,is_file=False) results = _model.output_results(None,texts,embeddings,main_index) #st.success("Similarity computation complete") return results def get_model_info(model_name): for node in model_names: if (model_name == node["name"]): return node def run_test(model_name,sentences,display_area,main_index,user_uploaded): display_area.text("Loading model:" + model_name) model_info = get_model_info(model_name) if ("Note" in model_info): fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})" display_area.write(fail_link) model = load_model(model_name) display_area.text("Model " + model_name + " load complete") try: if (user_uploaded): results = uncached_compute_similarity(sentences,model,model_name,main_index) else: display_area.text("Computing vectors for sentences") results = cached_compute_similarity(sentences,model,model_name,main_index) display_area.text("Similarity computation complete") return results except Exception as e: st.error("Some error occurred during prediction" + str(e)) st.stop() return {} def display_results(orig_sentences,main_index,results,response_info): main_sent = f"
{response_info}

" main_sent += "
Results sorted by cosine distance. Closest(1) to furthest(-1) away from main sentence
" main_sent += f"
Main sentence:  {orig_sentences[main_index]}
" body_sent = [] download_data = {} for key in results: index = orig_sentences.index(key) + 1 body_sent.append(f"
{index}] {key}   {results[key]:.2f}
") download_data[key] = f"{results[key]:.2f}" main_sent = main_sent + "\n" + '\n'.join(body_sent) st.markdown(main_sent,unsafe_allow_html=True) st.session_state["download_ready"] = json.dumps(download_data,indent=4) def init_session(): st.session_state["download_ready"] = None st.session_state["model_name"] = "ss_test" st.session_state["main_index"] = 1 st.session_state["file_name"] = "default" def main(): init_session() st.markdown("
Compare popular/state-of-the-art models for Sentence Similarity task
", unsafe_allow_html=True) try: with st.form('twc_form'): uploaded_file = st.file_uploader("Step 1. Upload text file(one sentence in a line) or choose an example text file below.", type=".txt") selected_file_index = st.selectbox(label='Example files ', options = list(dict.keys(example_file_names)), index=0, key = "twc_file") st.write("") options_arr,markdown_str = construct_model_info_for_display() selected_model = st.selectbox(label='Step 2. Select Model', options = options_arr, index=0, key = "twc_model") st.write("") main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence:',value=1,min_value = 1) st.write("") submit_button = st.form_submit_button('Run') input_status_area = st.empty() display_area = st.empty() if submit_button: start = time.time() if uploaded_file is not None: st.session_state["file_name"] = uploaded_file.name sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read() else: st.session_state["file_name"] = example_file_names[selected_file_index] sentences = open(example_file_names[selected_file_index]).read() sentences = sentences.split("\n")[:-1] if (len(sentences) < main_index): main_index = len(sentences) st.info("Selected sentence index is larger than number of sentences in file. Truncating to " + str(main_index)) if (len(sentences) > MAX_INPUT): st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen") sentences = sentences[:MAX_INPUT] st.session_state["model_name"] = selected_model st.session_state["main_index"] = main_index results = run_test(selected_model,sentences,display_area,main_index - 1,(uploaded_file is not None)) display_area.empty() with display_area.container(): device = 'gpu' if torch.cuda.is_available() else 'cpu' response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences" display_results(sentences,main_index - 1,results,response_info) #st.json(results) st.download_button( label="Download results as json", data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "", disabled = False if st.session_state["download_ready"] != None else True, file_name= (st.session_state["model_name"] + "_" + str(st.session_state["main_index"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"), mime='text/json', key ="download" ) except Exception as e: st.error("Some error occurred during loading" + str(e)) st.stop() st.markdown(markdown_str, unsafe_allow_html=True) if __name__ == "__main__": main()