from huggingface_hub import hf_hub_download hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") import torch from PIL import Image from diffusers import DDPMScheduler from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler from module.ip_adapter.utils import load_adapter_to_pipe from pipelines.sdxl_instantir import InstantIRPipeline # prepare models under ./models instantir_path = f'./models' # load pretrained models pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) # load adapter load_adapter_to_pipe( pipe, f"{instantir_path}/adapter.pt", image_encoder_or_path = 'facebook/dinov2-large', ) # load previewer lora pipe.prepare_previewers(instantir_path) pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) # load aggregator weights pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") pipe.aggregator.load_state_dict(pretrained_state_dict) # send to GPU and fp16 pipe.to(device='cuda', dtype=torch.float16) pipe.aggregator.to(device='cuda', dtype=torch.float16) def infer(prompt, input_image): # load a broken image low_quality_image = Image.open(input_image).convert("RGB") # InstantIR restoration image = pipe( prompt=prompt, image=low_quality_image, previewer_scheduler=lcm_scheduler, ).images[0] return image import gradio as gr with gr.Blocks() as demo: with gr.Column(): with gr.Row(): with gr.Column(): lq_img = gr.Image(label="Low-quality image", type="filepath") prompt = gr.Textbox(label="Prompt", value="") submit_btn = gr.Button("InstantIR magic!") output_img = gr.Image(label="InstantIR restored") submit_btn.click( fn=infer, inputs=[prompt, lq_img], outputs=[output_img] ) demo.launch(show_error=True)