|
import argparse |
|
import os |
|
|
|
import torch |
|
from PIL import Image, ImageFilter |
|
from transformers import CLIPTextModel |
|
|
|
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Inference") |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="Path to pretrained model or model identifier from huggingface.co/models.", |
|
) |
|
parser.add_argument( |
|
"--validation_image", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="The directory of the validation image", |
|
) |
|
parser.add_argument( |
|
"--validation_mask", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="The directory of the validation mask", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="./test-infer/", |
|
help="The output directory where predictions are saved", |
|
) |
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.") |
|
|
|
args = parser.parse_args() |
|
|
|
if __name__ == "__main__": |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
generator = None |
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None |
|
) |
|
|
|
pipe.unet = UNet2DConditionModel.from_pretrained( |
|
args.model_path, |
|
subfolder="unet", |
|
revision=None, |
|
) |
|
pipe.text_encoder = CLIPTextModel.from_pretrained( |
|
args.model_path, |
|
subfolder="text_encoder", |
|
revision=None, |
|
) |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe = pipe.to("cuda") |
|
|
|
if args.seed is not None: |
|
generator = torch.Generator(device="cuda").manual_seed(args.seed) |
|
|
|
image = Image.open(args.validation_image) |
|
mask_image = Image.open(args.validation_mask) |
|
|
|
results = pipe( |
|
["a photo of sks"] * 16, |
|
image=image, |
|
mask_image=mask_image, |
|
num_inference_steps=25, |
|
guidance_scale=5, |
|
generator=generator, |
|
).images |
|
|
|
erode_kernel = ImageFilter.MaxFilter(3) |
|
mask_image = mask_image.filter(erode_kernel) |
|
|
|
blur_kernel = ImageFilter.BoxBlur(1) |
|
mask_image = mask_image.filter(blur_kernel) |
|
|
|
for idx, result in enumerate(results): |
|
result = Image.composite(result, image, mask_image) |
|
result.save(f"{args.output_dir}/{idx}.png") |
|
|
|
del pipe |
|
torch.cuda.empty_cache() |
|
|