File size: 1,012 Bytes
91dba4d
99ad741
5288696
 
6de72ef
 
5288696
 
 
 
 
 
 
 
 
 
 
 
 
 
6de72ef
 
91dba4d
 
5288696
6de72ef
 
 
5288696
6de72ef
 
 
91dba4d
 
 
6de72ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np

from model import FeatureNormPredictor


def predict (word, sentence, lm_name, layer, norm):
    if word not in sentence: return "invalid input: word not in sentence"
    model_name = lm_name + str(layer) + '_to_' + norm
    lm = cwe.CWE('bert-base-uncased')
    if layer not in range (lm.layers): return "invalid input: layer not in lm"
    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path=model_name+'.ckpt',
        map_location=None
        )
    model.eval()
    inputs = [word, sentence, lm_name, str(layer), norm]
    outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs]
    return "\n".join(outputs)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["bert", "roberta", "electra"]),
        "number",
        gr.Radio(["Binder", "McRae", "Buchanan"]),
    ],
    outputs=["text"],
)

demo.launch()