Federico Galatolo
Added models in dict
242d29c
import streamlit as st
import os
import json
def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
data_path = "./data/TeTIm-Eval-Mini"
category_map = {
"Digital art": "sampled_art_digital",
"Digital sketches": "sampled_art_sketch",
"Traditional art": "sampled_art_traditional",
"Baroque paintings": "sampled_painting_baroque",
"High renaissance paintings": "sampled_painting_high-renaissance",
"Neoclassicism paintings": "sampled_painting_neoclassicism",
"Animal photography": "sampled_photo_animal",
"Food photography": "sampled_photo_food",
"Landscape photography": "sampled_photo_landscape",
"Portrait photography": "sampled_photo_person"
}
models_map = {
"real": "Real picture",
"DALLE2": "DALL-E 2",
"Stable Diffusion": "Stable Diffusion",
"Latent Diffusion": "Latent Diffusion",
"craiyon": "Craiyon (ex dalle-mini)",
"GLIDE": "GLIDE"
}
if "current" not in st.session_state:
st.session_state["current"] = 1
st.header("TeTIm-Eval-mini showcase")
category = st.selectbox("Category", category_map.keys())
path = os.path.join(data_path, "real", category_map[category], "images")
real_images = sorted(os.listdir(path), key=lambda e: int(e.split(".")[0]))
coll, _, colc, _, colr = st.columns(5)
prev = coll.button("« Previous")
state = colc.empty()
next = colr.button("Next »")
if prev:
st.session_state["current"] = max(st.session_state["current"] - 1, 1)
if next:
st.session_state["current"] = min(st.session_state["current"] + 1, len(real_images))
state.write(f"{st.session_state['current']}/{len(real_images)}")
path = os.path.join(data_path, "real", category_map[category], "annotations")
annotations = sorted(os.listdir(path), key=lambda e: int(e.split(".")[0]))
with open(os.path.join(path, annotations[st.session_state["current"]-1])) as annotation_f:
annotation = json.load(annotation_f)
st.subheader(annotation["caption"])
for (modell, modell_name), (modelr, modelr_name) in batch(list(models_map.items()), n=2):
coll_title, colr_title = st.columns(2)
coll_title.subheader(modell_name)
colr_title.subheader(modelr_name)
pathl = os.path.join(data_path, modell, category_map[category])
if modell == "real": pathl = os.path.join(pathl, "images")
imagesl = sorted(os.listdir(pathl), key=lambda e: int(e.split(".")[0]))
pathr = os.path.join(data_path, modelr, category_map[category])
if modelr == "real": pathr = os.path.join(pathr, "images")
imagesr = sorted(os.listdir(pathr), key=lambda e: int(e.split(".")[0]))
coll_image, colr_image = st.columns(2)
coll_image.image(os.path.join(pathl, imagesl[st.session_state["current"]-1]))
colr_image.image(os.path.join(pathr, imagesr[st.session_state["current"]-1]))