File size: 2,696 Bytes
448630d
 
 
 
 
 
 
 
 
 
 
6d2581f
448630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60216b6
448630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60216b6
 
 
 
448630d
 
6d2581f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from datasets import load_dataset


# +
def get_methods_and_arch(dataset):
    columns = dataset.column_names[5:]
    methods = []
    archs   = []
    for column in columns:
        methods.append(column.split('_')[0])
        archs.append('_'.join(column.split('_')[1:-2]))
    return list(set(methods)),list(set(archs))

def get_columns(arch,method):
    columns = dataset.column_names[5:]
    for col in columns:
        if f'{method}_{arch}' in col:
            return col
def button_fn(arch,method):
    column_heatmap = get_columns(arch,method)
    #print("Updated column: ",column_heatmap)
    return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap]

def func_slider(index,column_textbox):
    #global column_heatmap
    example = dataset[index]
    return example['image'],example[column_textbox]


# -

dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train')
METHODS, ARCHS = get_methods_and_arch(dataset)
index_default = 0
DEMO = False

if __name__ == '__main__':
    demo = gr.Blocks()
    with demo:
        gr.Markdown("# Heatmap Gaze Location")            
       
        with gr.Row():
            dropdown_arch   = gr.Dropdown(choices = ARCHS,
                                   value   =  'resnet50',
                                   label   = 'Model')

            dropdown_method = gr.Dropdown(choices = METHODS,
                                   value   =  'gradcam',
                                   label   = 'Method')
        with gr.Row():
            button = gr.Button(label = 'Update Heatmap Model - Method')
                    
        with gr.Row():
            hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1)
        with gr.Row():
            column_textbox = gr.Textbox(label = 'column name',
                                        value = get_columns(ARCHS[0],METHODS[0]) )
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"])
            with gr.Column():
                image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')])

    
        button.click(fn = button_fn,
                     inputs = [dropdown_arch,dropdown_method],
                     outputs = [column_textbox,hf_slider,image_input,image_output])
        
        
        hf_slider.change(func_slider, 
                         inputs  = [hf_slider,column_textbox],
                         outputs = [image_input, image_output])
    if DEMO:
        demo.launch(share = True,debug = True)
    else:
        demo.launch()