Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import json | |
def batch(iterable, n=1): | |
l = len(iterable) | |
for ndx in range(0, l, n): | |
yield iterable[ndx:min(ndx + n, l)] | |
data_path = "./data/TeTIm-Eval-Mini" | |
category_map = { | |
"Digital art": "sampled_art_digital", | |
"Digital sketches": "sampled_art_sketch", | |
"Traditional art": "sampled_art_traditional", | |
"Baroque paintings": "sampled_painting_baroque", | |
"High renaissance paintings": "sampled_painting_high-renaissance", | |
"Neoclassicism paintings": "sampled_painting_neoclassicism", | |
"Animal photography": "sampled_photo_animal", | |
"Food photography": "sampled_photo_food", | |
"Landscape photography": "sampled_photo_landscape", | |
"Portrait photography": "sampled_photo_person" | |
} | |
models_map = { | |
"real": "Real picture", | |
"DALLE2": "DALL-E 2", | |
"Stable Diffusion": "Stable Diffusion", | |
"Latent Diffusion": "Latent Diffusion", | |
"craiyon": "Craiyon (ex dalle-mini)", | |
"GLIDE": "GLIDE" | |
} | |
if "current" not in st.session_state: | |
st.session_state["current"] = 1 | |
st.header("TeTIm-Eval-mini showcase") | |
category = st.selectbox("Category", category_map.keys()) | |
path = os.path.join(data_path, "real", category_map[category], "images") | |
real_images = sorted(os.listdir(path), key=lambda e: int(e.split(".")[0])) | |
coll, _, colc, _, colr = st.columns(5) | |
prev = coll.button("« Previous") | |
state = colc.empty() | |
next = colr.button("Next »") | |
if prev: | |
st.session_state["current"] = max(st.session_state["current"] - 1, 1) | |
if next: | |
st.session_state["current"] = min(st.session_state["current"] + 1, len(real_images)) | |
state.write(f"{st.session_state['current']}/{len(real_images)}") | |
path = os.path.join(data_path, "real", category_map[category], "annotations") | |
annotations = sorted(os.listdir(path), key=lambda e: int(e.split(".")[0])) | |
with open(os.path.join(path, annotations[st.session_state["current"]-1])) as annotation_f: | |
annotation = json.load(annotation_f) | |
st.subheader(annotation["caption"]) | |
for (modell, modell_name), (modelr, modelr_name) in batch(list(models_map.items()), n=2): | |
coll_title, colr_title = st.columns(2) | |
coll_title.subheader(modell_name) | |
colr_title.subheader(modelr_name) | |
pathl = os.path.join(data_path, modell, category_map[category]) | |
if modell == "real": pathl = os.path.join(pathl, "images") | |
imagesl = sorted(os.listdir(pathl), key=lambda e: int(e.split(".")[0])) | |
pathr = os.path.join(data_path, modelr, category_map[category]) | |
if modelr == "real": pathr = os.path.join(pathr, "images") | |
imagesr = sorted(os.listdir(pathr), key=lambda e: int(e.split(".")[0])) | |
coll_image, colr_image = st.columns(2) | |
coll_image.image(os.path.join(pathl, imagesl[st.session_state["current"]-1])) | |
colr_image.image(os.path.join(pathr, imagesr[st.session_state["current"]-1])) | |