lovodkin93 commited on
Commit
86c3edd
β€’
1 Parent(s): 411b109

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from datasets import load_dataset
3
+ import gradio as gr
4
+ import os
5
+
6
+ auth_token = os.environ.get("auth_token")
7
+ whoops = load_dataset("nlphuji/whoops", use_auth_token=auth_token)['test']
8
+ # print(f"Loaded WHOOPS!, first example:")
9
+ # print(whoops[0])
10
+ dataset_size = len(whoops)
11
+
12
+ IMAGE = 'image'
13
+ IMAGE_DESIGNER = 'image_designer'
14
+ DESIGNER_EXPLANATION = 'designer_explanation'
15
+ CROWD_CAPTIONS = 'crowd_captions'
16
+ CROWD_EXPLANATIONS = 'crowd_explanations'
17
+ CROWD_UNDERSPECIFIED_CAPTIONS = 'crowd_underspecified_captions'
18
+ QA = 'question_answering_pairs'
19
+ IMAGE_ID = 'image_id'
20
+ SELECTED_CAPTION = 'selected_caption'
21
+ COMMONSENSE_CATEGORY = 'commonsense_category'
22
+ left_side_columns = [IMAGE]
23
+ right_side_columns = [x for x in whoops.features.keys() if x not in left_side_columns]
24
+ enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_UNDERSPECIFIED_CAPTIONS]
25
+ emoji_to_label = {IMAGE_DESIGNER: '🎨, πŸ§‘β€πŸŽ¨, πŸ’»', DESIGNER_EXPLANATION: 'πŸ’‘, πŸ€”, πŸ§‘β€πŸŽ¨',
26
+ CROWD_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ“', CROWD_EXPLANATIONS: 'πŸ‘₯, πŸ’‘, πŸ€”', CROWD_UNDERSPECIFIED_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ‘Ž',
27
+ QA: '❓, πŸ€”, πŸ’‘', IMAGE_ID: 'πŸ”, πŸ“„, πŸ’Ύ', COMMONSENSE_CATEGORY: 'πŸ€”, πŸ“š, πŸ’‘', SELECTED_CAPTION: 'πŸ“, πŸ‘Œ, πŸ’¬'}
28
+ # batch_size = 16
29
+ batch_size = 8
30
+ target_size = (1024, 1024)
31
+
32
+
33
+ def func(index):
34
+ start_index = index * batch_size
35
+ end_index = start_index + batch_size
36
+ all_examples = [whoops[index] for index in list(range(start_index, end_index))]
37
+ values_lst = []
38
+ for example_idx, example in enumerate(all_examples):
39
+ values = get_instance_values(example)
40
+ values_lst += values
41
+ return values_lst
42
+
43
+
44
+ def get_instance_values(example):
45
+ values = []
46
+ for k in left_side_columns + right_side_columns:
47
+ if k == IMAGE:
48
+ value = example["image"].resize(target_size)
49
+ elif k in enumerate_cols:
50
+ value = list_to_string(example[k])
51
+ elif k == QA:
52
+ qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]]
53
+ value = list_to_string(qa_list)
54
+ else:
55
+ value = example[k]
56
+ values.append(value)
57
+ return values
58
+ def list_to_string(lst):
59
+ return '\n'.join(['{}. {}'.format(i+1, item) for i, item in enumerate(lst)])
60
+
61
+ demo = gr.Blocks()
62
+
63
+
64
+ def get_col(example):
65
+ instance_values = get_instance_values(example)
66
+ with gr.Column():
67
+ inputs_left = []
68
+ assert len(left_side_columns) == len(
69
+ instance_values[:len(left_side_columns)]) # excluding the image & designer
70
+ for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]):
71
+ if key == IMAGE:
72
+ img_resized = example["image"].resize(target_size)
73
+ # input_k = gr.Image(value=img_resized, label=example['commonsense_category'])
74
+ input_k = gr.Image(value=img_resized)
75
+ else:
76
+ label = key.capitalize().replace("_", " ")
77
+ input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
78
+ inputs_left.append(input_k)
79
+ with gr.Accordion("Click for details", open=False):
80
+ text_inputs_right = []
81
+ assert len(right_side_columns) == len(
82
+ instance_values[len(left_side_columns):]) # excluding the image & designer
83
+ for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]):
84
+ label = key.capitalize().replace("_", " ")
85
+ text_input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
86
+ text_inputs_right.append(text_input_k)
87
+ return inputs_left, text_inputs_right
88
+
89
+
90
+ with demo:
91
+ gr.Markdown("# Slide to iterate WHOOPS!")
92
+
93
+ with gr.Column():
94
+ num_batches = math.ceil(dataset_size / batch_size)
95
+ slider = gr.Slider(minimum=0, maximum=num_batches, step=1, label=f'Page (out of {num_batches})')
96
+ with gr.Row():
97
+ index = slider.value
98
+ start_index = 0 * batch_size
99
+ end_index = start_index + batch_size
100
+ all_examples = [whoops[index] for index in list(range(start_index, end_index))]
101
+ all_inputs_left_right = []
102
+ for example_idx, example in enumerate(all_examples):
103
+ inputs_left, text_inputs_right = get_col(example)
104
+ inputs_left_right = inputs_left + text_inputs_right
105
+ all_inputs_left_right += inputs_left_right
106
+
107
+ slider.change(func, inputs=[slider], outputs=all_inputs_left_right)
108
+
109
+ demo.launch()