File size: 1,071 Bytes
53ef34c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4c4e57
450919d
53ef34c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import streamlit as st
from helper import generate_img, DDPM
import torch
from cv2 import resize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
betas = betas.to(device)
alpha = 1.0 - betas
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
model = torch.load("model.pt", map_location=device)
sampler = DDPM(betas)

label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])}

st.markdown("<style>header{display:none}</style>", unsafe_allow_html=True)

sampling_count = 300
batch_size = 1
context = st.radio('Pick one:',
  label_to_index.keys()
  )


if st.button("click"):
  index = [label_to_index[context]]
  img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index)
  img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0]
  img = resize(img, (320,320), interpolation=0)
  st.write(context)
  st.image(img, clamp=True)