jwalanthi commited on
Commit
5288696
·
1 Parent(s): 6de72ef

checks input

Browse files
Files changed (2) hide show
  1. .gitignore +4 -0
  2. app.py +22 -5
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # git ignore all the csvs they are too big
2
+ models/
3
+
4
+ __pycache__/
app.py CHANGED
@@ -1,18 +1,35 @@
1
  import gradio as gr
2
-
 
 
 
3
  import numpy as np
4
 
5
- def greet (word, sentence, model, layer, norm):
6
- inputs = [word, sentence, model, str(layer), norm]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs]
8
  return "\n".join(outputs)
9
 
10
  demo = gr.Interface(
11
- fn=greet,
12
  inputs=[
13
  "text",
14
  "text",
15
- gr.Radio(["bert", "gpt2"]),
16
  "number",
17
  gr.Radio(["Binder", "McRae", "Buchanan"]),
18
  ],
 
1
  import gradio as gr
2
+ import torch
3
+ import lightning
4
+ from minicons import cwe
5
+ import pandas as pd
6
  import numpy as np
7
 
8
+ from model import FeatureNormPredictor
9
+
10
+ import sys
11
+ sys.path.insert(0, '/home/jjr4354/semantic-features')
12
+
13
+ def predict (word, sentence, lm_name, layer, norm):
14
+ if word not in sentence: return "invalid input: word not in sentence"
15
+ model_name = lm_name + str(layer) + '_to_' + norm
16
+ lm = cwe.CWE('bert-base-uncased')
17
+ if layer not in range (lm.layers): return "invalid input: layer not in lm"
18
+ model = FeatureNormPredictor.load_from_checkpoint(
19
+ checkpoint_path=model_name+'.ckpt',
20
+ map_location=None
21
+ )
22
+ model.eval()
23
+ inputs = [word, sentence, lm_name, str(layer), norm]
24
  outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs]
25
  return "\n".join(outputs)
26
 
27
  demo = gr.Interface(
28
+ fn=predict,
29
  inputs=[
30
  "text",
31
  "text",
32
+ gr.Radio(["bert", "roberta", "electra"]),
33
  "number",
34
  gr.Radio(["Binder", "McRae", "Buchanan"]),
35
  ],