Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import os | |
import gradio as gr | |
from diffusers import DiffusionPipeline | |
import matplotlib.pyplot as plt | |
import torch | |
from PIL import Image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device) | |
from inference import get_lora_network, inference, get_validation_dataloader | |
lora_map = { | |
"None": "None", | |
"Andre Derain": "andre-derain_subset1", | |
"Vincent van Gogh": "van_gogh_subset1", | |
"Andy Warhol": "andy_subset1", | |
"Walter Battiss": "walter-battiss_subset2", | |
"Camille Corot": "camille-corot_subset1", | |
"Claude Monet": "monet_subset2", | |
"Pablo Picasso": "picasso_subset1", | |
"Jackson Pollock": "jackson-pollock_subset1", | |
"Gerhard Richter": "gerhard-richter_subset1", | |
"M.C. Escher": "m.c.-escher_subset1", | |
"Albert Gleizes": "albert-gleizes_subset1", | |
"Hokusai": "katsushika-hokusai_subset1", | |
"Wassily Kandinsky": "kandinsky_subset1", | |
"Gustav Klimt": "klimt_subset3", | |
"Roy Lichtenstein": "roy-lichtenstein_subset1", | |
"Henri Matisse": "henri-matisse_subset1", | |
"Joan Miro": "joan-miro_subset2", | |
} | |
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5): | |
adapter_path = lora_map[adapter_choice] | |
if adapter_path not in [None, "None"]: | |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
prompts = [prompt]*samples | |
infer_loader = get_validation_dataloader(prompts) | |
network = get_lora_network(pipe.unet, adapter_path)["network"] | |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
height=512, width=512, scales=[1.0], | |
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
start_noise=-1, show=False, style_prompt="sks art", no_load=True, | |
from_scratch=True)[0][1.0] | |
return pred_images | |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0): | |
infer_loader = get_validation_dataloader(prompts, image) | |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"] | |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
height=512, width=512, scales=[0.,1.], | |
save_dir=None, seed=seed,steps=20, guidance_scale=7.5, | |
start_noise=start_noise, show=True, style_prompt="sks art", no_load=True, | |
from_scratch=False) | |
return pred_images | |
# def infer(prompt, samples, steps, scale, seed): | |
# generator = torch.Generator(device=device).manual_seed(seed) | |
# images_list = pipe( # type: ignore | |
# [prompt] * samples, | |
# num_inference_steps=steps, | |
# guidance_scale=scale, | |
# generator=generator, | |
# ) | |
# images = [] | |
# safe_image = Image.open(r"data/unsafe.png") | |
# print(images_list) | |
# for i, image in enumerate(images_list["images"]): # type: ignore | |
# if images_list["nsfw_content_detected"][i]: # type: ignore | |
# images.append(safe_image) | |
# else: | |
# images.append(image) | |
# return images | |
block = gr.Blocks() | |
# Direct infer | |
with block: | |
with gr.Group(): | |
gr.Markdown(" # Art-Free Diffusion Demo") | |
with gr.Row(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
max_lines=2, | |
placeholder="Enter your prompt", | |
container=False, | |
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.", | |
) | |
btn = gr.Button("Run", scale=0) | |
gallery = gr.Gallery( | |
label="Generated images", | |
show_label=False, | |
elem_id="gallery", | |
columns=[2], | |
) | |
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") | |
with gr.Row(elem_id="advanced-options"): | |
adapter_choice = gr.Dropdown( | |
label="Choose adapter", | |
choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss", | |
"Camille Corot", "Claude Monet", "Pablo Picasso", | |
"Jackson Pollock", "Gerhard Richter", "M.C. Escher", | |
"Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein", | |
"Henri Matisse", "Joan Miro" | |
], | |
value="None" | |
) | |
# print(adapter_choice[0]) | |
# lora_path = lora_map[adapter_choice.value] | |
# if lora_path is not None: | |
# lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) | |
scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 | |
) | |
print(scale) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
randomize=True, | |
) | |
gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery) | |
advanced_button.click( | |
None, | |
[], | |
text, | |
) | |
block.launch() |