File size: 4,684 Bytes
bd199cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e18ee
bd199cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e18ee
 
 
 
 
 
 
 
 
c4f7e78
 
83e18ee
 
 
 
 
 
 
c4f7e78
83e18ee
c4f7e78
83e18ee
 
bd199cf
c4f7e78
83e18ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
from torchvision import transforms
from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler

NUM_INFERENCE_STEPS = 50
device = "cuda"

base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to(device)

refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to(device)

base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)


def preprocess_image(image):
    image = image.convert("RGB")
    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
    image = transforms.ToTensor()(image)
    image = image * 2 - 1
    image = image.unsqueeze(0).to(device)
    return image


def preprocess_map(map):
    map = map.convert("L")
    map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
    # convert to tensor
    map = transforms.ToTensor()(map)
    map = map.to(device)
    return map


def inference(image, map, gs, prompt, negative_prompt):
    validate_inputs(image, map)
    image = preprocess_image(image)
    map = preprocess_map(map)
    edited_images = base(prompt=prompt, original_image=image, image=image, strength=1, guidance_scale=gs,
                         num_images_per_prompt=1,
                         negative_prompt=negative_prompt,
                         map=map,
                         num_inference_steps=NUM_INFERENCE_STEPS, denoising_end=0.8, output_type="latent").images

    edited_images = refiner(prompt=prompt, original_image=image, image=edited_images, strength=1, guidance_scale=7.5,
                            num_images_per_prompt=1,
                            negative_prompt=negative_prompt,
                            map=map,
                            num_inference_steps=NUM_INFERENCE_STEPS, denoising_start=0.8).images[0]
    return edited_images


def validate_inputs(image, map):
    if image is None:
        raise gr.Error("Missing image")
    if map is None:
        raise gr.Error("Missing map")


example1 = ["assets/input2.jpg", "assets/map2.jpg", 17.5,
            "Tree of life under the sea, ethereal, glittering, lens flares, cinematic lighting, artwork by Anna Dittmann & Carne Griffiths, 8k, unreal engine 5, hightly detailed, intricate detailed",
            "bad anatomy, poorly drawn face, out of frame, gibberish, lowres, duplicate, morbid, darkness, maniacal, creepy, fused, blurry background, crosseyed, extra limbs, mutilated, dehydrated, surprised, poor quality, uneven, off-centered, bird illustration, painting, cartoons"]
example2 = ["assets/input3.jpg", "assets/map4.png", 21,
            "overgrown atrium, nature, ancient black marble columns and terracotta tile floors, waterfall, ultra-high quality, octane render, corona render, UHD, 64k",
            "Two bodies, Two heads, doll, extra nipples, bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, out of frame, cloned face, watermark, text, lowres, disfigured, ostentatious, ugly, oversaturated, grain, low resolution, blurry, bad anatomy, poorly drawn face, mutant, mutated,  blurred, out of focus, long neck, long body, ugly, disgusting, bad drawing, childish"]
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(label="Input Image", type="pil")
                change_map = gr.Image(label="Change Map", type="pil")
            gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
            prompt = gr.Textbox(label="Prompt")
            neg_prompt = gr.Textbox(label="Negative Prompt")
            with gr.Row():
                clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
                run_btn = gr.Button("Run",variant="primary")

        output = gr.Image(label="Output Image")
    gr.Examples(examples=[example1, example2],inputs=[input_image, change_map, gs, prompt, neg_prompt])
    gr.Markdown("Differential Diffusion with SDXL; Thanks to the community for the prompts in the examples.")
    run_btn.click(inference, inputs=[input_image, change_map, gs, prompt, neg_prompt], outputs=output)
    clr_btn.add(output)
if __name__ == "__main__":
    demo.launch()