Spaces:
Sleeping
Sleeping
File size: 2,240 Bytes
91dba4d 99ad741 5288696 2ffd102 25004af 6de72ef 541d16e 5288696 25004af ef7044d d8a5bfe f50f2ed ccb6ea2 f50f2ed ccb6ea2 e0240cd ccb6ea2 ef7044d ccb6ea2 2ffd102 541d16e 2ffd102 541d16e 2ffd102 541d16e 25004af f497ee3 a227eaa f2e3727 25004af 6c606d3 d2572ec 00f165b b72fbb8 00f165b 33ddbbb 91dba4d 5288696 6de72ef d8a5bfe 63c7e9c 6de72ef 91dba4d 923ec13 541d16e 923ec13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os
import pandas as pd
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
def predict (Sentence, Word, LLM, Norm, Layer):
models = {'BERT': 'bert-base-uncased',
'ALBERT': 'albert-xxlarge-v2',
'RoBERTa': 'roberta-base'}
if Word not in Sentence: return "invalid input: word not in sentence"
model_name_hf = LLM.lower()
norm_name_hf = Norm.lower()
lm = cwe.CWE(models[LLM])
repo_id = "jwalanthi/semantic-feature-classifiers"
subfolder = f"{model_name_hf}_models_all"
name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}"
model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN'])
label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN'])
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path=model_path,
map_location=None
)
model.eval()
with open (label_path, "r") as file:
labels = [line.rstrip() for line in file.readlines()]
data = (Sentence, Word)
emb = lm.extract_representation(data, layer=Layer)
with torch.no_grad():
pred = torch.nn.functional.relu(model(emb))
pred_sq = pred.squeeze(0)
pred_round = torch.round(pred_sq, decimals=2)
pred_list = pred_round.detach().numpy().tolist()
df = pd.DataFrame({'feature':labels, 'value':pred_list})
df = df[df['value'] > 0]
df['value'] = df['value'].round(2)
df_sorted = df.sort_values(by='value', ascending=False)
df_sorted = df_sorted.reset_index()
Output = [row['feature']+'\t\t\t\t\t\t\t'+str(row['value']) for _, row in df_sorted.iterrows()]
return "All Positive Predicted Values:\n"+"\n".join(Output)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["BERT", "ALBERT", "RoBERTa"]),
gr.Radio(["Binder", "McRae", "Buchanan"]),
gr.Slider(0,12, step=1)
],
outputs=["text"],
)
demo.launch()
if __name__ == "__main__":
demo.launch() |