Spaces:
Runtime error
Runtime error
achterbrain
commited on
Commit
•
748cc87
1
Parent(s):
1914d22
Transferred files from streamlit cloud repo
Browse files- Dashboard.py +156 -0
- Data/Prompt_dir_221128.csv +62 -0
- Graphics/IL_Logo.png +0 -0
- pages/1_⚙️Manual assessment.py +142 -0
- pages/2_🤖Automated assessment.py +101 -0
- pages/3_📊Assessment summary.py +160 -0
- pages/Functions/Assessment_functions.py +208 -0
- pages/Functions/coco-labels-paper.txt +91 -0
- requirements.txt +61 -0
Dashboard.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
@st.cache
|
6 |
+
def prompt_to_csv(df):
|
7 |
+
# IMPORTANT: Cache the conversion to prevent computation on every rerun
|
8 |
+
df_download = df
|
9 |
+
df_download['Filename']='p'+df_download['ID'].astype('str')+'_1.png'
|
10 |
+
df_download = df[['Prompt','Filename']].drop_duplicates(subset='Filename')
|
11 |
+
return df_download.to_csv().encode('utf-8')
|
12 |
+
|
13 |
+
|
14 |
+
# Setup
|
15 |
+
## Load prompt directory
|
16 |
+
prompt_dir = pd.read_csv('Data/Prompt_dir_221128.csv') #second version of prompt_dir
|
17 |
+
st.session_state['prompt_dir'] = prompt_dir
|
18 |
+
## Create lists of prompts for manual and automated assessments
|
19 |
+
st.session_state['automated_tasks'] = ['Multiple object types', 'Single object','Negation']
|
20 |
+
automated_prompts = prompt_dir.loc[
|
21 |
+
(prompt_dir['Auto_assessment']==True)&
|
22 |
+
(prompt_dir['Task']).isin(st.session_state['automated_tasks'])].ID.tolist()
|
23 |
+
manual_prompts = prompt_dir.ID.tolist()
|
24 |
+
|
25 |
+
# Page
|
26 |
+
st.title('Generative Image Benchmark')
|
27 |
+
st.write('This is an evaluation platform to assess the performance of image generation algorithms developed by Intel Labs. This is the alpha version of the platform.')
|
28 |
+
st.subheader('User guide')
|
29 |
+
st.write('To assess a generative image algorithm, download a set of prompts using the prompt downloader below. Generate one image per prompt and use the file names provided to name your images. Upload these generated images in the data upload section below. The pages for manual assessment and automated assessment allow you to systematically assess the generated images. The results will be presented and ready for download on the assessment summary page.')
|
30 |
+
st.sidebar.image('Graphics/IL_Logo.png')
|
31 |
+
|
32 |
+
|
33 |
+
# Add prompt downloading functions
|
34 |
+
prompt_download_dict = {}
|
35 |
+
## Count how many prompts are in database to allow for max value in selection
|
36 |
+
prompt_task_count = prompt_dir.Task.value_counts(sort=False)
|
37 |
+
prompt_task_count = prompt_task_count.drop(index='Single object')
|
38 |
+
prompt_task_select = prompt_task_count.copy()
|
39 |
+
## Hide downloader in box
|
40 |
+
with st.expander("Prompt downloader"):
|
41 |
+
st.write('Select the number of prompts you want to download for each task category. The set of prompts will automatically also include all single objects appearing in the selected prompts.')
|
42 |
+
# Create numerical selector for every task in prompt directory
|
43 |
+
for i_task in prompt_task_select.index:
|
44 |
+
prompt_task_select[i_task] = st.number_input(
|
45 |
+
i_task,
|
46 |
+
value = prompt_task_count[i_task],
|
47 |
+
max_value=prompt_task_count[i_task],
|
48 |
+
min_value=0,
|
49 |
+
step = 1)
|
50 |
+
|
51 |
+
# Create df with selected number of prompts per task
|
52 |
+
for i_task in prompt_task_select.index:
|
53 |
+
temp_df = prompt_dir.loc[prompt_dir['Task']==i_task][0:prompt_task_select[i_task]]
|
54 |
+
if len(temp_df)>0:
|
55 |
+
prompt_download_dict[i_task]=temp_df
|
56 |
+
|
57 |
+
# Concat all tasks to dataframe
|
58 |
+
prompt_download = pd.concat(prompt_download_dict.values())
|
59 |
+
|
60 |
+
# Add relevant single object prompts
|
61 |
+
single_object_ids = prompt_download.Linked_prompts.str.split(',').explode().unique().astype('int')
|
62 |
+
prompt_download = pd.concat([
|
63 |
+
prompt_download,
|
64 |
+
prompt_dir.loc[prompt_dir['ID'].isin(single_object_ids)]
|
65 |
+
])
|
66 |
+
|
67 |
+
# Add download button for prompts
|
68 |
+
st.download_button(
|
69 |
+
label="Download prompts",
|
70 |
+
data=prompt_to_csv(prompt_download),
|
71 |
+
file_name='prompt_list.csv',
|
72 |
+
mime='text/csv',
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
# Generate empty dataset for results, if it does not exist yet
|
78 |
+
try:
|
79 |
+
num_uploaded_images = st.session_state['eval_df'].shape[0]
|
80 |
+
except KeyError:
|
81 |
+
st.session_state['eval_df'] = pd.DataFrame(
|
82 |
+
columns=['File_name','Prompt_no','automated_eval','manual_eval','manual_eval_completed','manual_eval_task_score'])
|
83 |
+
st.session_state['uploaded_img'] = []
|
84 |
+
|
85 |
+
# Create dic for automated asssssment if it does not excist yet
|
86 |
+
try:
|
87 |
+
test_dict = st.session_state['results_dict']
|
88 |
+
except KeyError:
|
89 |
+
st.session_state['results_dict'] = {}
|
90 |
+
|
91 |
+
# Data upload setup
|
92 |
+
st.subheader('Data upload')
|
93 |
+
#uploaded_files = st.file_uploader('Upload generated images', accept_multiple_files=True)
|
94 |
+
with st.form("my-form", clear_on_submit=True):
|
95 |
+
uploaded_files = st.file_uploader('Select images for upload', accept_multiple_files=True)
|
96 |
+
|
97 |
+
man_assessment_share = st.selectbox(
|
98 |
+
'Select share of uploaded images to be used for manual assessment.',
|
99 |
+
('100%', '50%'))
|
100 |
+
|
101 |
+
submitted = st.form_submit_button("Add images")
|
102 |
+
st.session_state['uploaded_img'] = st.session_state['uploaded_img']+uploaded_files
|
103 |
+
|
104 |
+
|
105 |
+
# Add new uploaded images to session state
|
106 |
+
## Try to append it to pre-existing list, else create new list in session state
|
107 |
+
## Always reset uploaded files to empty list after they have been added to state
|
108 |
+
if len(uploaded_files) != 0:
|
109 |
+
try:
|
110 |
+
# Extract prompts of uploaded files
|
111 |
+
file_names = [x.name for x in uploaded_files]
|
112 |
+
files_prompts = [x.split('_')[0][1:] for x in file_names]
|
113 |
+
|
114 |
+
# Create manual evaluation df
|
115 |
+
df_dict = {'File_name':file_names, 'Prompt_no':files_prompts}
|
116 |
+
eval_df = pd.DataFrame(df_dict)
|
117 |
+
eval_df['automated_eval'] = eval_df['Prompt_no'].astype('int').isin(automated_prompts)
|
118 |
+
eval_df['manual_eval'] = eval_df['Prompt_no'].astype('int').isin(manual_prompts)
|
119 |
+
eval_df['manual_eval_completed'] = False
|
120 |
+
eval_df['manual_eval_task_score'] = np.nan
|
121 |
+
|
122 |
+
# Exclude given percentage of uploaded images from manual assessment; with random selection
|
123 |
+
if man_assessment_share == '50%':
|
124 |
+
reassign_number = int(len(eval_df)/2)
|
125 |
+
manual_eval_reassign = eval_df['manual_eval']
|
126 |
+
random_image_indices = np.random.choice(len(manual_eval_reassign),reassign_number, replace=False)
|
127 |
+
manual_eval_reassign.iloc[random_image_indices]=False
|
128 |
+
eval_df['manual_eval'] = manual_eval_reassign
|
129 |
+
|
130 |
+
# Join new uploaded df with existing df
|
131 |
+
joint_eval_df = pd.concat([st.session_state['eval_df'], eval_df], ignore_index=True)
|
132 |
+
|
133 |
+
# Add task name to eval_df
|
134 |
+
Prompt_no_task_dict = dict(zip(prompt_dir.ID.astype('str').to_list(),prompt_dir.Task.to_list()))
|
135 |
+
joint_eval_df['Task'] = joint_eval_df.Prompt_no.map(Prompt_no_task_dict)
|
136 |
+
|
137 |
+
# Save eval_df to session state
|
138 |
+
st.session_state['eval_df'] = joint_eval_df
|
139 |
+
|
140 |
+
except KeyError:
|
141 |
+
st.session_state['uploaded_img'] = uploaded_files
|
142 |
+
|
143 |
+
|
144 |
+
eval_df = st.session_state['eval_df']
|
145 |
+
if eval_df.shape[0]!=0:
|
146 |
+
# Print current state of uploaded data
|
147 |
+
st.write("{0} images uploaded. Reload the page to reset the image upload.".format(str(eval_df.shape[0])))
|
148 |
+
st.write("- Available for manual assessment: ", str(sum(eval_df.manual_eval)))
|
149 |
+
manual_eval_available = sum(eval_df.manual_eval)
|
150 |
+
st.write("- Available for automated assessment: ", str(sum(eval_df.automated_eval)))
|
151 |
+
else:
|
152 |
+
st.write("Upload files to start the assessment.")
|
153 |
+
|
154 |
+
#st.write(eval_df)
|
155 |
+
#st.write(prompt_dir)
|
156 |
+
#st.session_state['eval_df']
|
Data/Prompt_dir_221128.csv
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ID,Task,Prompt,Representations,Task_specific_label,Auto_assessment,Linked_prompts
|
2 |
+
1,Directed Actions,Girl pointing at a dog,"girl, dog",pointing at,FALSE,"29,30"
|
3 |
+
2,Directed Actions,Monkey hitting a can,"monkey, can",hitting,FALSE,"31,32"
|
4 |
+
3,Directed Actions,Man petting a cat,"man, cat",petting,FALSE,"33,34"
|
5 |
+
4,Negation,A vase with no flowers in it,vase,flower,TRUE,35
|
6 |
+
5,Negation,A street without a car,street,car,TRUE,36
|
7 |
+
6,Negation,Man without a beard,man,beard,FALSE,33
|
8 |
+
7,Spatial relationships,A book on top of a chair,"book, chair",on top,FALSE,"37,38"
|
9 |
+
8,Spatial relationships,Car behind a bus,"car, bus",behind,FALSE,"39,61"
|
10 |
+
9,Spatial relationships,Cup in front of a vase,"cup, vase",in front,FALSE,"40,35"
|
11 |
+
10,Numbers (multiple objects),Five apples,apple,5,TRUE,41
|
12 |
+
11,Numbers (multiple objects),Eight cars,car,8,TRUE,39
|
13 |
+
12,Numbers (multiple objects),Three cups,cup ,3,TRUE,40
|
14 |
+
13,Basic Arithmetic,Two apples and twice as many oranges,"apple, orange","2,4",TRUE,"41,42"
|
15 |
+
14,Basic Arithmetic,One cat and three times as a many dogs,"cat, dog","1,3",TRUE,"34,30"
|
16 |
+
15,Basic Arithmetic,Five spoons and two less knifes,"spoon, knife","5,3",TRUE,"43,44"
|
17 |
+
16,Multiple object types,"A cow, an apple and a cat","cow, apple, cat",NA,TRUE,"45,41,34"
|
18 |
+
17,Multiple object types,"A cup, a knife, a plate and a hat","cup, knife, plate, hat",NA,TRUE,"40,44,46,47"
|
19 |
+
18,Multiple object types,"A balloon, a lion, a cow, a monkey and a surfboard","balloon, lion, cow, monkey, surfboard",NA,TRUE,"48,49,45,51"
|
20 |
+
19,Chimeras,A cow with the tail of a snake,"cow, snake",NA,FALSE,"45,52"
|
21 |
+
20,Chimeras,A fish with the wings of a duck,"fish, duck",NA,FALSE,"53,54"
|
22 |
+
21,Chimeras,A lion with the head of a goat,"lion, goat",NA,FALSE,"49,55"
|
23 |
+
22, Conditional generation,A green dinosaur or a red cup. Only render green things.,dinosaur,cup,TRUE,56
|
24 |
+
23, Conditional generation,A snail or a horse. Only render mammals.,horse,snail,TRUE,57
|
25 |
+
24, Conditional generation,A girl with a dog or a boy with a cat. Only show females.,"girl, dog","boy, cat",FALSE,"25,29,30"
|
26 |
+
25,Multiple object types,A girl with a dog,"girl, dog",NA,TRUE,"29,30"
|
27 |
+
26,Text rendering,"A book titled ""Probabilistic Models""",book ,Probabilistic Models,TRUE,37
|
28 |
+
27,Text rendering,"A bar sign with the name ""Wish you were beer""",bar sign,Wish you were beer,TRUE,59
|
29 |
+
28,Text rendering,"A shirt that says ""Sorry I am late. I did not want to come.""",shirt,Sorry I am late. I did not want to come.,TRUE,60
|
30 |
+
29,Single object,A girl,girl,NA,TRUE,
|
31 |
+
30,Single object,A dog,dog,NA,TRUE,
|
32 |
+
31,Single object,A monkey,monkey,NA,TRUE,
|
33 |
+
32,Single object,A can,can,NA,TRUE,
|
34 |
+
33,Single object,A man,man,NA,TRUE,
|
35 |
+
34,Single object,A cat,cat,NA,TRUE,
|
36 |
+
35,Single object,A vase,vase,NA,TRUE,
|
37 |
+
36,Single object,A street,street,NA,TRUE,
|
38 |
+
37,Single object,A book,book,NA,TRUE,
|
39 |
+
38,Single object,A chair,chair,NA,TRUE,
|
40 |
+
39,Single object,A car,car,NA,TRUE,
|
41 |
+
40,Single object,A cup,cup,NA,TRUE,
|
42 |
+
41,Single object,An apple,apple,NA,TRUE,
|
43 |
+
42,Single object,An orange,orange,NA,TRUE,
|
44 |
+
43,Single object,A spoon,spoon,NA,TRUE,
|
45 |
+
44,Single object,A knife,knife,NA,TRUE,
|
46 |
+
45,Single object,A cow,cow,NA,TRUE,
|
47 |
+
46,Single object,A plate,plate,NA,TRUE,
|
48 |
+
47,Single object,A hat,hat,NA,TRUE,
|
49 |
+
48,Single object,A balloon,balloon,NA,TRUE,
|
50 |
+
49,Single object,A lion,lion,NA,TRUE,
|
51 |
+
50,Single object,A monkey,monkey,NA,TRUE,
|
52 |
+
51,Single object,A surfboard,surfboard,NA,TRUE,
|
53 |
+
52,Single object,A snake,snake,NA,TRUE,
|
54 |
+
53,Single object,A fish,fish,NA,TRUE,
|
55 |
+
54,Single object,A duck,duck,NA,TRUE,
|
56 |
+
55,Single object,A goat,goat,NA,TRUE,
|
57 |
+
56,Single object,A dinosaur,dinosaur,NA,TRUE,
|
58 |
+
57,Single object,A horse,horse,NA,TRUE,
|
59 |
+
58,Single object,A girl,girl,NA,TRUE,
|
60 |
+
59,Single object,A bar sign,bar sign,NA,TRUE,
|
61 |
+
60,Single object,A shirt,shirt,NA,TRUE,
|
62 |
+
61,Single object,A bus,bus,NA,TRUE,
|
Graphics/IL_Logo.png
ADDED
pages/1_⚙️Manual assessment.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
st.title('Manual assessment')
|
7 |
+
st.write('On this page you can rate all uploaded images with regards to how good they match their respective prompts. You can see the outcome of your assessment on the summary page.')
|
8 |
+
st.write(' ')
|
9 |
+
side_image = Image.open('Graphics/IL_Logo.png')
|
10 |
+
st.sidebar.image(side_image)
|
11 |
+
# Create placeholders for key elements
|
12 |
+
assessment_progress = st.empty()
|
13 |
+
|
14 |
+
# Extract how many images are available for manual assessment in entire uploaded dataset
|
15 |
+
## Set to zero if the dataset has not been created yet due to starting the app on an assessment page
|
16 |
+
manual_eval_available = 0
|
17 |
+
try:
|
18 |
+
curr_eval_df = st.session_state['eval_df']
|
19 |
+
curr_eval_df['Picture_index']=curr_eval_df.index.values
|
20 |
+
curr_manual_eval = curr_eval_df.loc[(curr_eval_df['manual_eval']==True)&(curr_eval_df['manual_eval_completed']==False)]
|
21 |
+
manual_eval_available = len(curr_manual_eval)
|
22 |
+
curr_prompt_dir = st.session_state['prompt_dir']
|
23 |
+
except KeyError:
|
24 |
+
manual_eval_available = 0
|
25 |
+
st.session_state['uploaded_img'] = [] #safety if program is started on manual assesssment page and not desktop
|
26 |
+
|
27 |
+
|
28 |
+
# Main rating loop
|
29 |
+
## If images are available for rating this creates a from to submit ratings to database
|
30 |
+
## If subprompt option is selected, it expands the form to include these as well
|
31 |
+
## If no images are available it prints situation specific instructions
|
32 |
+
if manual_eval_available > 0:
|
33 |
+
# Let user choose whether subprompts should be presented
|
34 |
+
include_subprompts = st.checkbox('Show related subprompts if available (uploaded subprompts may not be shown if images have been assessed already).', value=True)
|
35 |
+
|
36 |
+
# Update the progress statement
|
37 |
+
assessment_progress.write('{0} images ready / left for assessment.'.format(manual_eval_available))
|
38 |
+
|
39 |
+
# Extract first example for manual assessment which is not rated yet (first meaning the lowest index, for lowest prompt number)
|
40 |
+
## Also extract relevant metadata of this example
|
41 |
+
curr_eval_df = st.session_state['eval_df']
|
42 |
+
lowest_prompt_no = curr_eval_df.loc[(curr_eval_df['manual_eval']==True)&(curr_eval_df['manual_eval_completed']==False)].Prompt_no.astype('int').min()
|
43 |
+
curr_picture_index = curr_eval_df.loc[
|
44 |
+
(curr_eval_df['manual_eval']==True)&
|
45 |
+
(curr_eval_df['manual_eval_completed']==False)&
|
46 |
+
(curr_eval_df['Prompt_no']==str(lowest_prompt_no))].Picture_index.min()
|
47 |
+
curr_manual_eval_row = curr_eval_df.iloc[[curr_picture_index]]
|
48 |
+
curr_prompt_ID = int(curr_manual_eval_row.Prompt_no.item())
|
49 |
+
curr_prompt_row =st.session_state['prompt_dir'].loc[st.session_state['prompt_dir']['ID']==curr_prompt_ID]
|
50 |
+
|
51 |
+
# Extract information about linked subprompts
|
52 |
+
curr_linked_prompts = curr_prompt_row.Linked_prompts.item()
|
53 |
+
|
54 |
+
# Set it to nan if the user chose to hide subprompts in evaluation
|
55 |
+
if include_subprompts == False:
|
56 |
+
curr_linked_prompts = float('nan')
|
57 |
+
|
58 |
+
# Split the subprompt string to get actual list of subprompt IDs
|
59 |
+
if pd.notna(curr_linked_prompts):
|
60 |
+
curr_linked_prompts = curr_linked_prompts.split(',')
|
61 |
+
|
62 |
+
# Create form to collect assessment
|
63 |
+
## First create main prompt inputs, then render subprompts if subprompt list found
|
64 |
+
## The submit button writes assessment to database
|
65 |
+
form_loc = st.empty()
|
66 |
+
with form_loc.form("multi_form",clear_on_submit=True):
|
67 |
+
|
68 |
+
# Write main prompt
|
69 |
+
st.write('Prompt: {0}'.format(
|
70 |
+
curr_prompt_dir.loc[curr_prompt_dir['ID']==int(curr_manual_eval_row.Prompt_no.item())]['Prompt'].item()
|
71 |
+
))
|
72 |
+
# Exclude prompt from rating if user chooses to
|
73 |
+
include_prompt = st.checkbox('Include this prompt in assessment summary', value=True)
|
74 |
+
|
75 |
+
# Show image of current prompt and rating
|
76 |
+
st.image(st.session_state['uploaded_img'][curr_manual_eval_row.Picture_index.item()],width=350)
|
77 |
+
curr_manual_eval_row['manual_eval_task_score'] = st.radio(
|
78 |
+
"Does the image match the prompt?",('Yes', 'No'), horizontal=True, key='base')
|
79 |
+
|
80 |
+
st.write(' ') # Create whitespace
|
81 |
+
st.write(' ') # Create whitespace
|
82 |
+
|
83 |
+
# If there are linked prompts, create df with info
|
84 |
+
# Else create emtpy df which will automatically skip the rating creation for these prompts
|
85 |
+
# Here we do not test for (curr_eval_df['manual_eval']==True) as the curr_linked_prompts is already testing for valid prompt number and we want to ignore the exclusion for subprompts
|
86 |
+
if type(curr_linked_prompts)==list:
|
87 |
+
curr_linked_rows = curr_eval_df.loc[
|
88 |
+
(curr_eval_df['manual_eval_completed']==False)&
|
89 |
+
(curr_eval_df['Prompt_no'].isin(curr_linked_prompts))]
|
90 |
+
curr_linked_rows = curr_linked_rows.groupby('Prompt_no').first()
|
91 |
+
else:
|
92 |
+
curr_linked_rows = pd.DataFrame()
|
93 |
+
|
94 |
+
# Create rating for subprompts if a df for subprompt info was created
|
95 |
+
for row in curr_linked_rows.itertuples():
|
96 |
+
# Prompt
|
97 |
+
st.write('Prompt: {0}'.format(
|
98 |
+
curr_prompt_dir.loc[curr_prompt_dir['ID']==int(row.Index)]['Prompt'].item()
|
99 |
+
))
|
100 |
+
# Image
|
101 |
+
st.image(st.session_state['uploaded_img'][row.Picture_index],width=350)
|
102 |
+
# Rating
|
103 |
+
curr_linked_rows.loc[curr_linked_rows['Picture_index']==row.Picture_index,'manual_eval_task_score'] = st.radio(
|
104 |
+
"Does the image match the prompt?",('Yes', 'No'), horizontal=True, key=row.Picture_index)
|
105 |
+
st.write(' ')
|
106 |
+
st.write(' ')
|
107 |
+
|
108 |
+
|
109 |
+
# Submit assessments to database
|
110 |
+
submitted = st.form_submit_button("Submit")
|
111 |
+
if submitted:
|
112 |
+
# First add main prompt assessment
|
113 |
+
st.session_state['eval_df'].loc[
|
114 |
+
curr_picture_index,'manual_eval']=include_prompt
|
115 |
+
st.session_state['eval_df'].loc[
|
116 |
+
curr_picture_index,'manual_eval_completed']=True
|
117 |
+
st.session_state['eval_df'].loc[
|
118 |
+
curr_picture_index,'manual_eval_task_score']=curr_manual_eval_row['manual_eval_task_score'].item()
|
119 |
+
|
120 |
+
# Add subprompt assessment if dataset was created for subprompts
|
121 |
+
# This stage will automatically be skipped if the df for linked prompts is empty
|
122 |
+
for row in curr_linked_rows.itertuples():
|
123 |
+
st.session_state['eval_df'].loc[
|
124 |
+
row.Picture_index,'manual_eval']=include_prompt
|
125 |
+
st.session_state['eval_df'].loc[
|
126 |
+
row.Picture_index,'manual_eval_completed']=True
|
127 |
+
st.session_state['eval_df'].loc[
|
128 |
+
row.Picture_index,'manual_eval_task_score']=row.manual_eval_task_score
|
129 |
+
|
130 |
+
# Reset page after ratings were submitted
|
131 |
+
st.experimental_rerun()
|
132 |
+
# If no files are uploaded
|
133 |
+
elif len(st.session_state['uploaded_img'])==0:
|
134 |
+
assessment_progress.write('Upload files on dashboard starting page to start manual assessment.')
|
135 |
+
# If files are uploaded but all ratings are completed
|
136 |
+
else:
|
137 |
+
assessment_progress.write('You finished assessing the current batch of uploaded images. Upload more pictures of generate your results on the summary page.')
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
#st.session_state['eval_df'].loc[curr_manual_eval,'manual_eval_completed']=True
|
142 |
+
#st.write(st.session_state['eval_df'])
|
pages/2_🤖Automated assessment.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from itertools import compress
|
4 |
+
from PIL import Image
|
5 |
+
#from pages.Functions.Assessment_functions import Empty_DSwrapper
|
6 |
+
|
7 |
+
@st.cache
|
8 |
+
def Empty_DSwrapper(img, representations = None, Task_specific_label = None):
|
9 |
+
'''
|
10 |
+
Dashboard wrapper of DETR_multi_object_counting
|
11 |
+
'''
|
12 |
+
img_input = img
|
13 |
+
|
14 |
+
return True
|
15 |
+
|
16 |
+
# Create dictionary to hold functions
|
17 |
+
fun_dict = {
|
18 |
+
'Multiple object types':Empty_DSwrapper,
|
19 |
+
'Single object':Empty_DSwrapper,
|
20 |
+
'Negation':Empty_DSwrapper}
|
21 |
+
|
22 |
+
|
23 |
+
st.title('Automated Assessment')
|
24 |
+
st.write('On this page you can use automated assessment algorithms to assess how good uploaded images match their respective prompts.')
|
25 |
+
st.write(' ')
|
26 |
+
side_image = Image.open('Graphics/IL_Logo.png')
|
27 |
+
st.sidebar.image(side_image)
|
28 |
+
|
29 |
+
try:
|
30 |
+
# Create necessary variables
|
31 |
+
prompt_dir = st.session_state['prompt_dir']
|
32 |
+
curr_eval_df = st.session_state['eval_df']
|
33 |
+
curr_eval_df['Picture_index']=curr_eval_df.index.values
|
34 |
+
|
35 |
+
# Assess how many images are available for automatic assessment
|
36 |
+
automated_eval_available = sum(curr_eval_df['automated_eval'])
|
37 |
+
|
38 |
+
# Add task name to eval_df
|
39 |
+
temp_prompt_dir=prompt_dir[['ID','Representations','Task_specific_label']]
|
40 |
+
temp_prompt_dir['Prompt_no']=temp_prompt_dir['ID'].astype('str')
|
41 |
+
curr_eval_df = curr_eval_df.merge(temp_prompt_dir,on='Prompt_no')
|
42 |
+
except KeyError:
|
43 |
+
automated_eval_available = 0
|
44 |
+
|
45 |
+
|
46 |
+
# If images for assessment available: create form to start assessment
|
47 |
+
# Else: Note to upload images for assessment
|
48 |
+
if automated_eval_available > 0:
|
49 |
+
|
50 |
+
with st.form("auto_assessment_form",clear_on_submit=True):
|
51 |
+
# Form info statment
|
52 |
+
st.write('Select tasks to assess with the automated assessment:')
|
53 |
+
|
54 |
+
# Add selection for available categories
|
55 |
+
assess_multi_object = st.checkbox(
|
56 |
+
'Multiple object types ({0} images available)'.format(
|
57 |
+
len(curr_eval_df.loc[
|
58 |
+
(curr_eval_df['automated_eval']==True)&
|
59 |
+
(curr_eval_df['Task']=='Multiple object types')])
|
60 |
+
))
|
61 |
+
assess_single_object = st.checkbox(
|
62 |
+
'Single object type ({0} images available)'.format(
|
63 |
+
len(curr_eval_df.loc[
|
64 |
+
(curr_eval_df['automated_eval']==True)&
|
65 |
+
(curr_eval_df['Task']=='Single object')])
|
66 |
+
))
|
67 |
+
|
68 |
+
negation = st.checkbox(
|
69 |
+
'Negation ({0} images available)'.format(
|
70 |
+
len(curr_eval_df.loc[
|
71 |
+
(curr_eval_df['automated_eval']==True)&
|
72 |
+
(curr_eval_df['Task']=='Negation')])
|
73 |
+
))
|
74 |
+
|
75 |
+
submitted = st.form_submit_button("Start automated assessment")
|
76 |
+
if submitted:
|
77 |
+
# Create list for tasks which were selected for assessment
|
78 |
+
selected_tasks = list(
|
79 |
+
compress(
|
80 |
+
['Multiple object types','Single object','Negation'],
|
81 |
+
[assess_multi_object,assess_single_object,negation]))
|
82 |
+
# Create dataset to loop over with assessment
|
83 |
+
assessed_df = curr_eval_df.loc[
|
84 |
+
(curr_eval_df['automated_eval']==True)&
|
85 |
+
(curr_eval_df['Task'].isin(selected_tasks))]
|
86 |
+
results_column = []
|
87 |
+
|
88 |
+
for row in assessed_df.itertuples():
|
89 |
+
# Apply task based classifier and safe in list
|
90 |
+
temp_image = Image.open(st.session_state['uploaded_img'][row.Picture_index])
|
91 |
+
temp_result = fun_dict[row.Task](
|
92 |
+
temp_image,row.Representations,row.Task_specific_label)
|
93 |
+
results_column.append(temp_result)
|
94 |
+
|
95 |
+
assessed_df['Score']=results_column
|
96 |
+
st.session_state['auto_eval_df']=assessed_df[['File_name','Prompt_no','Picture_index','Task','Score']]
|
97 |
+
st.write('Completed assessment. Access results on the summary page.')
|
98 |
+
else:
|
99 |
+
st.write('Upload files on dashboard starting page to start automated assessment.')
|
100 |
+
|
101 |
+
#st.write(st.session_state['auto_eval_df'])
|
pages/3_📊Assessment summary.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import seaborn as sns
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from PIL import Image
|
6 |
+
side_image = Image.open('Graphics/IL_Logo.png')
|
7 |
+
st.sidebar.image(side_image)
|
8 |
+
|
9 |
+
@st.cache
|
10 |
+
def convert_df_to_csv(df):
|
11 |
+
# IMPORTANT: Cache the conversion to prevent computation on every rerun
|
12 |
+
return df[['File_name','Prompt_no','Task','Score']].to_csv().encode('utf-8')
|
13 |
+
|
14 |
+
def plot_style_simple(results_df):
|
15 |
+
eval_sum = results_df.groupby('Task')['Score'].sum()
|
16 |
+
eval_count = results_df.groupby('Task')['Score'].count()
|
17 |
+
eval_share = (eval_sum/eval_count)*100
|
18 |
+
|
19 |
+
fig = plt.figure(figsize=(12, 3))
|
20 |
+
sns.barplot(x=eval_share.index, y=eval_share.values, palette='GnBu')
|
21 |
+
plt.xticks(rotation=-45)
|
22 |
+
plt.ylabel('Percentage correct')
|
23 |
+
plt.xlabel(' ')
|
24 |
+
return fig
|
25 |
+
|
26 |
+
def plot_style_combined(results_df, uploaded_df = None):
|
27 |
+
# Create joined dataframe of results and uploadd_df
|
28 |
+
uploaded_results_df = pd.read_csv(uploaded_df)
|
29 |
+
manual_results_df['Model']='Current'
|
30 |
+
uploaded_results_df['Model']='Uploaded'
|
31 |
+
results_df = pd.concat([manual_results_df,uploaded_results_df])
|
32 |
+
|
33 |
+
# Create scores for plot
|
34 |
+
eval_sum = results_df.groupby(['Model','Task'])['Score'].sum()
|
35 |
+
eval_count = results_df.groupby(['Model','Task'])['Score'].count()
|
36 |
+
eval_share = (eval_sum/eval_count)*100
|
37 |
+
eval_share = eval_share.reset_index()
|
38 |
+
|
39 |
+
# Create plot
|
40 |
+
fig = plt.figure(figsize=(12, 3))
|
41 |
+
sns.barplot(data=eval_share,x='Task',y='Score',hue='Model', palette='GnBu')
|
42 |
+
plt.xticks(rotation=-45)
|
43 |
+
plt.ylabel('Percentage correct')
|
44 |
+
plt.xlabel(' ')
|
45 |
+
return fig
|
46 |
+
|
47 |
+
assessment_result_frames = {}
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
st.title('Assessment Summary')
|
52 |
+
|
53 |
+
st.header('Manual assessment')
|
54 |
+
|
55 |
+
|
56 |
+
try:
|
57 |
+
if sum(st.session_state['eval_df']['manual_eval_completed'])>0:
|
58 |
+
# Display file uploader
|
59 |
+
manual_file_upload = st.file_uploader("Upload .csv with saved manual assessment for model comparison")
|
60 |
+
|
61 |
+
# Create dataset for manual summary plots
|
62 |
+
manual_eval_df = st.session_state['eval_df']
|
63 |
+
manual_eval_df['Score'] = manual_eval_df['manual_eval_task_score'].map({'Yes':True, 'No':False})
|
64 |
+
manual_results_df = manual_eval_df.loc[
|
65 |
+
(manual_eval_df['manual_eval']==True)&
|
66 |
+
(manual_eval_df['manual_eval_completed']==True)]
|
67 |
+
|
68 |
+
assessment_result_frames['Manual assessment'] = manual_results_df
|
69 |
+
|
70 |
+
# If df was uploaded for comparison, we create comparison plot, else simple plot
|
71 |
+
if manual_file_upload == None:
|
72 |
+
fig = plot_style_simple(manual_results_df)
|
73 |
+
st.pyplot(fig)
|
74 |
+
else:
|
75 |
+
fig = plot_style_combined(manual_results_df,manual_file_upload)
|
76 |
+
st.pyplot(fig)
|
77 |
+
|
78 |
+
st.download_button(
|
79 |
+
label="Download manual assessment data",
|
80 |
+
data=convert_df_to_csv(manual_results_df),
|
81 |
+
file_name='manual_assessment.csv',
|
82 |
+
mime='text/csv',
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
st.write('Complete manual assessment to generate summary.')
|
86 |
+
except KeyError:
|
87 |
+
st.write('Complete automated assessment to generate summary.')
|
88 |
+
|
89 |
+
|
90 |
+
st.write(' ')
|
91 |
+
st.header('Automated assessment')
|
92 |
+
try:
|
93 |
+
# Create dataset for automated summary plots
|
94 |
+
auto_eval_df = st.session_state['auto_eval_df']
|
95 |
+
assessment_result_frames['Automated assessment'] = auto_eval_df
|
96 |
+
|
97 |
+
# Display file uploader
|
98 |
+
auto_file_upload = st.file_uploader("Upload .csv with saved automated assessment for model comparison")
|
99 |
+
|
100 |
+
|
101 |
+
# If df was uploaded for comparison, we create comparison plot, else simple plot
|
102 |
+
if auto_file_upload == None:
|
103 |
+
fig = plot_style_simple(auto_eval_df)
|
104 |
+
st.pyplot(fig)
|
105 |
+
else:
|
106 |
+
fig = plot_style_combined(auto_eval_df,auto_file_upload)
|
107 |
+
st.pyplot(fig)
|
108 |
+
|
109 |
+
st.download_button(
|
110 |
+
label="Download automated assessment data",
|
111 |
+
data=convert_df_to_csv(auto_eval_df),
|
112 |
+
file_name='automated_assessment.csv',
|
113 |
+
mime='text/csv',
|
114 |
+
)
|
115 |
+
except KeyError:
|
116 |
+
st.write('Complete automated assessment to generate summary.')
|
117 |
+
|
118 |
+
|
119 |
+
try:
|
120 |
+
# Start gallery
|
121 |
+
st.header('Assessment gallery')
|
122 |
+
|
123 |
+
assessment_method_selected = st.selectbox(
|
124 |
+
'Select generation method',
|
125 |
+
assessment_result_frames.keys())
|
126 |
+
|
127 |
+
if len(assessment_result_frames.keys())<1:
|
128 |
+
st.write('Complete manual or automated assessment to access images in the gallery.')
|
129 |
+
|
130 |
+
# Create needed info frames
|
131 |
+
gallery_df = assessment_result_frames[assessment_method_selected]
|
132 |
+
curr_prompt_dir = st.session_state['prompt_dir']
|
133 |
+
|
134 |
+
# Select task
|
135 |
+
tasks_available = gallery_df.Task.unique().tolist()
|
136 |
+
task_selected = st.selectbox('Select task type',tasks_available)
|
137 |
+
# Select image type
|
138 |
+
type_selected = st.selectbox(
|
139 |
+
'Select image type',
|
140 |
+
('Correctly generated images', 'Incorrectly generated images'))
|
141 |
+
type_selected_dict = {'Correctly generated images':True, 'Incorrectly generated images':False}
|
142 |
+
# Create df for presented images
|
143 |
+
gallery_df_print = gallery_df.loc[
|
144 |
+
(gallery_df['Score']==type_selected_dict[type_selected])&
|
145 |
+
(gallery_df['Task']==task_selected)]
|
146 |
+
# Select presented image and prompt
|
147 |
+
generation_number = st.number_input('Generation number',min_value=1, max_value=len(gallery_df_print), step=1)
|
148 |
+
gallery_row_print = gallery_df_print.iloc[int(generation_number-1)]
|
149 |
+
curr_Prompt_no = gallery_row_print.Prompt_no
|
150 |
+
curr_Prompt = curr_prompt_dir[curr_prompt_dir['ID']==int(curr_Prompt_no)].Prompt
|
151 |
+
curr_Picture_index = gallery_row_print.Picture_index.item()
|
152 |
+
# Plot prompt and image
|
153 |
+
st.write('Prompt: '+curr_Prompt.item())
|
154 |
+
st.image(st.session_state['uploaded_img'][curr_Picture_index],width=350)
|
155 |
+
|
156 |
+
#st.write(auto_df_print)
|
157 |
+
except IndexError:
|
158 |
+
st.write('There is no image availabe in your selected category.')
|
159 |
+
except KeyError:
|
160 |
+
pass
|
pages/Functions/Assessment_functions.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import pandas as pd
|
5 |
+
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection
|
6 |
+
from PIL import Image
|
7 |
+
CLIPmodel_import = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
8 |
+
CLIPprocessor_import = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
9 |
+
DetrFeatureExtractor_import = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
|
10 |
+
DetrModel_import = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
11 |
+
|
12 |
+
# Import list of coco example objects
|
13 |
+
script_path = os.path.dirname(__file__)
|
14 |
+
coco_objects = open(script_path+"/coco-labels-paper.txt", "r")
|
15 |
+
coco_objects = coco_objects.read()
|
16 |
+
coco_objects = coco_objects.split("\n")
|
17 |
+
|
18 |
+
# Example image
|
19 |
+
#test_image = Image.open('pages/Functions/test_image.png')
|
20 |
+
#test_image = Image.open('pages/Functions/test_imageIV.png')
|
21 |
+
|
22 |
+
###### Empty function for debugging
|
23 |
+
|
24 |
+
def Empty_DSwrapper(img, representations = None, Task_specific_label = None):
|
25 |
+
'''
|
26 |
+
Dashboard wrapper of DETR_multi_object_counting
|
27 |
+
'''
|
28 |
+
img_input = img
|
29 |
+
|
30 |
+
return True
|
31 |
+
|
32 |
+
|
33 |
+
###### Helper functions
|
34 |
+
def Coco_object_set(included_object, set_length=6):
|
35 |
+
'''
|
36 |
+
Creates set of object based on coco objects and the currently correct object.
|
37 |
+
'''
|
38 |
+
curr_object_set = set([included_object])
|
39 |
+
|
40 |
+
while len(curr_object_set)<set_length:
|
41 |
+
temp_object = random.choice(coco_objects)
|
42 |
+
curr_object_set.add(temp_object)
|
43 |
+
|
44 |
+
return list(curr_object_set)
|
45 |
+
|
46 |
+
|
47 |
+
def Object_set_creator(included_object, list_of_all_objects = coco_objects, excluded_objects_list = [], set_length=6):
|
48 |
+
'''
|
49 |
+
Creates set of object based on list_of_all_objects.
|
50 |
+
The included object will always be in the list.
|
51 |
+
Optional list of objects to be excluded from the set.
|
52 |
+
'''
|
53 |
+
curr_object_set = set([included_object])
|
54 |
+
|
55 |
+
# Check that the included object is not contained in the excluded objects
|
56 |
+
if included_object in excluded_objects_list:
|
57 |
+
raise ValueError('The included_object can not be part of the excluded_objects list.')
|
58 |
+
|
59 |
+
while len(curr_object_set)<set_length:
|
60 |
+
temp_object = random.choice(list_of_all_objects)
|
61 |
+
if temp_object not in excluded_objects_list:
|
62 |
+
curr_object_set.add(temp_object)
|
63 |
+
|
64 |
+
return list(curr_object_set)
|
65 |
+
|
66 |
+
|
67 |
+
###### Single object recognition
|
68 |
+
|
69 |
+
def CLIP_single_object_classifier(img, object_class, task_specific_label=None):
|
70 |
+
'''
|
71 |
+
Test presence of object in image by using the "red herring strategy" and CLIP algorithm.
|
72 |
+
|
73 |
+
Note that the task_specific_label is not used for this classifier.
|
74 |
+
'''
|
75 |
+
# Define model and parameters
|
76 |
+
word_list = Coco_object_set(object_class)
|
77 |
+
inputs = CLIPprocessor_import(text=word_list, images=img, return_tensors="pt", padding=True)
|
78 |
+
# Run inference
|
79 |
+
outputs = CLIPmodel_import(**inputs)
|
80 |
+
# Get image-text similarity score
|
81 |
+
logits_per_image = outputs.logits_per_image
|
82 |
+
# Get probabilities
|
83 |
+
probs = logits_per_image.softmax(dim=1)
|
84 |
+
# Return true if the highest prob value is recognised
|
85 |
+
if word_list[probs.argmax().item()]==object_class:
|
86 |
+
return True
|
87 |
+
else:
|
88 |
+
return False
|
89 |
+
|
90 |
+
|
91 |
+
def CLIP_object_recognition(img, object_class, tested_classes):
|
92 |
+
'''
|
93 |
+
More general CLIP object recogntintion implementation
|
94 |
+
'''
|
95 |
+
if object_class not in tested_classes:
|
96 |
+
raise ValueError('The object_class has to be part of the tested_classes list.')
|
97 |
+
|
98 |
+
# Define model and parameters
|
99 |
+
inputs = CLIPprocessor_import(text=tested_classes, images=img, return_tensors="pt", padding=True)
|
100 |
+
# Run inference
|
101 |
+
outputs = CLIPmodel_import(**inputs)
|
102 |
+
# Get image-text similarity score
|
103 |
+
logits_per_image = outputs.logits_per_image
|
104 |
+
# Get probabilities
|
105 |
+
probs = logits_per_image.softmax(dim=1)
|
106 |
+
# Return true if the highest prob value is recognised
|
107 |
+
if tested_classes[probs.argmax().item()]==object_class:
|
108 |
+
return True
|
109 |
+
else:
|
110 |
+
return False
|
111 |
+
|
112 |
+
|
113 |
+
###### Multi object recognition
|
114 |
+
#list_of_objects = ['cat','apple','cow']
|
115 |
+
|
116 |
+
def CLIP_multi_object_recognition(img, list_of_objects):
|
117 |
+
'''
|
118 |
+
Algorithm based on CLIP to test presence of multiple objects.
|
119 |
+
|
120 |
+
Currently has a debugging print call in.
|
121 |
+
'''
|
122 |
+
# Loop over list of objects, test for presence of each inidividually, making sure that non of the other objects is part of test set
|
123 |
+
for i_object in list_of_objects:
|
124 |
+
# Create list with objects not in test set (all objects which arent i_object)
|
125 |
+
untested_objects = [x for x in list_of_objects if x!= i_object]
|
126 |
+
# Create set going into clip object recogniser and test this set using standard recognition function
|
127 |
+
CLIP_test_classes = Object_set_creator(included_object=i_object, excluded_objects_list=untested_objects)
|
128 |
+
i_object_present = CLIP_object_recognition(img, i_object, CLIP_test_classes)
|
129 |
+
print(i_object+str(i_object_present))
|
130 |
+
# Stop loop and return false if one of the objects is not recognised by CLIP
|
131 |
+
if i_object_present == False:
|
132 |
+
return False
|
133 |
+
|
134 |
+
# Return true if all objects were recognised
|
135 |
+
return True
|
136 |
+
|
137 |
+
def CLIP_multi_object_recognition_DSwrapper(img, representations, task_specific_label=None):
|
138 |
+
'''
|
139 |
+
Dashboard wrapper of CLIP_multi_object_recognition
|
140 |
+
|
141 |
+
Note that the task_specific_label is not used for this classifier.
|
142 |
+
'''
|
143 |
+
list_of_objects = representations.split(', ')
|
144 |
+
return CLIP_multi_object_recognition(img,list_of_objects)
|
145 |
+
|
146 |
+
###### Negation
|
147 |
+
def CLIP_object_negation(img, present_object, absent_object):
|
148 |
+
'''
|
149 |
+
Algorithm based on CLIP to test negation prompts
|
150 |
+
'''
|
151 |
+
# Create sets of objects for present and absent object
|
152 |
+
tested_classes_present = Object_set_creator(
|
153 |
+
included_object=present_object, excluded_objects_list=[absent_object])
|
154 |
+
tested_classes_absent = Object_set_creator(
|
155 |
+
included_object=absent_object, excluded_objects_list=[present_object],set_length=10)
|
156 |
+
|
157 |
+
# Use CLIP object recognition to test for objects.
|
158 |
+
presence_test = CLIP_object_recognition(img, present_object, tested_classes_present)
|
159 |
+
absence_test = CLIP_object_recognition(img, absent_object, tested_classes_absent)
|
160 |
+
|
161 |
+
if presence_test==True and absence_test==False:
|
162 |
+
return True
|
163 |
+
else:
|
164 |
+
return False
|
165 |
+
|
166 |
+
###### Counting / arithmetic
|
167 |
+
'''
|
168 |
+
test_image = Image.open('pages/Functions/test_imageIII.jpeg')
|
169 |
+
object_classes = ['cat','remote']
|
170 |
+
object_counts = [2,2]
|
171 |
+
'''
|
172 |
+
|
173 |
+
def DETR_multi_object_counting(img, object_classes, object_counts, confidence_treshold=0.5):
|
174 |
+
# Apply Detr to image
|
175 |
+
inputs = DetrFeatureExtractor_import(images=img, return_tensors="pt")
|
176 |
+
outputs = DetrModel_import(**inputs)
|
177 |
+
|
178 |
+
# Convert outputs (bounding boxes and class logits) to COCO API
|
179 |
+
target_sizes = torch.tensor([img.size[::-1]])
|
180 |
+
results = DetrFeatureExtractor_import.post_process_object_detection(
|
181 |
+
outputs, threshold=confidence_treshold, target_sizes=target_sizes)[0]
|
182 |
+
|
183 |
+
# Create dict with value_counts
|
184 |
+
count_dict = pd.Series(results['labels'].numpy())
|
185 |
+
count_dict = count_dict.map(DetrModel_import.config.id2label)
|
186 |
+
count_dict = count_dict.value_counts().to_dict()
|
187 |
+
|
188 |
+
# Create dict for correct response
|
189 |
+
label_dict = dict(zip(object_classes, object_counts))
|
190 |
+
|
191 |
+
# Return False is the count for a given label does not match
|
192 |
+
for i_item in label_dict.items():
|
193 |
+
if int(count_dict[i_item[0]])==int(i_item[1]): # Adding type control for comparison due to str read in
|
194 |
+
print(str(i_item)+'_true')
|
195 |
+
else:
|
196 |
+
print(str(i_item)+'_false')
|
197 |
+
print("oberserved: "+str(count_dict[i_item[0]]))
|
198 |
+
return False
|
199 |
+
# If all match, return true
|
200 |
+
return True
|
201 |
+
|
202 |
+
def DETR_multi_object_counting_DSwrapper(img, representations, Task_specific_label):
|
203 |
+
'''
|
204 |
+
Dashboard wrapper of DETR_multi_object_counting
|
205 |
+
'''
|
206 |
+
list_of_objects = representations.split(', ')
|
207 |
+
object_counts = Task_specific_label.split(', ')
|
208 |
+
return DETR_multi_object_counting(img,list_of_objects, object_counts, confidence_treshold=0.5)
|
pages/Functions/coco-labels-paper.txt
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person
|
2 |
+
bicycle
|
3 |
+
car
|
4 |
+
motorcycle
|
5 |
+
airplane
|
6 |
+
bus
|
7 |
+
train
|
8 |
+
truck
|
9 |
+
boat
|
10 |
+
traffic light
|
11 |
+
fire hydrant
|
12 |
+
street sign
|
13 |
+
stop sign
|
14 |
+
parking meter
|
15 |
+
bench
|
16 |
+
bird
|
17 |
+
cat
|
18 |
+
dog
|
19 |
+
horse
|
20 |
+
sheep
|
21 |
+
cow
|
22 |
+
elephant
|
23 |
+
bear
|
24 |
+
zebra
|
25 |
+
giraffe
|
26 |
+
hat
|
27 |
+
backpack
|
28 |
+
umbrella
|
29 |
+
shoe
|
30 |
+
eye glasses
|
31 |
+
handbag
|
32 |
+
tie
|
33 |
+
suitcase
|
34 |
+
frisbee
|
35 |
+
skis
|
36 |
+
snowboard
|
37 |
+
sports ball
|
38 |
+
kite
|
39 |
+
baseball bat
|
40 |
+
baseball glove
|
41 |
+
skateboard
|
42 |
+
surfboard
|
43 |
+
tennis racket
|
44 |
+
bottle
|
45 |
+
plate
|
46 |
+
wine glass
|
47 |
+
cup
|
48 |
+
fork
|
49 |
+
knife
|
50 |
+
spoon
|
51 |
+
bowl
|
52 |
+
banana
|
53 |
+
apple
|
54 |
+
sandwich
|
55 |
+
orange
|
56 |
+
broccoli
|
57 |
+
carrot
|
58 |
+
hot dog
|
59 |
+
pizza
|
60 |
+
donut
|
61 |
+
cake
|
62 |
+
chair
|
63 |
+
couch
|
64 |
+
potted plant
|
65 |
+
bed
|
66 |
+
mirror
|
67 |
+
dining table
|
68 |
+
window
|
69 |
+
desk
|
70 |
+
toilet
|
71 |
+
door
|
72 |
+
tv
|
73 |
+
laptop
|
74 |
+
mouse
|
75 |
+
remote
|
76 |
+
keyboard
|
77 |
+
cell phone
|
78 |
+
microwave
|
79 |
+
oven
|
80 |
+
toaster
|
81 |
+
sink
|
82 |
+
refrigerator
|
83 |
+
blender
|
84 |
+
book
|
85 |
+
clock
|
86 |
+
vase
|
87 |
+
scissors
|
88 |
+
teddy bear
|
89 |
+
hair drier
|
90 |
+
toothbrush
|
91 |
+
hair brush
|
requirements.txt
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==4.2.0
|
2 |
+
attrs==22.1.0
|
3 |
+
blinker==1.5
|
4 |
+
cachetools==5.2.0
|
5 |
+
charset-normalizer==2.1.1
|
6 |
+
click==8.1.3
|
7 |
+
commonmark==0.9.1
|
8 |
+
contourpy==1.0.6
|
9 |
+
cycler==0.11.0
|
10 |
+
decorator==5.1.1
|
11 |
+
entrypoints==0.4
|
12 |
+
filelock==3.8.0
|
13 |
+
fonttools==4.38.0
|
14 |
+
gitdb==4.0.10
|
15 |
+
GitPython==3.1.29
|
16 |
+
huggingface-hub==0.11.1
|
17 |
+
idna==3.4
|
18 |
+
importlib-metadata==5.1.0
|
19 |
+
Jinja2==3.1.2
|
20 |
+
jsonschema==4.17.3
|
21 |
+
kiwisolver==1.4.4
|
22 |
+
MarkupSafe==2.1.1
|
23 |
+
matplotlib==3.6.2
|
24 |
+
numpy==1.23.5
|
25 |
+
packaging==21.3
|
26 |
+
pandas==1.5.2
|
27 |
+
Pillow==9.3.0
|
28 |
+
protobuf==3.20.3
|
29 |
+
pyarrow==10.0.1
|
30 |
+
pydeck==0.8.0
|
31 |
+
Pygments==2.13.0
|
32 |
+
Pympler==1.0.1
|
33 |
+
pyparsing==3.0.9
|
34 |
+
pyrsistent==0.19.2
|
35 |
+
python-dateutil==2.8.2
|
36 |
+
pytz==2022.6
|
37 |
+
pytz-deprecation-shim==0.1.0.post0
|
38 |
+
PyYAML==6.0
|
39 |
+
regex==2022.10.31
|
40 |
+
requests==2.28.1
|
41 |
+
rich==12.6.0
|
42 |
+
seaborn==0.12.1
|
43 |
+
semver==2.13.0
|
44 |
+
six==1.16.0
|
45 |
+
smmap==5.0.0
|
46 |
+
streamlit==1.15.1
|
47 |
+
timm==0.6.12
|
48 |
+
tokenizers==0.13.2
|
49 |
+
toml==0.10.2
|
50 |
+
toolz==0.12.0
|
51 |
+
torch==1.13.0
|
52 |
+
torchvision==0.14.0
|
53 |
+
tornado==6.2
|
54 |
+
tqdm==4.64.1
|
55 |
+
transformers==4.24.0
|
56 |
+
typing_extensions==4.4.0
|
57 |
+
tzdata==2022.7
|
58 |
+
tzlocal==4.2
|
59 |
+
urllib3==1.26.13
|
60 |
+
validators==0.20.0
|
61 |
+
zipp==3.11.0
|