Joaquin Romero Flores commited on
Commit
8552e9f
·
1 Parent(s): 38e2ee6

laoding utils & app

Browse files
Files changed (2) hide show
  1. app.py +55 -0
  2. utils.py +27 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit Library
2
+ import streamlit as st
3
+
4
+ from utils import load_model, generate
5
+
6
+ # Main Page
7
+ # Title
8
+ st.title("Butterflies Generator")
9
+ # Sumamry Description Model
10
+ st.write("This is Light GAN trained model")
11
+
12
+ # Sidebar
13
+ st.sidebar.subheader("This butterfly does not exist!, could you believe it?")
14
+ # Logo folder
15
+ st.sidebar.image("assets/logo.png", width=200)
16
+ #
17
+ st.sidebar.caption("Demo live created")
18
+
19
+ # Loading & Model's Name
20
+ repo_id = "ceyda/butterfly_cropped_uniq1K_512"
21
+ gan_model = load_model(repo_id)
22
+
23
+ # Butterflies Generator (4)
24
+ n_butterflies = 4
25
+
26
+
27
+ def run():
28
+ with st.spinner("Loading, Please! be patient..."):
29
+ # Inner Processing
30
+ ims = generate(gan_model, n_butterflies)
31
+ # To save on
32
+ st.session_state["ims"]
33
+
34
+
35
+ if "ims" not in st.session_state:
36
+ st.session_state["ims"] = None
37
+ run()
38
+
39
+ # Extracting info from ims
40
+ ims = st.session_state["ims"]
41
+
42
+ #
43
+ run_button = st.button(
44
+ "Please! Generate buttlerflies",
45
+ on_click= run(),
46
+ help="we are in flight, fasten your seatbelt."
47
+ )
48
+
49
+ if ims is not None:
50
+ cols = st.columns(n_butterflies)
51
+ for j, im in enumerate(ims):
52
+ i = j % n_butterflies
53
+ cols[i].image(im, use_column_width=True)
54
+
55
+
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils Standards
2
+
3
+ # Numerical Management
4
+ import numpy as np
5
+ # Pythorch
6
+ import torch
7
+ # A hugging face library
8
+ from huggan.pytorch.lightweight_gan.leightweight_gan import LightweightGAN
9
+
10
+ """Let's now set up the model functions"""
11
+
12
+ def load_model(model_name = "ceyda/butterfly_cropped_uniq1K_512", model_version=None):
13
+ # GAN set-up
14
+ gan = LightweightGAN.from_pretrained(model_name, versio=model_version)
15
+ # GAN Inference Evaluation
16
+ gan.eval()
17
+ return gan
18
+
19
+ # Let's set-up a second function for generation form
20
+
21
+ def generate(gan, batch_size=1):
22
+ with torch.no_grad():
23
+ # Cleaning process to properly fit it to the model
24
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
25
+ #
26
+ ims = ims.permute(0, 2, 3, 1).deatch().cpu().numpy().astype(np.uint8)
27
+ return ims