Spaces:
Runtime error
Runtime error
File size: 6,938 Bytes
86a0c13 b0b9e1f 47cfe13 86a0c13 579c59e 86a0c13 cb5f8d1 47cfe13 21feb87 cb5f8d1 b0b9e1f 47cfe13 b0b9e1f 47cfe13 b0b9e1f 9fbe234 21feb87 9530865 21feb87 b0b9e1f 579c59e 47cfe13 9fbe234 9530865 b0b9e1f 47cfe13 86a0c13 cb5f8d1 86a0c13 cb5f8d1 47cfe13 9fbe234 cb5f8d1 47cfe13 cb5f8d1 47cfe13 21feb87 cb5f8d1 47cfe13 cb5f8d1 47cfe13 1bd7bf1 47cfe13 cb5f8d1 1bd7bf1 21feb87 cb5f8d1 47cfe13 21feb87 47cfe13 21feb87 47cfe13 21feb87 47cfe13 21feb87 47cfe13 21feb87 9530865 21feb87 9530865 cb5f8d1 47cfe13 21feb87 47cfe13 21feb87 47cfe13 21feb87 b0b9e1f 86a0c13 cb5f8d1 e3c61c8 47cfe13 21feb87 47cfe13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from distutils.command.build import build
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed,make_meme
import streamlit.components.v1 as components
import io
import os
root_dir=os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(root_dir, "custom_component/frontend/build")
_component_func = components.declare_component("release_butterflies", path=build_dir)
def release_butterflies(name, key=None):
component_value = _component_func(name=name, key=key, default=0)
return component_value
st.sidebar.subheader("This butterfly does not exist! ")
st.sidebar.image("assets/logo.png", width=200)
st.title("ButterflyGAN")
@st.experimental_singleton
def load_model_intocache(model_name,model_version):
# model_name='ceyda/butterfly_512_base'
gan = load_model(model_name,model_version)
return gan
@st.experimental_singleton
def load_dataset():
dataset=get_dataset()
return dataset
@st.experimental_singleton
def load_variables():# Don't want to open read files over and over. not sure if it makes a diff
latent_walk_code=open("assets/code_snippets/latent_walk.py").read()
latent_walk_code_music=open("assets/code_snippets/latent_walk_music.py").read()
return latent_walk_code,latent_walk_code_music
def img2download(image):
imgByteArr = io.BytesIO()
image.save(imgByteArr, format="JPEG")
imgByteArr = imgByteArr.getvalue()
return imgByteArr
model_name='ceyda/butterfly_cropped_uniq1K_512'
model_version='57d36a15546909557d9f967f47713236c8288838'
# model_version=None
model=load_model_intocache(model_name,model_version)
dataset=load_dataset()
latent_walk_code, latent_walk_code_music=load_variables()
generate_menu="π¦ Make butterflies"
latent_walk_menu="π§ Take a latent walk"
make_meme_menu="π¦ Make a meme"
mosaic_menu="π See the mosaic"
fun_menu="π Release the butterflies"
screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu,fun_menu])
if screen == generate_menu:
batch_size=4 #generate 4 butterflies
col_num=4
def run():
with st.spinner("Generating..."):
ims=generate(model,batch_size)
st.session_state['ims'] = ims
if 'ims' not in st.session_state:
st.session_state['ims'] = None
run()
ims=st.session_state["ims"]
st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \
Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*")
runb=st.button("Generate", on_click=run ,help="generated on the fly maybe slow")
if ims is not None:
cols=st.columns(col_num)
picks=[False]*batch_size
for j,im in enumerate(ims):
i=j%col_num
cols[i].image(im, use_column_width=True)
picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
if any(picks):
# st.write("Nearest butterflies:")
for i,pick in enumerate(picks):
if pick:
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
for r in retrieved_examples["image"]:
cols[i].image(r, use_column_width=True)
st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings")
st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}")
elif screen == latent_walk_menu:
st.write("Take a latent walk :musical_note: with cute butterflies")
cols=st.columns(3)
cols[0].caption("A regular walk (no music)")
cols[0].video("assets/latent_walks/regular_walk.mp4")
cols[1].caption("Walk with music :butterfly:")
cols[1].video("assets/latent_walks/walk_happyrock.mp4")
cols[2].caption("Walk with music :butterfly:")
cols[2].video("assets/latent_walks/walk_cute.mp4")
st.caption("Royalty Free Music from Bensound")
st.write("π§Did those butterflies seem to be dancing to the music?!Here is the secret:")
with st.expander("See the Code Snippets"):
st.write("A regular latent walk:")
st.code(latent_walk_code, language='python')
st.write(":musical_note: latent walk with music:")
st.code(latent_walk_code_music, language='python')
elif screen == make_meme_menu:
if "pigeon" not in st.session_state:
st.session_state['pigeon'] = generate(model,1)[0]
def get_pigeon():
st.session_state['pigeon'] = generate(model,1)[0]
cols= st.columns(2)
cols[0].button("change pigeon",on_click=get_pigeon)
no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon")
show_text=cols[1].checkbox("Show text?",True)
meme_text=st.text_input("Enter text","Is this a pigeon?")
meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg)
st.image(meme)
coly=st.columns(2)
coly[0].download_button("Download", img2download(meme),mime="image/jpeg")
coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter")
elif screen == mosaic_menu:
cols=st.columns(2)
cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)")
cols[0].image("assets/train_data_mosaic_lowres.jpg")
cols[0].write("π view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)")
cols[1].markdown("These are the butterflies our model generated.")
cols[1].image("assets/gen_mosaic_lowres.jpg")
cols[1].write("π view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)")
elif screen == fun_menu:
cols=st.columns([1,2])
cols[0].write("While working on this project")
cols[0].image("assets/butterflies_everywhere.jpg")
with cols[1]:
release_butterflies("Hello World")
# footer stuff
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) used")
# Link project repo( scripts etc )
# Credits
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
st.sidebar.caption(f"Contributors:")
st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
## Feel free to add more & change stuff ^ |