File size: 2,240 Bytes
91dba4d
99ad741
5288696
2ffd102
 
25004af
6de72ef
541d16e
5288696
 
25004af
ef7044d
 
d8a5bfe
f50f2ed
ccb6ea2
 
 
f50f2ed
ccb6ea2
e0240cd
ccb6ea2
ef7044d
ccb6ea2
 
2ffd102
541d16e
2ffd102
541d16e
 
 
 
2ffd102
541d16e
 
 
25004af
f497ee3
a227eaa
f2e3727
 
 
25004af
 
6c606d3
d2572ec
00f165b
b72fbb8
00f165b
33ddbbb
 
91dba4d
 
5288696
6de72ef
 
 
d8a5bfe
63c7e9c
 
6de72ef
91dba4d
 
 
923ec13
541d16e
 
923ec13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os
import pandas as pd

from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams


def predict (Sentence, Word, LLM, Norm, Layer):
    models = {'BERT': 'bert-base-uncased',
              'ALBERT': 'albert-xxlarge-v2',
              'RoBERTa': 'roberta-base'}
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name_hf = LLM.lower()
    norm_name_hf = Norm.lower()
    lm = cwe.CWE(models[LLM])

    repo_id = "jwalanthi/semantic-feature-classifiers"
    subfolder = f"{model_name_hf}_models_all"
    name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}"

    model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN'])
    label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN'])

    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path=model_path,
        map_location=None
    )
    model.eval()

    with open (label_path, "r") as file:
        labels = [line.rstrip() for line in file.readlines()]

    data = (Sentence, Word)
    emb = lm.extract_representation(data, layer=Layer)
    with torch.no_grad():
        pred = torch.nn.functional.relu(model(emb))
    pred_sq = pred.squeeze(0)
    pred_round = torch.round(pred_sq, decimals=2)
    pred_list = pred_round.detach().numpy().tolist()

    df = pd.DataFrame({'feature':labels, 'value':pred_list})
    df = df[df['value'] > 0]
    df['value'] = df['value'].round(2)
    df_sorted = df.sort_values(by='value', ascending=False)
    df_sorted = df_sorted.reset_index()

    Output = [row['feature']+'\t\t\t\t\t\t\t'+str(row['value']) for _, row in df_sorted.iterrows()]
    return "All Positive Predicted Values:\n"+"\n".join(Output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["BERT", "ALBERT", "RoBERTa"]),
        gr.Radio(["Binder", "McRae", "Buchanan"]),
        gr.Slider(0,12, step=1)
    ],
    outputs=["text"],
)

demo.launch()

if __name__ == "__main__":
    demo.launch()