svjack's picture
Upload 1392 files
43b7e92 verified
raw
history blame
2.49 kB
import argparse
import os
import torch
from PIL import Image, ImageFilter
from transformers import CLIPTextModel
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
"--model_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
required=True,
help="The directory of the validation image",
)
parser.add_argument(
"--validation_mask",
type=str,
default=None,
required=True,
help="The directory of the validation mask",
)
parser.add_argument(
"--output_dir",
type=str,
default="./test-infer/",
help="The output directory where predictions are saved",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.")
args = parser.parse_args()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
generator = None
# create & load model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None
)
pipe.unet = UNet2DConditionModel.from_pretrained(
args.model_path,
subfolder="unet",
revision=None,
)
pipe.text_encoder = CLIPTextModel.from_pretrained(
args.model_path,
subfolder="text_encoder",
revision=None,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
if args.seed is not None:
generator = torch.Generator(device="cuda").manual_seed(args.seed)
image = Image.open(args.validation_image)
mask_image = Image.open(args.validation_mask)
results = pipe(
["a photo of sks"] * 16,
image=image,
mask_image=mask_image,
num_inference_steps=25,
guidance_scale=5,
generator=generator,
).images
erode_kernel = ImageFilter.MaxFilter(3)
mask_image = mask_image.filter(erode_kernel)
blur_kernel = ImageFilter.BoxBlur(1)
mask_image = mask_image.filter(blur_kernel)
for idx, result in enumerate(results):
result = Image.composite(result, image, mask_image)
result.save(f"{args.output_dir}/{idx}.png")
del pipe
torch.cuda.empty_cache()