Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import math | |
import random | |
import gradio as gr | |
import torch | |
from PIL import Image, ImageOps | |
from diffusers import StableDiffusionPipeline | |
help_text = """ | |
""" | |
example_instructions = [ | |
"A river" | |
] | |
model_id = "dimentox/heightmapstyle" | |
def main(): | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) | |
# example_image = Image.open("imgs/example.jpg").convert("RGB") | |
def load_example( | |
steps: int, | |
randomize_seed: bool, | |
seed: int, | |
randomize_cfg: bool, | |
text_cfg_scale: float, | |
image_cfg_scale: float, | |
): | |
example_instruction = random.choice(example_instructions) | |
return [example_instruction] + generate( | |
example_instruction, | |
steps, | |
randomize_seed, | |
seed, | |
randomize_cfg, | |
text_cfg_scale, | |
image_cfg_scale, | |
) | |
def generate( | |
instruction: str, | |
steps: int, | |
randomize_seed: bool, | |
seed: int, | |
randomize_cfg: bool, | |
text_cfg_scale: float, | |
image_cfg_scale: float, | |
): | |
seed = random.randint(0, 100000) if randomize_seed else seed | |
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale | |
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale | |
# width, height = input_image.size | |
# factor = 512 / max(width, height) | |
# factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | |
# width = int((width * factor) // 64) * 64 | |
# height = int((height * factor) // 64) * 64 | |
# input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | |
if instruction == "": | |
return [seed] | |
generator = torch.manual_seed(seed) | |
edited_image = pipe( | |
instruction, | |
guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
num_inference_steps=steps, generator=generator, | |
).images[0] | |
return [seed, text_cfg_scale, image_cfg_scale, edited_image] | |
def reset(): | |
return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None] | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=100): | |
generate_button = gr.Button("Generate") | |
with gr.Column(scale=1, min_width=100): | |
load_button = gr.Button("Load Example") | |
with gr.Column(scale=1, min_width=100): | |
reset_button = gr.Button("Reset") | |
with gr.Column(scale=3): | |
instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True) | |
with gr.Row(): | |
edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False) | |
edited_image.style(height=512, width=512) | |
with gr.Row(): | |
steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
randomize_seed = gr.Radio( | |
["Fix Seed", "Randomize Seed"], | |
value="Randomize Seed", | |
type="index", | |
show_label=False, | |
interactive=True, | |
) | |
seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True) | |
randomize_cfg = gr.Radio( | |
["Fix CFG", "Randomize CFG"], | |
value="Fix CFG", | |
type="index", | |
show_label=False, | |
interactive=True, | |
) | |
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True) | |
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True) | |
gr.Markdown(help_text) | |
load_button.click( | |
fn=load_example, | |
inputs=[ | |
steps, | |
randomize_seed, | |
seed, | |
randomize_cfg, | |
text_cfg_scale, | |
image_cfg_scale, | |
], | |
outputs=[instruction, seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
generate_button.click( | |
fn=generate, | |
inputs=[ | |
instruction, | |
steps, | |
randomize_seed, | |
seed, | |
randomize_cfg, | |
text_cfg_scale, | |
image_cfg_scale, | |
], | |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
reset_button.click( | |
fn=reset, | |
inputs=[], | |
outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch(share=False) | |
if __name__ == "__main__": | |
main() | |
import gradio as gr | |
gr.Examples( | |
[["heightmapsstyle", "a lake with a river"], | |
["heightmapsstyle", "greyscale", "a river running though flat planes"]], | |
[txt, txt_2], | |
cache_examples=True, | |
) | |
gr.load().launch() | |
# sr_b64 = super_resolution(hmap_b64) | |