from turtle import title import os import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig from skimage.measure import label, regionprops processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") classes = list() def create_mask(image,image_mask,alpha=0.7): mask = np.zeros_like(image) # copy your image_mask to all dimensions (i.e. colors) of your image for i in range(3): mask[:,:,i] = image_mask.copy() # apply the mask to your image overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0) return overlay_image def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352): bbox = np.asarray(bbox)/model_shape y1,y2 = bbox[::2] *orig_image_shape[0] x1,x2 = bbox[1::2]*orig_image_shape[1] return [int(y1),int(x1),int(y2),int(x2)] def detect_using_clip(image,prompts=[],threshould=0.4): h,w = image.shape[:2] model_detections = dict() predicted_images = dict() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) detection = outputs.logits[0] # Assuming class index 0 for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) # extract countours from the image lbl_0 = label(predicted_image) props = regionprops(lbl_0) prompt = prompt.lower() model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props] predicted_images[prompt]= predicted_image return model_detections , predicted_images def visualize_images(image,detections,predicted_images,prompt): alpha = 0.7 # H,W = image.shape[:2] prompt = prompt.lower() image_resize = cv2.resize(image,(352,352)) mask_image = create_mask(image=image_resize,image_mask=predicted_images[prompt]) if prompt not in detections.keys(): print("prompt not in query ..") return image_resize final_image = cv2.addWeighted(image_resize,alpha,mask_image,1-alpha,0) return final_image def shot(image, labels_text,selected_categoty): prompts = labels_text.split(',') prompts = list(map(lambda x: x.strip(),prompts)) model_detections,predicted_images = detect_using_clip(image,prompts=prompts) category_image = visualize_images(image=image,detections=model_detections,predicted_images=predicted_images,prompt=selected_categoty) return category_image iface = gr.Interface(fn=shot, inputs = ["image","text","text"], outputs = "image", description ="Add an Image and list of category to be detected separated by commas", title = "Zero-shot Image Classification with Prompt ", examples=[ ["images/room.jpg","bed, table, plant, light, window",'plant'], ["images/image2.png","banner, building,door, sign","sign"] ], # allow_flagging=False, # analytics_enabled=False, ) iface.launch()