BrushEdit / app /src /brushedit_all_in_one_pipeline.py
Yw22's picture
brushedit demo
b2682d8
raw
history blame
2.79 kB
from PIL import Image, ImageEnhance
from diffusers.image_processor import VaeImageProcessor
import numpy as np
import cv2
def BrushEdit_Pipeline(pipe,
prompts,
mask_np,
original_image,
generator,
num_inference_steps,
guidance_scale,
control_strength,
negative_prompt,
num_samples,
blending):
if mask_np.ndim != 3:
mask_np = mask_np[:, :, np.newaxis]
mask_np = mask_np / 255
height, width = mask_np.shape[0], mask_np.shape[1]
## resize the mask and original image to the same size which is divisible by vae_scale_factor
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
mask_blurred = mask_blurred[:, :, np.newaxis]
original_image = cv2.resize(original_image, (width_new, height_new))
init_image = original_image * (1 - mask_np)
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
brushnet_conditioning_scale = float(control_strength)
images = pipe(
[prompts] * num_samples,
init_image,
mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
brushnet_conditioning_scale=brushnet_conditioning_scale,
negative_prompt=[negative_prompt]*num_samples,
height=height_new,
width=width_new,
).images
## convert to vae shape format, must be divisible by 8
original_image_pil = Image.fromarray(original_image).convert("RGB")
init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
init_image_np = init_image_np.astype(np.uint8)
if blending:
mask_blurred = mask_blurred * 0.5 + 0.5
image_all = []
for image_i in images:
image_np = np.array(image_i)
## blending
image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
image_pasted = image_pasted.astype(np.uint8)
image = Image.fromarray(image_pasted)
image_all.append(image)
else:
image_all = images
return image_all, mask_image, mask_np, init_image_np