import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np import torch import arxiv def main(): id_provided = True st.set_page_config( layout="wide", initial_sidebar_state="auto", page_title="Title Generator!", page_icon=None, ) st.title("Title Generator: Generate a title from the abstract of a paper") st.text("") st.text("") example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""", ) # st.selectbox("Provide the link/id for an arxiv paper", example_prompts) # Take the message which needs to be processed message = st.text_area("...or paste a paper's abstract to generate a title") if len(message)<1: message=example id_provided = True ids = message.split('/')[-1] search = arxiv.Search(id_list=[ids]) for result in search.results(): message = result.summary title = result.title else: id_provided = False st.text("") models_to_choose = [ "AryanLala/autonlp-Scientific_Title_Generator-34558227", "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full" ] BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose) def preprocess(text): if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") | (BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")): return [text] else: st.error("Please select a model first") @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False) def load_model(): tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) return model, tokenizer def get_summary(text): with st.spinner(text="Processing your request"): model, tokenizer = load_model() preprocessed = preprocess(text) inputs = tokenizer( preprocessed, truncation=True, padding="longest", return_tensors="pt" ) output = model.generate( **inputs, max_length=60, num_beams=10, num_return_sequences=1, temperature=1.5, ) target_text = tokenizer.batch_decode(output, skip_special_tokens=True) return target_text[0] # Define function to run when submit is clicked def submit(message): if len(message) > 0: summary = get_summary(message) if id_provided: html_str = f""" <style> p.a {{ font: 20px Courier; }} </style> <p class="a"><b>Title Generated:></b> {summary} </p> <p class="a"><b>Original Title:></b> {title} </p> """ else: html_str = f""" <style> p.a {{ font: 20px Courier; }} </style> <p class="a"><b>Title Generated:></b> {summary} </p> """ st.markdown(html_str, unsafe_allow_html=True) # st.markdown(emoji) else: st.error("The text can't be empty") # Run algo when submit button is clicked if st.button("Submit"): submit(message) with st.expander("Additional Information"): st.markdown(""" The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset) The task of the models is to suggest an appropraite title from the abstract of a scientific paper. The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data from the Cs.AI (Artificial Intelligence) category of papers. The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full) was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition) Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids """,unsafe_allow_html=True,) st.text('\n') st.text('\n') st.markdown( '''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89) </span>''', unsafe_allow_html=True, ) if __name__ == "__main__": main()