InstantIR / app_with_diffusers.py
fffiloni's picture
Update app_with_diffusers.py
29dbe4a verified
raw
history blame
2.19 kB
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".")
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".")
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".")
import torch
from PIL import Image
from diffusers import DDPMScheduler
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
from module.ip_adapter.utils import load_adapter_to_pipe
from pipelines.sdxl_instantir import InstantIRPipeline
# prepare models under ./models
instantir_path = f'./models'
# load pretrained models
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
# load adapter
load_adapter_to_pipe(
pipe,
f"{instantir_path}/adapter.pt",
image_encoder_or_path = 'facebook/dinov2-large',
)
# load previewer lora
pipe.prepare_previewers(instantir_path)
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
# load aggregator weights
pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt")
pipe.aggregator.load_state_dict(pretrained_state_dict)
# send to GPU and fp16
pipe.to(device='cuda', dtype=torch.float16)
pipe.aggregator.to(device='cuda', dtype=torch.float16)
def infer(input_image):
# load a broken image
low_quality_image = Image.open(input_image).convert("RGB")
# InstantIR restoration
image = pipe(
image=low_quality_image,
previewer_scheduler=lcm_scheduler,
).images[0]
return image
import gradio as gr
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
lq_img = gr.Image(label="Low-quality image", type="filepath")
submit_btn = gr.Button("InstantIR magic!")
output_img = gr.Image(label="InstantIR restored")
submit_btn.click(
fn=infer,
inputs=[lq_img],
outputs=[output_img]
)
demo.launch(show_error=True)