''' this is for existing formulae retrieval ''' import gradio as gr import pickle from pdb import set_trace from retrieval_utils import * ### for the webapp title = "math equation retrieval demo using the OpenStax Calculus textbook" description = "This is a demo for math equation retrieval based on research developed at Rice University. Click on one of the examples or type an equation of your own. Then click submit to see the retrieved equation along with its surrounding context. The retrieved equation is marked in red. Currently DOES NOT support single symbol retrieval. Demo uses the OpenStax calculus textbook. " # article = "Warning and disclaimer: Currently we do not gaurantee the generated questions to be good 100 percent of the time. Also the model may generate content that is biased, hostile, or toxic; be extra careful about what you provide as input. We are not responsible for any impacts that the generated content may have.

Developed at Rice University and OpenStax with gradio and OpenAI API

" examples = [ ["""y=\\text{sin}\\phantom{\\rule{0.1em}{0ex}}x"""], ["""\\left[a,b\\right]"""], ["""20g"""], ["""x=5"""], ["""f\\left(x\\right)=\\frac{1}{\\sqrt{1+{x}^{2}}}"""], ["""P\\left(x\\right)=30x-0.3{x}^{2}-250"""], ["""\\epsilon =0.8;"""], ["""{x}_{n}=\\frac{{x}_{n-1}}{2}+\\frac{3}{2{x}_{n-1}}"""], ["""y=f\\left(x\\right),y=1+f\\left(x\\right),x=0,y=0"""], ["""\\frac{1}{2}du=d\\theta"""] ] ######################## for the retrieval model ####################### ######################################################################## ## configs def retrieval_fn(inp): _top=1000 # top nodes as vocab _min_nodes=3 #3, 5 _max_nodes=150 #150, 10 _max_children=8 # 1 is actually 2 children because indexing starts at 0 _max_depth=19 _simple = False # non typed nodes; all tail nodes are converted to UNKNOWN simple = 'simple' if _simple else 'complex' _w_unknown = True # typed nodes but with unknown tail for "other". cannot be used together with "simple" _collection = 'OS_calculus' # 'ARQMath' # 'WikiTopiceqArxiv' simple = False #if data_path.split('_')[-1] == 'simple' else False w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO use_softmax = True # TODO topN=5 result_path = '/mnt/math-text-embedding/math_embedding/results_retrieval/jack@51.143.92.116/2020-12-15_06-09-44_eqsWikiTopiceqArxiv-RNN-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19-onehot2010_w_unknown_softmax' data_path = '/mnt/math-text-embedding/math_embedding/openstax_retrieval_demo/model_input_data/eqsOS_calculus-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19_w_unknown' simple = False #if data_path.split('_')[-1] == 'simple' else False w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO use_softmax = True # TODO vocab, vocab_T, vocab_N, vocab_V, vocab_other, stoi, itos, eqs, encoder_src, encoder_pos, \ _, _, _, _, max_n_children = load_all_data(data_path) with open('tree_to_eq_dict.pkl', 'rb') as f: tree_to_eq_dict = pickle.load(f) with open('eq_to_context_dict.pkl', 'rb') as f: eq_to_context_dict = pickle.load(f) ## load model - rnn print('load model ...') # Configure models hidden_size = 500 n_layers = 2 dropout = 0.1 pad_idx = stoi[PAD] # Initialize models pos_embedding = PositionalEncoding(max_depth=20, max_children=10, mode='onehot') # encoder = EncoderRNN(len(vocab), hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax) encoder = EncoderRNN(4009, hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax) encoder.load_state_dict(torch.load(os.path.join(result_path, 'encoder_best.pt'))) encoder.cuda(); encoder.eval(); # retrieval retrieval_result_eqs, retrieval_result_context, cos_values = retrieval(inp, eqs, tree_to_eq_dict, eq_to_context_dict, encoder_src, encoder_pos, encoder, stoi, vocab_N, vocab_V, vocab_T, vocab_other, topN, _top, _min_nodes, _max_nodes, _max_children, _max_depth, _simple, _w_unknown) return retrieval_result_context def generate(inp): retrieved = retrieval_fn(inp) new_r = [] for r in retrieved: p = r[0:250] n = r[-250:] m = r[250:-250] new_r.append(('..... ' + p, None)) new_r.append((m, 'retrieved equation')) new_r.append((n+' ......\r\n\r\n\r\n', None)) # output = '\n\n\n'.join(retrieval_result_contexts['$$' + inp + '$$'][0:10]) return new_r counter = 0 def generate_html(inp): global counter retrieved = retrieval_fn(inp) html_txt = '' for r in retrieved: p = r[0:250] n = r[-250:] m = r[250:-250] html_txt += '

...... ' + p.replace('<', '<').replace('>', '>') + '{}'.format(m.replace('<', '<').replace('>', '>')) + n.replace('<', '<').replace('>', '>') + ' ...



' html = "" + html_txt + "" html = html.replace('$$', '') # print(html) counter += 1 print(counter) return html # output_text = gr.outputs.Textbox() # output_text = gr.outputs.HighlightedText()#color_map={"retrieved equation": "red"}) output_text = gr.outputs.HTML()#color_map={"retrieved equation": "red"}) gr.Interface(fn=generate_html, inputs=gr.inputs.Textbox(lines=5, label="Input Text"), outputs=["html"],#output_text, title=title, description=description, # article=article, examples=examples).launch(share=True)