from PIL import Image, ImageEnhance from diffusers.image_processor import VaeImageProcessor import numpy as np import cv2 def BrushEdit_Pipeline(pipe, prompts, mask_np, original_image, generator, num_inference_steps, guidance_scale, control_strength, negative_prompt, num_samples, blending): if mask_np.ndim != 3: mask_np = mask_np[:, :, np.newaxis] mask_np = mask_np / 255 height, width = mask_np.shape[0], mask_np.shape[1] ## resize the mask and original image to the same size which is divisible by vae_scale_factor image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) height_new, width_new = image_processor.get_default_height_width(original_image, height, width) mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis] mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 mask_blurred = mask_blurred[:, :, np.newaxis] original_image = cv2.resize(original_image, (width_new, height_new)) init_image = original_image * (1 - mask_np) init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB") brushnet_conditioning_scale = float(control_strength) images = pipe( [prompts] * num_samples, init_image, mask_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale, negative_prompt=[negative_prompt]*num_samples, height=height_new, width=width_new, ).images ## convert to vae shape format, must be divisible by 8 original_image_pil = Image.fromarray(original_image).convert("RGB") init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze()) init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255 init_image_np = init_image_np.astype(np.uint8) if blending: mask_blurred = mask_blurred * 0.5 + 0.5 image_all = [] for image_i in images: image_np = np.array(image_i) ## blending image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np image_pasted = image_pasted.astype(np.uint8) image = Image.fromarray(image_pasted) image_all.append(image) else: image_all = images return image_all, mask_image, mask_np, init_image_np