File size: 2,420 Bytes
8b1e96d
0cffd40
8b1e96d
0cffd40
 
 
 
 
8b1e96d
0cffd40
8b1e96d
 
 
 
 
 
 
ec35e66
 
 
 
 
 
8b1e96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cffd40
8b1e96d
0cffd40
8b1e96d
0cffd40
 
 
8b1e96d
0cffd40
8b1e96d
 
 
 
0cffd40
8b1e96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import gradio as gr    
import torch
import PIL

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
    "1-Step" : ["dmd2_sdxl_1step_unet.bin", 1],
    "4-Step" : ["dmd2_sdxl_4step_unet.bin", 4],
}
loaded = None

CSS = """
.gradio-container {
  max-width: 690px !important;
}
"""

# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = DiffusionPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")


# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
    global loaded
    print(prompt, ckpt)

    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1]

    if loaded != num_inference_steps:
        unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
        unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoints)), map_location="cuda"))
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
        loaded = num_inference_steps
        
    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)

    return results.images[0]



# Gradio Interface

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
    gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>https://huggingface.co/tianweiy/DMD2</a> text-to-image generation</center></p>")
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            submit = gr.Button(scale=1, variant='primary')
    img = gr.Image(label='DMD2 Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    
demo.queue().launch()