import gradio as gr import cv2 import numpy as np import mediapipe as mp from mediapipe.tasks import python from mediapipe.tasks.python import vision from mediapipe.python._framework_bindings import image as image_module _Image = image_module.Image from mediapipe.python._framework_bindings import image_frame _ImageFormat = image_frame.ImageFormat import torch from diffusers import StableDiffusionPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel from PIL import Image from compel import Compel # Constants for colors BG_COLOR = (0, 0, 0, 255) # gray with full opacity MASK_COLOR = (255, 255, 255, 255) # white with full opacity # Create the options that will be used for ImageSegmenter base_options = python.BaseOptions(model_asset_path='emirhan.tflite') options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True) # Initialize ControlNet inpainting pipeline controlnet = ControlNetModel.from_pretrained( 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16, ).to("cuda") pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( 'runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=torch.float16, ).to("cuda") # Function to segment hair and generate mask def segment_hair(image): rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA) rgba_image[:, :, 3] = 0 # Set alpha channel to empty # Create MP Image object from numpy array mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image) # Create the image segmenter with vision.ImageSegmenter.create_from_options(options) as segmenter: # Retrieve the masks for the segmented image segmentation_result = segmenter.segment(mp_image) category_mask = segmentation_result.category_mask # Generate solid color images for showing the output segmentation mask. image_data = mp_image.numpy_view() fg_image = np.zeros(image_data.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image_data.shape, dtype=np.uint8) bg_image[:] = BG_COLOR condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2 output_image = np.where(condition, fg_image, bg_image) return output_image # Return the RGBA mask # Function to inpaint the hair area using ControlNet def inpaint_hair(image, prompt): # Segment hair to get the mask mask = segment_hair(image) # Convert to PIL image for the inpainting pipeline image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) mask_pil = Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_RGBA2RGB)) # Prepare the inpainting condition image_np = np.array(image_pil).astype(np.float32) / 255.0 mask_np = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0 image_np[mask_np > 0.5] = -1.0 # Set as masked pixel inpaint_condition = torch.from_numpy(np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)).to("cuda") # Generate inpainted image generator = torch.Generator("cuda").manual_seed(42) output = pipe( prompt=prompt, image=image_pil, mask_image=mask_pil, control_image=inpaint_condition, num_inference_steps=50, guidance_scale=7.5, generator=generator ).images[0] return np.array(output) # Gradio interface iface = gr.Interface( fn=inpaint_hair, inputs=[ gr.Image(type="numpy"), gr.Textbox(label="Prompt", placeholder="Describe the desired inpainting result...") ], outputs=gr.Image(type="numpy"), title="Hair Inpainting with ControlNet", description="Upload an image, and provide a prompt to inpaint the hair area using ControlNet.", examples=[["example.jpeg", "dreadlocks"]] ) if __name__ == "__main__": iface.launch()