import gradio as gr from huggingface_hub import hf_hub_download import pickle import gradio as gr import numpy as np import subprocess import shutil import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc # Define the function to process the input file and model selection def process_file(file,label, model_name): with open(file.name, 'r') as f: content = f.read() saved_test_dataset = "train.txt" saved_test_label = "train_label.txt" # Save the uploaded file content to a specified location shutil.copyfile(file.name, saved_test_dataset) shutil.copyfile(label.name, saved_test_label) # For demonstration purposes, we'll just return the content with the selected model name if(model_name=="FS"): checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32" elif(model_name=="IS"): checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14" elif(model_name=="CORRECTNESS"): checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48" elif(model_name=="EFFECTIVENESS"): checkpoint="ratio_proportion_change3/output/effectiveness/bert_fine_tuned.model.ep28" else: checkpoint=None print(checkpoint) subprocess.run(["python", "src/test_saved_model.py", "--finetuned_bert_checkpoint",checkpoint ]) result = {} with open("result.txt", 'r') as file: for line in file: key, value = line.strip().split(': ', 1) # print(type(key)) if key=='epoch': result[key]=value else: result[key]=float(value) # Create a plot with open("roc_data.pkl", "rb") as f: fpr, tpr, _ = pickle.load(f) roc_auc = auc(fpr, tpr) fig, ax = plt.subplots() ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}') ax.legend(loc="lower right") ax.grid() # Save plot to a file plot_path = "plot.png" fig.savefig(plot_path) plt.close(fig) # Prepare text output text_output = f"Model: {model_name}\nResult:\n{result}" return text_output,plot_path # List of models for the dropdown menu models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"] # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# ASTRA") gr.Markdown("Upload a .txt file and select a model from the dropdown menu.") with gr.Row(): file_input = gr.File(label="Upload a .txt file", file_types=['.txt']) label_input = gr.File(label="Upload a .txt file", file_types=['.txt']) model_dropdown = gr.Dropdown(choices=models, label="Select a model") with gr.Row(): output_text = gr.Textbox(label="Output Text") output_image = gr.Image(label="Output Plot") btn = gr.Button("Submit") btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image]) # Launch the app demo.launch()