ru / Mask_Clip2Seg.py
Nirmal00001123's picture
Upload 66 files
28dc120 verified
import cv2
import numpy as np
import torch
import threading
from torchvision import transforms
from clip.clipseg import CLIPDensePredT
import numpy as np
from roop.typing import Frame
THREAD_LOCK_CLIP = threading.Lock()
class Mask_Clip2Seg():
plugin_options:dict = None
model_clip = None
processorname = 'clip2seg'
type = 'mask'
def Initialize(self, plugin_options:dict):
if self.plugin_options is not None:
if self.plugin_options["devicename"] != plugin_options["devicename"]:
self.Release()
self.plugin_options = plugin_options
if self.model_clip is None:
self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
self.model_clip.eval();
self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
device = torch.device(self.plugin_options["devicename"])
self.model_clip.to(device)
def Run(self, img1, keywords:str) -> Frame:
if keywords is None or len(keywords) < 1 or img1 is None:
return img1
source_image_small = cv2.resize(img1, (256,256))
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
mask_border = 1
l = 0
t = 0
r = 1
b = 1
mask_blur = 5
clip_blur = 5
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
img_mask /= 255
input_image = source_image_small
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((256, 256)),
])
img = transform(input_image).unsqueeze(0)
thresh = 0.5
prompts = keywords.split(',')
with THREAD_LOCK_CLIP:
with torch.no_grad():
preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
clip_mask = torch.sigmoid(preds[0][0])
for i in range(len(prompts)-1):
clip_mask += torch.sigmoid(preds[i+1][0])
clip_mask = clip_mask.data.cpu().numpy()
np.clip(clip_mask, 0, 1)
clip_mask[clip_mask>thresh] = 1.0
clip_mask[clip_mask<=thresh] = 0.0
kernel = np.ones((5, 5), np.float32)
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
img_mask *= clip_mask
img_mask[img_mask<0.0] = 0.0
return img_mask
def Release(self):
self.model_clip = None