import gradio as gr import torch from minicons import cwe from huggingface_hub import hf_hub_download import os import pandas as pd from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams def predict (Sentence, Word, LLM, Norm, Layer): models = {'BERT': 'bert-base-uncased', 'ALBERT': 'albert-xxlarge-v2', 'RoBERTa': 'roberta-base'} if Word not in Sentence: return "invalid input: word not in sentence" model_name_hf = LLM.lower() norm_name_hf = Norm.lower() lm = cwe.CWE(models[LLM]) repo_id = "jwalanthi/semantic-feature-classifiers" subfolder = f"{model_name_hf}_models_all" name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}" model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN']) label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN']) model = FeatureNormPredictor.load_from_checkpoint( checkpoint_path=model_path, map_location=None ) model.eval() with open (label_path, "r") as file: labels = [line.rstrip() for line in file.readlines()] data = (Sentence, Word) emb = lm.extract_representation(data, layer=Layer) pred = torch.nn.functional.relu(model(emb)) pred = pred.squeeze(0) pred_list = pred.detach().numpy().tolist() df = pd.DataFrame({'feature':labels, 'value':pred_list}) df = df[df['value'] > 0] df_sorted = df.sort_values(by='value', ascending=False) df_sorted = df_sorted.reset_index() Output = [row['feature']+'\t\t\t\t\t\t\t'+str(row['value']) for _, row in df_sorted.iterrows()] return "All Positive Predicted Values:\n"+"\n".join(Output) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["BERT", "ALBERT", "RoBERTa"]), gr.Radio(["Binder", "McRae", "Buchanan"]), gr.Slider(0,12, step=1) ], outputs=["text"], ) demo.launch() if __name__ == "__main__": demo.launch()