Spaces:
Runtime error
Runtime error
Joaquin Romero Flores
commited on
Commit
·
8552e9f
1
Parent(s):
38e2ee6
laoding utils & app
Browse files
app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Streamlit Library
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from utils import load_model, generate
|
5 |
+
|
6 |
+
# Main Page
|
7 |
+
# Title
|
8 |
+
st.title("Butterflies Generator")
|
9 |
+
# Sumamry Description Model
|
10 |
+
st.write("This is Light GAN trained model")
|
11 |
+
|
12 |
+
# Sidebar
|
13 |
+
st.sidebar.subheader("This butterfly does not exist!, could you believe it?")
|
14 |
+
# Logo folder
|
15 |
+
st.sidebar.image("assets/logo.png", width=200)
|
16 |
+
#
|
17 |
+
st.sidebar.caption("Demo live created")
|
18 |
+
|
19 |
+
# Loading & Model's Name
|
20 |
+
repo_id = "ceyda/butterfly_cropped_uniq1K_512"
|
21 |
+
gan_model = load_model(repo_id)
|
22 |
+
|
23 |
+
# Butterflies Generator (4)
|
24 |
+
n_butterflies = 4
|
25 |
+
|
26 |
+
|
27 |
+
def run():
|
28 |
+
with st.spinner("Loading, Please! be patient..."):
|
29 |
+
# Inner Processing
|
30 |
+
ims = generate(gan_model, n_butterflies)
|
31 |
+
# To save on
|
32 |
+
st.session_state["ims"]
|
33 |
+
|
34 |
+
|
35 |
+
if "ims" not in st.session_state:
|
36 |
+
st.session_state["ims"] = None
|
37 |
+
run()
|
38 |
+
|
39 |
+
# Extracting info from ims
|
40 |
+
ims = st.session_state["ims"]
|
41 |
+
|
42 |
+
#
|
43 |
+
run_button = st.button(
|
44 |
+
"Please! Generate buttlerflies",
|
45 |
+
on_click= run(),
|
46 |
+
help="we are in flight, fasten your seatbelt."
|
47 |
+
)
|
48 |
+
|
49 |
+
if ims is not None:
|
50 |
+
cols = st.columns(n_butterflies)
|
51 |
+
for j, im in enumerate(ims):
|
52 |
+
i = j % n_butterflies
|
53 |
+
cols[i].image(im, use_column_width=True)
|
54 |
+
|
55 |
+
|
utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utils Standards
|
2 |
+
|
3 |
+
# Numerical Management
|
4 |
+
import numpy as np
|
5 |
+
# Pythorch
|
6 |
+
import torch
|
7 |
+
# A hugging face library
|
8 |
+
from huggan.pytorch.lightweight_gan.leightweight_gan import LightweightGAN
|
9 |
+
|
10 |
+
"""Let's now set up the model functions"""
|
11 |
+
|
12 |
+
def load_model(model_name = "ceyda/butterfly_cropped_uniq1K_512", model_version=None):
|
13 |
+
# GAN set-up
|
14 |
+
gan = LightweightGAN.from_pretrained(model_name, versio=model_version)
|
15 |
+
# GAN Inference Evaluation
|
16 |
+
gan.eval()
|
17 |
+
return gan
|
18 |
+
|
19 |
+
# Let's set-up a second function for generation form
|
20 |
+
|
21 |
+
def generate(gan, batch_size=1):
|
22 |
+
with torch.no_grad():
|
23 |
+
# Cleaning process to properly fit it to the model
|
24 |
+
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
|
25 |
+
#
|
26 |
+
ims = ims.permute(0, 2, 3, 1).deatch().cpu().numpy().astype(np.uint8)
|
27 |
+
return ims
|