will33am's picture
update
6d2581f
import gradio as gr
from datasets import load_dataset
# +
def get_methods_and_arch(dataset):
columns = dataset.column_names[5:]
methods = []
archs = []
for column in columns:
methods.append(column.split('_')[0])
archs.append('_'.join(column.split('_')[1:-2]))
return list(set(methods)),list(set(archs))
def get_columns(arch,method):
columns = dataset.column_names[5:]
for col in columns:
if f'{method}_{arch}' in col:
return col
def button_fn(arch,method):
column_heatmap = get_columns(arch,method)
#print("Updated column: ",column_heatmap)
return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap]
def func_slider(index,column_textbox):
#global column_heatmap
example = dataset[index]
return example['image'],example[column_textbox]
# -
dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train')
METHODS, ARCHS = get_methods_and_arch(dataset)
index_default = 0
DEMO = False
if __name__ == '__main__':
demo = gr.Blocks()
with demo:
gr.Markdown("# Heatmap Gaze Location")
with gr.Row():
dropdown_arch = gr.Dropdown(choices = ARCHS,
value = 'resnet50',
label = 'Model')
dropdown_method = gr.Dropdown(choices = METHODS,
value = 'gradcam',
label = 'Method')
with gr.Row():
button = gr.Button(label = 'Update Heatmap Model - Method')
with gr.Row():
hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1)
with gr.Row():
column_textbox = gr.Textbox(label = 'column name',
value = get_columns(ARCHS[0],METHODS[0]) )
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"])
with gr.Column():
image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')])
button.click(fn = button_fn,
inputs = [dropdown_arch,dropdown_method],
outputs = [column_textbox,hf_slider,image_input,image_output])
hf_slider.change(func_slider,
inputs = [hf_slider,column_textbox],
outputs = [image_input, image_output])
if DEMO:
demo.launch(share = True,debug = True)
else:
demo.launch()