Spaces:
Runtime error
Runtime error
"""Interface for labeling concepts in images. | |
""" | |
from typing import Optional | |
import gradio as gr | |
from src import global_variables | |
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME | |
def get_image( | |
step: int, | |
split: str, | |
index: str, | |
filtered_indices: dict, | |
profile: gr.OAuthProfile | |
): | |
username = profile.username | |
try: | |
int_index = int(index) | |
except: | |
gr.Warning("Error parsing index using 0") | |
int_index = 0 | |
sample_idx = int_index + step | |
if sample_idx < 0: | |
gr.Warning("No previous image.") | |
sample_idx = 0 | |
if sample_idx >= len(global_variables.all_metadata[split]): | |
gr.Warning("No next image.") | |
sample_idx = len(global_variables.all_metadata[split]) - 1 | |
sample = global_variables.all_metadata[split][sample_idx] | |
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" | |
try: | |
username_votes = global_variables.all_votes[sample["id"]][username] | |
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] | |
unseen_concepts = [c for c in CONCEPTS if c not in username_votes] | |
except KeyError: | |
voted_concepts = [] | |
unseen_concepts = [] | |
tie_concepts = [c for c in CONCEPTS if sample[c] is None] | |
return ( | |
image_path, | |
voted_concepts, | |
f"{split}:{sample_idx}", | |
sample["class"], | |
{c: sample[c] for c in CONCEPTS}, | |
str(sample_idx), | |
unseen_concepts, | |
tie_concepts, | |
filtered_indices, | |
) | |
def make_get_image(step): | |
def f( | |
split: str, | |
index: str, | |
filtered_indices: dict, | |
profile: gr.OAuthProfile | |
): | |
return get_image(step, split, index, filtered_indices, profile) | |
return f | |
get_next_image = make_get_image(1) | |
get_prev_image = make_get_image(-1) | |
get_current_image = make_get_image(0) | |
def submit_label( | |
voted_concepts: list, | |
current_image: Optional[str], | |
split, | |
index, | |
filtered_indices, | |
profile: gr.OAuthProfile | |
): | |
username = profile.username | |
if current_image is None: | |
gr.Warning("No image selected.") | |
return None, None, None, None, None, None, None, index, filtered_indices | |
global_variables.update_votes(username, current_image, voted_concepts) | |
gr.Info("Submit success") | |
return get_next_image( | |
split, | |
index, | |
filtered_indices, | |
profile | |
) | |
def save_current_work( | |
profile: gr.OAuthProfile, | |
): | |
username = profile.username | |
global_variables.save_current_work(username) | |
gr.Info("Save success") | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"## # Image Selection", | |
) | |
split = gr.Radio( | |
label="Split", | |
choices=["train", "test"], | |
value="train", | |
) | |
index = gr.Textbox( | |
value="0", | |
label="Index", | |
max_lines=1, | |
) | |
with gr.Row(): | |
prev_button = gr.Button( | |
value="Prev", | |
) | |
next_button = gr.Button( | |
value="Next", | |
) | |
gr.LoginButton() | |
submit_button = gr.Button( | |
value="Local Submit", | |
) | |
with gr.Row(): | |
save_button = gr.Button( | |
value="Save", | |
) | |
with gr.Group(): | |
voted_concepts = gr.CheckboxGroup( | |
label="Voted Concepts", | |
choices=CONCEPTS, | |
) | |
unseen_concepts = gr.CheckboxGroup( | |
label="Previously Unseen Concepts", | |
choices=CONCEPTS, | |
) | |
tie_concepts = gr.CheckboxGroup( | |
label="Tie Concepts", | |
choices=CONCEPTS, | |
) | |
with gr.Group(): | |
gr.Markdown( | |
"## # Image Info", | |
) | |
im_class = gr.Textbox( | |
label="Class", | |
) | |
im_concepts = gr.JSON( | |
label="Concepts", | |
) | |
with gr.Column(): | |
image = gr.Image( | |
label="Image", | |
) | |
current_image = gr.State(None) | |
filtered_indices = gr.State({ | |
split: list(range(len(global_variables.all_metadata[split]))) | |
for split in global_variables.all_metadata | |
}) | |
common_output = [ | |
image, | |
voted_concepts, | |
current_image, | |
im_class, | |
im_concepts, | |
index, | |
unseen_concepts, | |
tie_concepts, | |
filtered_indices, | |
] | |
common_input = [split, index, filtered_indices] | |
prev_button.click( | |
get_prev_image, | |
inputs=common_input, | |
outputs=common_output | |
) | |
next_button.click( | |
get_next_image, | |
inputs=common_input, | |
outputs=common_output | |
) | |
submit_button.click( | |
submit_label, | |
inputs=[voted_concepts, current_image, split, index, filtered_indices], | |
outputs=common_output | |
) | |
index.submit( | |
get_current_image, | |
inputs=common_input, | |
outputs=common_output, | |
) | |
save_button.click( | |
save_current_work, | |
outputs=[image] | |
) |