#Importing required libraries import spaces import gradio as gr import os import random import numpy as np import cv2 from PIL import Image from dataclasses import dataclass from typing import Any, List, Dict, Optional, Union, Tuple import torch import google.generativeai as genai from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline, T5EncoderModel, CLIPTextModel from diffusers import FluxTransformer2DModel, FluxInpaintPipeline MARKDOWN = """ # Add-It🎨 Add or Replace anything to any image by using a single Prompt and an Image. Made using [Flux (Schnell)](https://huggingface.co/black-forest-labs/FLUX.1-schnell), [Grounding-DINO](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino) and [SAM](https://huggingface.co/docs/transformers/en/model_doc/sam). """ #Gemini Setup genai.configure(api_key = os.environ['Gemini_API']) gemini_flash = genai.GenerativeModel(model_name='gemini-1.5-flash-002') def gemini_predict(prompt): system_message = f"""You are the best text analyser. You have to analyse a user query and identify what the user wants to change, from a given user query. Examples: Query: Change Lipstick colour to blue Response: Lips Query: Add a nose stud Response: Nose Query: Add a wallpaper to the right wall Response: Right wall Query: Change the Sofa's colour to Purple Response: Sofa Your response should be in 1 or 2-3 words Query : {prompt} """ response = gemini_flash.generate_content(system_message) return(str(response.text)[:-1]) MAX_SEED = np.iinfo(np.int32).max SAM_device = "cuda" # or "cpu" DEVICE = "cuda" ###GroundingDINO & SAM Setup #To store DINO results @dataclass class BoundingBox: xmin: int ymin: int xmax: int ymax: int @property def xyxy(self) -> List[float]: return [self.xmin, self.ymin, self.xmax, self.ymax] @dataclass class DetectionResult: score: float label: str box: BoundingBox mask: Optional[np.array] = None @classmethod def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': return cls(score=detection_dict['score'], label=detection_dict['label'], box=BoundingBox(xmin=detection_dict['box']['xmin'], ymin=detection_dict['box']['ymin'], xmax=detection_dict['box']['xmax'], ymax=detection_dict['box']['ymax'])) #Utility Functions for Mask Generation def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: # Find contours in the binary mask contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Find the contour with the largest area largest_contour = max(contours, key=cv2.contourArea) # Extract the vertices of the contour polygon = largest_contour.reshape(-1, 2).tolist() return polygon def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: """ Convert a polygon to a segmentation mask. Args: - polygon (list): List of (x, y) coordinates representing the vertices of the polygon. - image_shape (tuple): Shape of the image (height, width) for the mask. Returns: - np.ndarray: Segmentation mask with the polygon filled. """ # Create an empty mask mask = np.zeros(image_shape, dtype=np.uint8) # Convert polygon to an array of points pts = np.array(polygon, dtype=np.int32) # Fill the polygon with white color (255) cv2.fillPoly(mask, [pts], color=(255,)) return mask def get_boxes(results: DetectionResult) -> List[List[List[float]]]: boxes = [] for result in results: xyxy = result.box.xyxy boxes.append(xyxy) return [boxes] def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) masks = masks.mean(axis=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) #print(masks) if polygon_refinement: for idx, mask in enumerate(masks): shape = mask.shape polygon = mask_to_polygon(mask) mask = polygon_to_mask(polygon, shape) masks[idx] = mask return masks def get_alphacomp_mask(mask, image, random_color=True): annotated_frame_pil = Image.fromarray(image).convert("RGBA") mask_image_pil = Image.fromarray(mask).convert("RGBA") return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil)) # Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. detector_id = "IDEA-Research/grounding-dino-tiny" object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=SAM_device) #Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. segmenter_id = "facebook/sam-vit-base" processor = AutoProcessor.from_pretrained(segmenter_id) segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(SAM_device) def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]: labels = [label if label.endswith(".") else label+"." for label in labels] with torch.no_grad(): results = object_detector(image, candidate_labels=labels, threshold=threshold) torch.cuda.empty_cache() results = [DetectionResult.from_dict(result) for result in results] #print("DINO results:", results) return results def segment_SAM(image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False) -> List[DetectionResult]: boxes = get_boxes(detection_results) inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(SAM_device) with torch.no_grad(): outputs = segmentator(**inputs) torch.cuda.empty_cache() masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0] #print("Masks:", masks) masks = refine_masks(masks, polygon_refinement) for detection_result, mask in zip(detection_results, masks): detection_result.mask = mask return detection_results def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False) -> Tuple[np.ndarray, List[DetectionResult]]: if isinstance(image, str): image = load_image(image) detections = detect(image, labels, threshold) segmented = segment_SAM(image, detections, polygon_refinement) return np.array(image), segmented def get_finalmask(image_array, detections): for i,d in enumerate(detections): mask_ = d.__getattribute__('mask') if i==0: image_with_mask = get_alphacomp_mask(mask_, image_array) else: image_with_mask += get_alphacomp_mask(mask_, image_array) return image_with_mask #Preprocessing Mask kernel = np.ones((3, 3), np.uint8) # Taking a matrix of size 3 as the kernel def preprocess_mask(pipe, inp_mask, expan_lvl, blur_lvl): if expan_lvl>0: inp_mask = Image.fromarray(cv2.dilate(np.array(inp_mask), kernel, iterations=expan_lvl)) if blur_lvl>0: inp_mask = pipe.mask_processor.blur(inp_mask, blur_factor=blur_lvl) # inp_mask = Image.fromarray(np.array(inp_mask)) return inp_mask def generate_mask(inp_image, label, threshold): image_array, segments = grounded_segmentation(image=inp_image, labels=label, threshold=threshold, polygon_refinement=True,) inp_mask = get_finalmask(image_array, segments) # print(type(inp_mask)) return inp_mask #Setting up Flux (Schnell) Inpainting transformer_ = FluxTransformer2DModel.from_pretrained("ashen0209/Flux-Dev2Pro", torch_dtype=torch.bfloat16) text_encoder_ = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) text_encoder_2_ = T5EncoderModel.from_pretrained("xlabs-ai/xflux_text_encoders", torch_dtype=torch.bfloat16) inpaint_pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",transformer=transformer_,text_encoder=text_encoder_,text_encoder_2=text_encoder_2_, torch_dtype=torch.bfloat16).to(DEVICE) #inpaint_pipe.load_lora_weights("XLabs-AI/flux-RealismLora") #Uncomment the following 4 lines, if you want LoRA Realism weights added to the pipeline # inpaint_pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better") # inpaint_pipe.set_adapters(["better"], adapter_weights=[2.6]) # inpaint_pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # inpaint_pipe.unload_lora_weights() #torch.cuda.empty_cache() @spaces.GPU() def process(input_image_editor, input_text, strength, seed, randomize_seed, num_inference_steps, guidance_scale, threshold, expan_lvl, blur_lvl, progress=gr.Progress(track_tqdm=True)): if not input_text: raise gr.Error("Please enter a text prompt.") #Object identification item = gemini_predict(input_text) #print(item) image = input_image_editor['background'] if not image: raise gr.Error("Please upload an image.") width, height = image.size if width>1024 or height>1024: image.thumbnail((1024, 1024)) if randomize_seed: seed = random.randint(0, MAX_SEED) #Generating Mask label = [item] gen_mask = generate_mask(image, label, threshold) #Pre-processing Mask, optional if expan_lvl>0 or blur_lvl>0: gen_mask = preprocess_mask(inpaint_pipe, gen_mask, expan_lvl, blur_lvl) #Inpainting generator = torch.Generator(device=DEVICE).manual_seed(seed) result = inpaint_pipe(prompt=input_text, image=image, mask_image=gen_mask, width=width, height=height, strength=strength, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale).images[0] return result, gen_mask, seed, item with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(scale=1): input_image_component = gr.ImageEditor( label='Image', type='pil', sources=["upload", "webcam"], image_mode='RGB', layers=False) input_text_component = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False,) with gr.Accordion("Advanced Settings", open=False): strength_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.96, step=0.01, label="Strength" ) num_inference_steps = gr.Slider( minimum=1, maximum=100, value=16, step=1, label="Number of inference steps" ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=5, ) seed_number = gr.Number( label="Seed", value=26, precision=0 ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) with gr.Accordion("Mask Settings", open=False): SAM_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.4, step=0.01, label="Threshold" ) expansion_level = gr.Slider( minimum=0, maximum=10, value=2, step=1, label="Mask Expansion level" ) blur_level = gr.Slider( minimum=0, maximum=5, step=1, value=0, label="Mask Blur level" ) submit_button_component = gr.Button(value='Inpaint', variant='primary') with gr.Column(scale=1): output_image_component = gr.Image(type='pil', image_mode='RGB', label='Generated Image') output_mask_component = gr.Image(type='pil', image_mode='RGB', label='Generated Mask') with gr.Accordion("Debug Info", open=False): output_seed = gr.Number(label="Used Seed") identified_item = gr.Textbox(label="Gemini predicted item") submit_button_component.click( fn=process, inputs=[input_image_component, input_text_component, strength_slider, seed_number, randomize_seed, num_inference_steps, guidance_scale, SAM_threshold, expansion_level, blur_level], outputs=[output_image_component, output_mask_component, output_seed, identified_item] ) demo.launch(debug=False, show_error=True)