Spaces:
Runtime error
Runtime error
import streamlit as st # HF spaces at v1.2.0 | |
from demo import load_model,generate,get_dataset,embed,make_meme | |
from PIL import Image | |
import numpy as np | |
import io | |
# TODOs | |
# Add markdown short readme project intro | |
# Add link to wandb logs | |
st.sidebar.subheader("This butterfly does not exist! ") | |
st.sidebar.image("assets/logo.png", width=200) | |
st.title("ButterflyGAN") | |
def load_model_intocache(model_name,model_version): | |
# model_name='ceyda/butterfly_512_base' | |
gan = load_model(model_name,model_version) | |
return gan | |
def load_dataset(): | |
dataset=get_dataset() | |
return dataset | |
def load_variables():# Don't want to open read files over and over. not sure if it makes a diff | |
st.session_state['latent_walk_code']=open("assets/code_snippets/latent_walk.py").read() | |
st.session_state['latent_walk_code_music']=open("assets/code_snippets/latent_walk_music.py").read() | |
def img2download(image): | |
imgByteArr = io.BytesIO() | |
image.save(imgByteArr, format="JPEG") | |
imgByteArr = imgByteArr.getvalue() | |
return imgByteArr | |
model_name='ceyda/butterfly_cropped_uniq1K_512' | |
# model_version='0edac54b81958b82ce9fd5c1f688c33ac8e4f223' | |
model_version=None ##TBD | |
model=load_model_intocache(model_name,model_version) | |
dataset=load_dataset() | |
load_variables() | |
generate_menu="π¦ Make butterflies" | |
latent_walk_menu="π§ Take a latent walk" | |
make_meme_menu="π¦ Make a meme" | |
mosaic_menu="π See the mosaic" | |
fun_menu="Release the butterflies" | |
screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu]) | |
if screen == generate_menu: | |
batch_size=4 #generate 4 butterflies | |
col_num=4 | |
def run(): | |
with st.spinner("Generating..."): | |
ims=generate(model,batch_size) | |
st.session_state['ims'] = ims | |
if 'ims' not in st.session_state: | |
st.session_state['ims'] = None | |
run() | |
ims=st.session_state["ims"] | |
st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \ | |
Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*") | |
runb=st.button("Generate", on_click=run ,help="generated on the fly maybe slow") | |
if ims is not None: | |
cols=st.columns(col_num) | |
picks=[False]*batch_size | |
for j,im in enumerate(ims): | |
i=j%col_num | |
cols[i].image(im) | |
picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j)) | |
# meme_it=cols[i].button("What is this?",key="meme_"+str(j)) | |
# if meme_it: | |
# no_bg=st.checkbox("Remove background?",True) | |
# meme_text=st.text_input("Meme text","Is this a pigeon?") | |
# meme=make_meme(im,text=meme_text,show_text=True,remove_background=no_bg) | |
# st.image(meme) | |
# if picks[j]: | |
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5) | |
# for r in retrieved_examples["image"]: | |
# st.image(r) | |
if any(picks): | |
# st.write("Nearest butterflies:") | |
for i,pick in enumerate(picks): | |
if pick: | |
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5) | |
for r in retrieved_examples["image"]: | |
cols[i].image(r) | |
st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings") | |
st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}") | |
elif screen == latent_walk_menu: | |
latent_walk_code=open("assets/code_snippets/latent_walk.py").read() | |
latent_walk_music_code=open("assets/code_snippets/latent_walk_music.py").read() | |
st.write("Take a latent walk :musical_note: with cute butterflies") | |
cols=st.columns(3) | |
cols[0].caption("A regular walk (no music)") | |
cols[0].video("assets/latent_walks/regular_walk.mp4") | |
cols[1].caption("Walk with music :butterfly:") | |
cols[1].video("assets/latent_walks/walk_happyrock.mp4") | |
cols[2].caption("Walk with music :butterfly:") | |
cols[2].video("assets/latent_walks/walk_cute.mp4") | |
st.caption("Royalty Free Music from Bensound") | |
st.write("π§Did those butterflies seem to be dancing to the music?!Here is the secret:") | |
with st.expander("See the Code Snippets"): | |
st.write("A regular latent walk:") | |
st.code(st.session_state['latent_walk_code'], language='python') | |
st.write(":musical_note: latent walk with music:") | |
st.code(st.session_state['latent_walk_code_music'], language='python') | |
elif screen == make_meme_menu: | |
if "pigeon" not in st.session_state: | |
st.session_state['pigeon'] = generate(model,1)[0] | |
def get_pigeon(): | |
st.session_state['pigeon'] = generate(model,1)[0] | |
cols= st.columns(2) | |
cols[0].button("change pigeon",on_click=get_pigeon) | |
no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon") | |
show_text=cols[1].checkbox("Show text?",True) | |
meme_text=st.text_input("Enter text","Is this a pigeon?") | |
meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg) | |
st.image(meme) | |
coly=st.columns(2) | |
coly[0].download_button("Download", img2download(meme),mime="image/jpeg") | |
coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter") | |
elif screen == mosaic_menu: | |
cols=st.columns(2) | |
cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)") | |
cols[0].image("assets/train_data_mosaic_lowres.jpg") | |
cols[0].write("π view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)") | |
cols[1].markdown("These are the butterflies our model generated.") | |
cols[1].image("assets/gen_mosaic_lowres.jpg") | |
cols[1].write("π view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)") | |
# footer stuff | |
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/huggan/smithsonian_butterflies_subset) used") | |
# Link project repo( scripts etc ) | |
# Credits | |
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon") | |
st.sidebar.caption(f"Contributors:") | |
st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)") | |
## Feel free to add more & change stuff ^ |