import streamlit as st from helper import generate_img, DDPM import torch from cv2 import resize device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = 500 beta1 = 1e-4 beta2 = 0.02 betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 betas = betas.to(device) alpha = 1.0 - betas alpha_bar = torch.cumprod(alpha, dim=0).to(device) model = torch.load("model.pt", map_location=device) sampler = DDPM(betas) label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])} st.markdown("", unsafe_allow_html=True) sampling_count = 300 batch_size = 1 context = st.radio('Pick one:', label_to_index.keys() ) if st.button("click"): index = [label_to_index[context]] img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index) img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0] img = resize(img, (320,320), interpolation=0) st.write(context) st.image(img, clamp=True)