Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
# + | |
def get_methods_and_arch(dataset): | |
columns = dataset.column_names[5:] | |
methods = [] | |
archs = [] | |
for column in columns: | |
methods.append(column.split('_')[0]) | |
archs.append('_'.join(column.split('_')[1:-2])) | |
return list(set(methods)),list(set(archs)) | |
def get_columns(arch,method): | |
columns = dataset.column_names[5:] | |
for col in columns: | |
if f'{method}_{arch}' in col: | |
return col | |
def button_fn(arch,method): | |
column_heatmap = get_columns(arch,method) | |
#print("Updated column: ",column_heatmap) | |
return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap] | |
def func_slider(index,column_textbox): | |
#global column_heatmap | |
example = dataset[index] | |
return example['image'],example[column_textbox] | |
# - | |
dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train') | |
METHODS, ARCHS = get_methods_and_arch(dataset) | |
index_default = 0 | |
DEMO = False | |
if __name__ == '__main__': | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Heatmap Gaze Location") | |
with gr.Row(): | |
dropdown_arch = gr.Dropdown(choices = ARCHS, | |
value = 'resnet50', | |
label = 'Model') | |
dropdown_method = gr.Dropdown(choices = METHODS, | |
value = 'gradcam', | |
label = 'Method') | |
with gr.Row(): | |
button = gr.Button(label = 'Update Heatmap Model - Method') | |
with gr.Row(): | |
hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1) | |
with gr.Row(): | |
column_textbox = gr.Textbox(label = 'column name', | |
value = get_columns(ARCHS[0],METHODS[0]) ) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"]) | |
with gr.Column(): | |
image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')]) | |
button.click(fn = button_fn, | |
inputs = [dropdown_arch,dropdown_method], | |
outputs = [column_textbox,hf_slider,image_input,image_output]) | |
hf_slider.change(func_slider, | |
inputs = [hf_slider,column_textbox], | |
outputs = [image_input, image_output]) | |
if DEMO: | |
demo.launch(share = True,debug = True) | |
else: | |
demo.launch() | |