from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts from htbuilder.units import percent, px from htbuilder.funcs import rgba, rgb import streamlit as st import os import sys import argparse import clip import numpy as np from PIL import Image from dalle.models import Dalle from dalle.utils.utils import set_seed, clip_score import cv2 import subprocess import signal def signal_handler(sig, frame): print('You pressed Ctrl+C!') sys.exit(0) def generate(prompt,crazy): print("-------------------") signal.signal(signal.SIGINT, signal_handler) device = 'cpu' model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model. model.to(device=device) num_candidates = 3 images = [] set_seed(np.random.randint(0,10000)) # Sampling images = model.sampling(prompt=prompt, top_k=2048, top_p=None, softmax_temperature=crazy, num_candidates=num_candidates, device=device).cpu().numpy() images = np.transpose(images, (0, 2, 3, 1)) # CLIP Re-ranking model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) model_clip.to(device=device) rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device) # Save images #return images[rank] for image in images: cv2.imwrite('temp/'+str(np.random.randint(0,10000))+'.jpeg', image) generate("a pink house",0.75)