moonlightlane commited on
Commit
47b0e32
·
1 Parent(s): a4905ae

Add application file

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ this is for existing formulae retrieval
3
+ '''
4
+
5
+ import gradio as gr
6
+ import pickle
7
+ from pdb import set_trace
8
+ from retrieval_utils import *
9
+
10
+
11
+ ### for the webapp
12
+ title = "math equation retrieval demo using the OpenStax Calculus textbook"
13
+ description = "This is a demo for math equation retrieval based on research developed at Rice University. Click on one of the examples or type an equation of your own. Then click submit to see the retrieved equation along with its surrounding context. The retrieved equation is marked in red. Currently DOES NOT support single symbol retrieval. Demo uses the OpenStax calculus textbook. "
14
+ # article = "<b>Warning and disclaimer:</b> Currently we do not gaurantee the generated questions to be good 100 percent of the time. Also the model may generate content that is biased, hostile, or toxic; be extra careful about what you provide as input. We are not responsible for any impacts that the generated content may have.<p style='text-align: center'>Developed at <a href='https://rice.edu/'>Rice University</a> and <a href='https://openstax.org/'>OpenStax</a> with <a href='https://gradio.app/'>gradio</a> and <a href='https://beta.openai.com/'>OpenAI API</a></p>"
15
+ examples = [
16
+ ["""y=\\text{sin}\\phantom{\\rule{0.1em}{0ex}}x"""],
17
+ ["""\\left[a,b\\right]"""],
18
+ ["""20g"""],
19
+ ["""x=5"""],
20
+ ["""f\\left(x\\right)=\\frac{1}{\\sqrt{1+{x}^{2}}}"""],
21
+ ["""P\\left(x\\right)=30x-0.3{x}^{2}-250"""],
22
+ ["""\\epsilon =0.8;"""],
23
+ ["""{x}_{n}=\\frac{{x}_{n-1}}{2}+\\frac{3}{2{x}_{n-1}}"""],
24
+ ["""y=f\\left(x\\right),y=1+f\\left(x\\right),x=0,y=0"""],
25
+ ["""\\frac{1}{2}du=d\\theta"""]
26
+ ]
27
+
28
+
29
+
30
+ ######################## for the retrieval model #######################
31
+ ########################################################################
32
+ ## configs
33
+
34
+ def retrieval_fn(inp):
35
+ _top=1000 # top nodes as vocab
36
+ _min_nodes=3 #3, 5
37
+ _max_nodes=150 #150, 10
38
+ _max_children=8 # 1 is actually 2 children because indexing starts at 0
39
+ _max_depth=19
40
+ _simple = False # non typed nodes; all tail nodes are converted to UNKNOWN
41
+ simple = 'simple' if _simple else 'complex'
42
+ _w_unknown = True # typed nodes but with unknown tail for "other". cannot be used together with "simple"
43
+ _collection = 'OS_calculus' # 'ARQMath' # 'WikiTopiceqArxiv'
44
+ simple = False #if data_path.split('_')[-1] == 'simple' else False
45
+ w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
46
+ use_softmax = True # TODO
47
+ topN=5
48
+ result_path = '/mnt/math-text-embedding/math_embedding/results_retrieval/jack@51.143.92.116/2020-12-15_06-09-44_eqsWikiTopiceqArxiv-RNN-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19-onehot2010_w_unknown_softmax'
49
+ data_path = '/mnt/math-text-embedding/math_embedding/openstax_retrieval_demo/model_input_data/eqsOS_calculus-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19_w_unknown'
50
+ simple = False #if data_path.split('_')[-1] == 'simple' else False
51
+ w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
52
+ use_softmax = True # TODO
53
+
54
+ vocab, vocab_T, vocab_N, vocab_V, vocab_other, stoi, itos, eqs, encoder_src, encoder_pos, \
55
+ _, _, _, _, max_n_children = load_all_data(data_path)
56
+
57
+ with open('tree_to_eq_dict.pkl', 'rb') as f:
58
+ tree_to_eq_dict = pickle.load(f)
59
+ with open('eq_to_context_dict.pkl', 'rb') as f:
60
+ eq_to_context_dict = pickle.load(f)
61
+
62
+ ## load model - rnn
63
+ print('load model ...')
64
+ # Configure models
65
+ hidden_size = 500
66
+ n_layers = 2
67
+ dropout = 0.1
68
+ pad_idx = stoi[PAD]
69
+
70
+ # Initialize models
71
+ pos_embedding = PositionalEncoding(max_depth=20, max_children=10, mode='onehot')
72
+ # encoder = EncoderRNN(len(vocab), hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
73
+ encoder = EncoderRNN(4009, hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
74
+ encoder.load_state_dict(torch.load(os.path.join(result_path, 'encoder_best.pt')))
75
+ encoder.cuda();
76
+ encoder.eval();
77
+
78
+ # retrieval
79
+ retrieval_result_eqs, retrieval_result_context, cos_values = retrieval(inp, eqs, tree_to_eq_dict, eq_to_context_dict, encoder_src, encoder_pos, encoder,
80
+ stoi, vocab_N, vocab_V, vocab_T, vocab_other, topN,
81
+ _top, _min_nodes, _max_nodes, _max_children, _max_depth, _simple, _w_unknown)
82
+
83
+ return retrieval_result_context
84
+
85
+
86
+ def generate(inp):
87
+
88
+ retrieved = retrieval_fn(inp)
89
+ new_r = []
90
+ for r in retrieved:
91
+ p = r[0:250]
92
+ n = r[-250:]
93
+ m = r[250:-250]
94
+ new_r.append(('..... ' + p, None))
95
+ new_r.append((m, 'retrieved equation'))
96
+ new_r.append((n+' ......\r\n\r\n\r\n', None))
97
+
98
+ # output = '\n\n\n'.join(retrieval_result_contexts['$$' + inp + '$$'][0:10])
99
+
100
+ return new_r
101
+
102
+ counter = 0
103
+ def generate_html(inp):
104
+
105
+ global counter
106
+
107
+ retrieved = retrieval_fn(inp)
108
+ html_txt = ''
109
+ for r in retrieved:
110
+ p = r[0:250]
111
+ n = r[-250:]
112
+ m = r[250:-250]
113
+ html_txt += '<p>...... ' + p.replace('<', '&lt;').replace('>', '&gt;') + '<span style="color: #ff0000">{}</span>'.format(m.replace('<', '&lt;').replace('>', '&gt;')) + n.replace('<', '&lt;').replace('>', '&gt;') + ' ... </p><br><br>'
114
+
115
+ html = "<!DOCTYPE html><html><body>" + html_txt + "</body></html>"
116
+ html = html.replace('$$', '')
117
+ # print(html)
118
+
119
+ counter += 1
120
+ print(counter)
121
+
122
+ return html
123
+
124
+ # output_text = gr.outputs.Textbox()
125
+ # output_text = gr.outputs.HighlightedText()#color_map={"retrieved equation": "red"})
126
+ output_text = gr.outputs.HTML()#color_map={"retrieved equation": "red"})
127
+
128
+
129
+ gr.Interface(fn=generate_html,
130
+ inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
131
+ outputs=["html"],#output_text,
132
+ title=title, description=description,
133
+ # article=article,
134
+ examples=examples).launch(share=True)