basil-ahmad's picture
Update app.py
b4c4e57 verified
import streamlit as st
from helper import generate_img, DDPM
import torch
from cv2 import resize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
betas = betas.to(device)
alpha = 1.0 - betas
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
model = torch.load("model.pt", map_location=device)
sampler = DDPM(betas)
label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])}
st.markdown("<style>header{display:none}</style>", unsafe_allow_html=True)
sampling_count = 300
batch_size = 1
context = st.radio('Pick one:',
label_to_index.keys()
)
if st.button("click"):
index = [label_to_index[context]]
img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index)
img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0]
img = resize(img, (320,320), interpolation=0)
st.write(context)
st.image(img, clamp=True)