Spaces:
Sleeping
Sleeping
import streamlit as st | |
from helper import generate_img, DDPM | |
import torch | |
from cv2 import resize | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
timesteps = 500 | |
beta1 = 1e-4 | |
beta2 = 0.02 | |
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 | |
betas = betas.to(device) | |
alpha = 1.0 - betas | |
alpha_bar = torch.cumprod(alpha, dim=0).to(device) | |
model = torch.load("model.pt", map_location=device) | |
sampler = DDPM(betas) | |
label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])} | |
st.markdown("<style>header{display:none}</style>", unsafe_allow_html=True) | |
sampling_count = 300 | |
batch_size = 1 | |
context = st.radio('Pick one:', | |
label_to_index.keys() | |
) | |
if st.button("click"): | |
index = [label_to_index[context]] | |
img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index) | |
img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0] | |
img = resize(img, (320,320), interpolation=0) | |
st.write(context) | |
st.image(img, clamp=True) | |