|
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline |
|
from diffusers import EulerAncestralDiscreteScheduler |
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler |
|
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation |
|
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering |
|
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation |
|
|
|
import os |
|
import random |
|
import torch |
|
import cv2 |
|
import re |
|
import uuid |
|
from PIL import Image, ImageOps, ImageDraw, ImageFont |
|
import numpy as np |
|
import math |
|
import inspect |
|
import tempfile |
|
|
|
from langchain.llms.openai import OpenAI |
|
|
|
|
|
import groundingdino.datasets.transforms as T |
|
from groundingdino.models import build_model |
|
from groundingdino.util import box_ops |
|
from groundingdino.util.slconfig import SLConfig |
|
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap |
|
|
|
|
|
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator |
|
import matplotlib.pyplot as plt |
|
import wget |
|
|
|
def prompts(name, description): |
|
def decorator(func): |
|
func.name = name |
|
func.description = description |
|
return func |
|
|
|
return decorator |
|
|
|
def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): |
|
new_size = new_image.size |
|
old_size = old_image.size |
|
easy_img = np.array(new_image) |
|
gt_img_array = np.array(old_image) |
|
pos_w = (new_size[0] - old_size[0]) // 2 |
|
pos_h = (new_size[1] - old_size[1]) // 2 |
|
|
|
kernel_h = cv2.getGaussianKernel(old_size[1], old_size[1] * sigma) |
|
kernel_w = cv2.getGaussianKernel(old_size[0], old_size[0] * sigma) |
|
kernel = np.multiply(kernel_h, np.transpose(kernel_w)) |
|
|
|
kernel[steps:-steps, steps:-steps] = 1 |
|
kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, steps - 1] |
|
kernel[:steps, -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)] |
|
kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, steps - 1] |
|
kernel[-steps:, -steps:] = kernel[-steps:, -steps:] / kernel[-steps, -steps] |
|
kernel = np.expand_dims(kernel, 2) |
|
kernel = np.repeat(kernel, 3, 2) |
|
|
|
weight = np.linspace(0, 1, steps) |
|
top = np.expand_dims(weight, 1) |
|
top = np.repeat(top, old_size[0] - 2 * steps, 1) |
|
top = np.expand_dims(top, 2) |
|
top = np.repeat(top, 3, 2) |
|
|
|
weight = np.linspace(1, 0, steps) |
|
down = np.expand_dims(weight, 1) |
|
down = np.repeat(down, old_size[0] - 2 * steps, 1) |
|
down = np.expand_dims(down, 2) |
|
down = np.repeat(down, 3, 2) |
|
|
|
weight = np.linspace(0, 1, steps) |
|
left = np.expand_dims(weight, 0) |
|
left = np.repeat(left, old_size[1] - 2 * steps, 0) |
|
left = np.expand_dims(left, 2) |
|
left = np.repeat(left, 3, 2) |
|
|
|
weight = np.linspace(1, 0, steps) |
|
right = np.expand_dims(weight, 0) |
|
right = np.repeat(right, old_size[1] - 2 * steps, 0) |
|
right = np.expand_dims(right, 2) |
|
right = np.repeat(right, 3, 2) |
|
|
|
kernel[:steps, steps:-steps] = top |
|
kernel[-steps:, steps:-steps] = down |
|
kernel[steps:-steps, :steps] = left |
|
kernel[steps:-steps, -steps:] = right |
|
|
|
pt_gt_img = easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] |
|
gaussian_gt_img = kernel * gt_img_array + (1 - kernel) * pt_gt_img |
|
gaussian_gt_img = gaussian_gt_img.astype(np.int64) |
|
easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] = gaussian_gt_img |
|
gaussian_img = Image.fromarray(easy_img) |
|
return gaussian_img |
|
|
|
def get_new_image_name(org_img_name, func_name="update"): |
|
head_tail = os.path.split(org_img_name) |
|
head = head_tail[0] |
|
tail = head_tail[1] |
|
name_split = tail.split('.')[0].split('_') |
|
this_new_uuid = str(uuid.uuid4())[0:4] |
|
if len(name_split) == 1: |
|
most_org_file_name = name_split[0] |
|
recent_prev_file_name = name_split[0] |
|
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name) |
|
else: |
|
assert len(name_split) == 4 |
|
most_org_file_name = name_split[3] |
|
recent_prev_file_name = name_split[0] |
|
new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name) |
|
return os.path.join(head, new_file_name) |
|
|
|
def seed_everything(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
return seed |
|
|
|
class InstructPix2Pix: |
|
def __init__(self, device): |
|
print(f"Initializing InstructPix2Pix to {device}") |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix", |
|
safety_checker=None, |
|
torch_dtype=self.torch_dtype).to(device) |
|
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) |
|
|
|
@prompts(name="Instruct Image Using Text", |
|
description="useful when you want to the style of the image to be like the text. " |
|
"like: make it look like a painting. or make it like a robot. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the text. ") |
|
def inference(self, inputs): |
|
"""Change style of image.""" |
|
print("===>Starting InstructPix2Pix Inference") |
|
image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
original_image = Image.open(image_path) |
|
image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="pix2pix") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Text2Image: |
|
def __init__(self, device): |
|
print(f"Initializing Text2Image to {device}") |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.to(device) |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image From User Input Text", |
|
description="useful when you want to generate an image from a user input text and save it to a file. " |
|
"like: generate an image of an object or something, or generate an image that includes some objects. " |
|
"The input to this tool should be a string, representing the text used to generate image. ") |
|
def inference(self, text): |
|
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") |
|
prompt = text + ', ' + self.a_prompt |
|
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0] |
|
image.save(image_filename) |
|
print( |
|
f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}") |
|
return image_filename |
|
|
|
|
|
class ImageCaptioning: |
|
def __init__(self, device): |
|
print(f"Initializing ImageCaptioning to {device}") |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
self.model = BlipForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device) |
|
|
|
@prompts(name="Get Photo Description", |
|
description="useful when you want to know what is inside the photo. receives image_path as input. " |
|
"The input to this tool should be a string, representing the image_path. ") |
|
def inference(self, image_path): |
|
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype) |
|
out = self.model.generate(**inputs) |
|
captions = self.processor.decode(out[0], skip_special_tokens=True) |
|
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}") |
|
return captions |
|
|
|
|
|
class Image2Canny: |
|
def __init__(self, device): |
|
print("Initializing Image2Canny") |
|
self.low_threshold = 100 |
|
self.high_threshold = 200 |
|
|
|
@prompts(name="Edge Detection On Image", |
|
description="useful when you want to detect the edge of the image. " |
|
"like: detect the edges of this image, or canny detection on image, " |
|
"or perform edge detection on this image, or detect the canny image of this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
image = np.array(image) |
|
canny = cv2.Canny(image, self.low_threshold, self.high_threshold) |
|
canny = canny[:, :, None] |
|
canny = np.concatenate([canny, canny, canny], axis=2) |
|
canny = Image.fromarray(canny) |
|
updated_image_path = get_new_image_name(inputs, func_name="edge") |
|
canny.save(updated_image_path) |
|
print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class CannyText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing CannyText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Canny Image", |
|
description="useful when you want to generate a new real image from both the user description and a canny image." |
|
" like: generate a real image of a object or something from this canny image," |
|
" or generate a new real image of a object or something from this edge image. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description. ") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="canny2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Text: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Line: |
|
def __init__(self, device): |
|
print("Initializing Image2Line") |
|
self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
|
@prompts(name="Line Detection On Image", |
|
description="useful when you want to detect the straight line of the image. " |
|
"like: detect the straight lines of this image, or straight line detection on image, " |
|
"or perform straight line detection on this image, or detect the straight line image of this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
mlsd = self.detector(image) |
|
updated_image_path = get_new_image_name(inputs, func_name="line-of") |
|
mlsd.save(updated_image_path) |
|
print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class LineText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing LineText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype |
|
) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Line Image", |
|
description="useful when you want to generate a new real image from both the user description " |
|
"and a straight line image. " |
|
"like: generate a real image of a object or something from this straight line image, " |
|
"or generate a new real image of a object or something from this straight lines. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description. ") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="line2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Text: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Hed: |
|
def __init__(self, device): |
|
print("Initializing Image2Hed") |
|
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
|
@prompts(name="Hed Detection On Image", |
|
description="useful when you want to detect the soft hed boundary of the image. " |
|
"like: detect the soft hed boundary of this image, or hed boundary detection on image, " |
|
"or perform hed boundary detection on this image, or detect soft hed boundary image of this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
hed = self.detector(image) |
|
updated_image_path = get_new_image_name(inputs, func_name="hed-boundary") |
|
hed.save(updated_image_path) |
|
print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class HedText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing HedText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype |
|
) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Soft Hed Boundary Image", |
|
description="useful when you want to generate a new real image from both the user description " |
|
"and a soft hed boundary image. " |
|
"like: generate a real image of a object or something from this soft hed boundary image, " |
|
"or generate a new real image of a object or something from this hed boundary. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="hed2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Scribble: |
|
def __init__(self, device): |
|
print("Initializing Image2Scribble") |
|
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
|
@prompts(name="Sketch Detection On Image", |
|
description="useful when you want to generate a scribble of the image. " |
|
"like: generate a scribble of this image, or generate a sketch from this image, " |
|
"detect the sketch from this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
scribble = self.detector(image, scribble=True) |
|
updated_image_path = get_new_image_name(inputs, func_name="scribble") |
|
scribble.save(updated_image_path) |
|
print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class ScribbleText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing ScribbleText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype |
|
) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Sketch Image", |
|
description="useful when you want to generate a new real image from both the user description and " |
|
"a scribble image or a sketch image. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="scribble2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Pose: |
|
def __init__(self, device): |
|
print("Initializing Image2Pose") |
|
self.detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet') |
|
|
|
@prompts(name="Pose Detection On Image", |
|
description="useful when you want to detect the human pose of the image. " |
|
"like: generate human poses of this image, or generate a pose image from this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
pose = self.detector(image) |
|
updated_image_path = get_new_image_name(inputs, func_name="human-pose") |
|
pose.save(updated_image_path) |
|
print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class PoseText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing PoseText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.num_inference_steps = 20 |
|
self.seed = -1 |
|
self.unconditional_guidance_scale = 9.0 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \ |
|
' fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Pose Image", |
|
description="useful when you want to generate a new real image from both the user description " |
|
"and a human pose image. " |
|
"like: generate a real image of a human from this human pose image, " |
|
"or generate a new real image of a human from this pose. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="pose2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class SegText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing SegText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg", |
|
torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \ |
|
' fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Segmentations", |
|
description="useful when you want to generate a new real image from both the user description and segmentations. " |
|
"like: generate a real image of a object or something from this segmentation image, " |
|
"or generate a new real image of a object or something from these segmentations. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="segment2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Depth: |
|
def __init__(self, device): |
|
print("Initializing Image2Depth") |
|
self.depth_estimator = pipeline('depth-estimation') |
|
|
|
@prompts(name="Predict Depth On Image", |
|
description="useful when you want to detect depth of the image. like: generate the depth from this image, " |
|
"or detect the depth map on this image, or predict the depth for this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
depth = self.depth_estimator(image)['depth'] |
|
depth = np.array(depth) |
|
depth = depth[:, :, None] |
|
depth = np.concatenate([depth, depth, depth], axis=2) |
|
depth = Image.fromarray(depth) |
|
updated_image_path = get_new_image_name(inputs, func_name="depth") |
|
depth.save(updated_image_path) |
|
print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class DepthText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing DepthText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained( |
|
"fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \ |
|
' fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Depth", |
|
description="useful when you want to generate a new real image from both the user description and depth image. " |
|
"like: generate a real image of a object or something from this depth image, " |
|
"or generate a new real image of a object or something from the depth map. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="depth2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Image2Normal: |
|
def __init__(self, device): |
|
print("Initializing Image2Normal") |
|
self.depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas") |
|
self.bg_threhold = 0.4 |
|
|
|
@prompts(name="Predict Normal Map On Image", |
|
description="useful when you want to detect norm map of the image. " |
|
"like: generate normal map from this image, or predict normal map of this image. " |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference(self, inputs): |
|
image = Image.open(inputs) |
|
original_size = image.size |
|
image = self.depth_estimator(image)['predicted_depth'][0] |
|
image = image.numpy() |
|
image_depth = image.copy() |
|
image_depth -= np.min(image_depth) |
|
image_depth /= np.max(image_depth) |
|
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) |
|
x[image_depth < self.bg_threhold] = 0 |
|
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3) |
|
y[image_depth < self.bg_threhold] = 0 |
|
z = np.ones_like(x) * np.pi * 2.0 |
|
image = np.stack([x, y, z], axis=2) |
|
image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5 |
|
image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) |
|
image = Image.fromarray(image) |
|
image = image.resize(original_size) |
|
updated_image_path = get_new_image_name(inputs, func_name="normal-map") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class NormalText2Image: |
|
def __init__(self, device): |
|
print(f"Initializing NormalText2Image to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.controlnet = ControlNetModel.from_pretrained( |
|
"fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype) |
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.to(device) |
|
self.seed = -1 |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \ |
|
' fewer digits, cropped, worst quality, low quality' |
|
|
|
@prompts(name="Generate Image Condition On Normal Map", |
|
description="useful when you want to generate a new real image from both the user description and normal map. " |
|
"like: generate a real image of a object or something from this normal map, " |
|
"or generate a new real image of a object or something from the normal map. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the user description") |
|
def inference(self, inputs): |
|
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
image = Image.open(image_path) |
|
self.seed = random.randint(0, 65535) |
|
seed_everything(self.seed) |
|
prompt = f'{instruct_text}, {self.a_prompt}' |
|
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt, |
|
guidance_scale=9.0).images[0] |
|
updated_image_path = get_new_image_name(image_path, func_name="normal2image") |
|
image.save(updated_image_path) |
|
print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class VisualQuestionAnswering: |
|
def __init__(self, device): |
|
print(f"Initializing VisualQuestionAnswering to {device}") |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.device = device |
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") |
|
self.model = BlipForQuestionAnswering.from_pretrained( |
|
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device) |
|
|
|
@prompts(name="Answer Question About The Image", |
|
description="useful when you need an answer for a question based on an image. " |
|
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. " |
|
"The input to this tool should be a comma separated string of two, representing the image_path and the question") |
|
def inference(self, inputs): |
|
image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
raw_image = Image.open(image_path).convert('RGB') |
|
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype) |
|
out = self.model.generate(**inputs) |
|
answer = self.processor.decode(out[0], skip_special_tokens=True) |
|
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, " |
|
f"Output Answer: {answer}") |
|
return answer |
|
|
|
|
|
class Segmenting: |
|
def __init__(self, device): |
|
print(f"Inintializing Segmentation to {device}") |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.model_checkpoint_path = os.path.join("checkpoints", "sam") |
|
|
|
self.download_parameters() |
|
self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device) |
|
self.sam_predictor = SamPredictor(self.sam) |
|
self.mask_generator = SamAutomaticMaskGenerator(self.sam) |
|
|
|
def download_parameters(self): |
|
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
if not os.path.exists(self.model_checkpoint_path): |
|
wget.download(url, out=self.model_checkpoint_path) |
|
|
|
def show_mask(self, mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([1])], axis=0) |
|
else: |
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 1]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def show_box(self, box, ax, label): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) |
|
ax.text(x0, y0, label) |
|
|
|
def get_mask_with_boxes(self, image_pil, image, boxes_filt): |
|
|
|
size = image_pil.size |
|
H, W = size[1], size[0] |
|
for i in range(boxes_filt.size(0)): |
|
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) |
|
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 |
|
boxes_filt[i][2:] += boxes_filt[i][:2] |
|
|
|
boxes_filt = boxes_filt.cpu() |
|
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(self.device) |
|
|
|
masks, _, _ = self.sam_predictor.predict_torch( |
|
point_coords=None, |
|
point_labels=None, |
|
boxes=transformed_boxes.to(self.device), |
|
multimask_output=False, |
|
) |
|
return masks |
|
|
|
def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases): |
|
|
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
self.sam_predictor.set_image(image) |
|
|
|
masks = self.get_mask_with_boxes(image_pil, image, boxes_filt) |
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(image) |
|
for mask in masks: |
|
self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) |
|
|
|
updated_image_path = get_new_image_name(image_path, func_name="segmentation") |
|
plt.axis('off') |
|
plt.savefig( |
|
updated_image_path, |
|
bbox_inches="tight", dpi=300, pad_inches=0.0 |
|
) |
|
return updated_image_path |
|
|
|
@prompts(name="Segment the Image", |
|
description="useful when you want to segment all the part of the image, but not segment a certain object." |
|
"like: segment all the object in this image, or generate segmentations on this image, " |
|
"or segment the image," |
|
"or perform segmentation on this image, " |
|
"or segment all the object in this image." |
|
"The input to this tool should be a string, representing the image_path") |
|
def inference_all(self, image_path): |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
masks = self.mask_generator.generate(image) |
|
plt.figure(figsize=(20, 20)) |
|
plt.imshow(image) |
|
if len(masks) == 0: |
|
return |
|
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) |
|
ax = plt.gca() |
|
ax.set_autoscale_on(False) |
|
polygons = [] |
|
color = [] |
|
for ann in sorted_anns: |
|
m = ann['segmentation'] |
|
img = np.ones((m.shape[0], m.shape[1], 3)) |
|
color_mask = np.random.random((1, 3)).tolist()[0] |
|
for i in range(3): |
|
img[:, :, i] = color_mask[i] |
|
ax.imshow(np.dstack((img, m))) |
|
|
|
updated_image_path = get_new_image_name(image_path, func_name="segment-image") |
|
plt.axis('off') |
|
plt.savefig( |
|
updated_image_path, |
|
bbox_inches="tight", dpi=300, pad_inches=0.0 |
|
) |
|
return updated_image_path |
|
|
|
|
|
class Text2Box: |
|
def __init__(self, device): |
|
print(f"Initializing ObjectDetection to {device}") |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.model_checkpoint_path = os.path.join("checkpoints", "groundingdino") |
|
self.model_config_path = os.path.join("checkpoints", "grounding_config.py") |
|
self.download_parameters() |
|
self.box_threshold = 0.3 |
|
self.text_threshold = 0.25 |
|
self.grounding = (self.load_model()).to(self.device) |
|
|
|
def download_parameters(self): |
|
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" |
|
if not os.path.exists(self.model_checkpoint_path): |
|
wget.download(url, out=self.model_checkpoint_path) |
|
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
if not os.path.exists(self.model_config_path): |
|
wget.download(config_url, out=self.model_config_path) |
|
|
|
def load_image(self, image_path): |
|
|
|
image_pil = Image.open(image_path).convert("RGB") |
|
|
|
transform = T.Compose( |
|
[ |
|
T.RandomResize([512], max_size=1333), |
|
T.ToTensor(), |
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
image, _ = transform(image_pil, None) |
|
return image_pil, image |
|
|
|
def load_model(self): |
|
args = SLConfig.fromfile(self.model_config_path) |
|
args.device = self.device |
|
model = build_model(args) |
|
checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu") |
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) |
|
print(load_res) |
|
_ = model.eval() |
|
return model |
|
|
|
def get_grounding_boxes(self, image, caption, with_logits=True): |
|
caption = caption.lower() |
|
caption = caption.strip() |
|
if not caption.endswith("."): |
|
caption = caption + "." |
|
image = image.to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.grounding(image[None], captions=[caption]) |
|
logits = outputs["pred_logits"].cpu().sigmoid()[0] |
|
boxes = outputs["pred_boxes"].cpu()[0] |
|
logits.shape[0] |
|
|
|
|
|
logits_filt = logits.clone() |
|
boxes_filt = boxes.clone() |
|
filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold |
|
logits_filt = logits_filt[filt_mask] |
|
boxes_filt = boxes_filt[filt_mask] |
|
logits_filt.shape[0] |
|
|
|
|
|
tokenlizer = self.grounding.tokenizer |
|
tokenized = tokenlizer(caption) |
|
|
|
pred_phrases = [] |
|
for logit, box in zip(logits_filt, boxes_filt): |
|
pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer) |
|
if with_logits: |
|
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") |
|
else: |
|
pred_phrases.append(pred_phrase) |
|
|
|
return boxes_filt, pred_phrases |
|
|
|
def plot_boxes_to_image(self, image_pil, tgt): |
|
H, W = tgt["size"] |
|
boxes = tgt["boxes"] |
|
labels = tgt["labels"] |
|
assert len(boxes) == len(labels), "boxes and labels must have same length" |
|
|
|
draw = ImageDraw.Draw(image_pil) |
|
mask = Image.new("L", image_pil.size, 0) |
|
mask_draw = ImageDraw.Draw(mask) |
|
|
|
|
|
for box, label in zip(boxes, labels): |
|
|
|
box = box * torch.Tensor([W, H, W, H]) |
|
|
|
box[:2] -= box[2:] / 2 |
|
box[2:] += box[:2] |
|
|
|
color = tuple(np.random.randint(0, 255, size=3).tolist()) |
|
|
|
x0, y0, x1, y1 = box |
|
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) |
|
|
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=6) |
|
|
|
|
|
font = ImageFont.load_default() |
|
if hasattr(font, "getbbox"): |
|
bbox = draw.textbbox((x0, y0), str(label), font) |
|
else: |
|
w, h = draw.textsize(str(label), font) |
|
bbox = (x0, y0, w + x0, y0 + h) |
|
|
|
draw.rectangle(bbox, fill=color) |
|
draw.text((x0, y0), str(label), fill="white") |
|
|
|
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=2) |
|
|
|
return image_pil, mask |
|
|
|
@prompts(name="Detect the Give Object", |
|
description="useful when you only want to detect or find out given objects in the picture" |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path, the text description of the object to be found") |
|
def inference(self, inputs): |
|
image_path, det_prompt = inputs.split(",") |
|
print(f"image_path={image_path}, text_prompt={det_prompt}") |
|
image_pil, image = self.load_image(image_path) |
|
|
|
boxes_filt, pred_phrases = self.get_grounding_boxes(image, det_prompt) |
|
|
|
size = image_pil.size |
|
pred_dict = { |
|
"boxes": boxes_filt, |
|
"size": [size[1], size[0]], |
|
"labels": pred_phrases, } |
|
|
|
image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0] |
|
|
|
updated_image_path = get_new_image_name(image_path, func_name="detect-something") |
|
updated_image = image_with_box.resize(size) |
|
updated_image.save(updated_image_path) |
|
print( |
|
f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be Detect {det_prompt}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class Inpainting: |
|
def __init__(self, device): |
|
self.device = device |
|
self.revision = 'fp16' if 'cuda' in self.device else None |
|
self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 |
|
|
|
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device) |
|
|
|
def __call__(self, prompt, image, mask_image, height=512, width=512, num_inference_steps=50): |
|
update_image = self.inpaint(prompt=prompt, image=image.resize((width, height)), |
|
mask_image=mask_image.resize((width, height)), height=height, width=width, |
|
num_inference_steps=num_inference_steps).images[0] |
|
return update_image |
|
|
|
|
|
class InfinityOutPainting: |
|
template_model = True |
|
def __init__(self, ImageCaptioning, Inpainting, VisualQuestionAnswering): |
|
self.ImageCaption = ImageCaptioning |
|
self.inpaint = Inpainting |
|
self.ImageVQA = VisualQuestionAnswering |
|
self.a_prompt = 'best quality, extremely detailed' |
|
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
def get_BLIP_vqa(self, image, question): |
|
inputs = self.ImageVQA.processor(image, question, return_tensors="pt").to(self.ImageVQA.device, |
|
self.ImageVQA.torch_dtype) |
|
out = self.ImageVQA.model.generate(**inputs) |
|
answer = self.ImageVQA.processor.decode(out[0], skip_special_tokens=True) |
|
print(f"\nProcessed VisualQuestionAnswering, Input Question: {question}, Output Answer: {answer}") |
|
return answer |
|
|
|
def get_BLIP_caption(self, image): |
|
inputs = self.ImageCaption.processor(image, return_tensors="pt").to(self.ImageCaption.device, |
|
self.ImageCaption.torch_dtype) |
|
out = self.ImageCaption.model.generate(**inputs) |
|
BLIP_caption = self.ImageCaption.processor.decode(out[0], skip_special_tokens=True) |
|
return BLIP_caption |
|
|
|
def get_imagine_caption(self, image, imagine): |
|
BLIP_caption = self.get_BLIP_caption(image) |
|
caption = BLIP_caption |
|
print(f'Prompt: {caption}') |
|
return caption |
|
|
|
def resize_image(self, image, max_size=1000000, multiple=8): |
|
aspect_ratio = image.size[0] / image.size[1] |
|
new_width = int(math.sqrt(max_size * aspect_ratio)) |
|
new_height = int(new_width / aspect_ratio) |
|
new_width, new_height = new_width - (new_width % multiple), new_height - (new_height % multiple) |
|
return image.resize((new_width, new_height)) |
|
|
|
def dowhile(self, original_img, tosize, expand_ratio, imagine, usr_prompt): |
|
old_img = original_img |
|
while (old_img.size != tosize): |
|
prompt = self.check_prompt(usr_prompt) if usr_prompt else self.get_imagine_caption(old_img, imagine) |
|
crop_w = 15 if old_img.size[0] != tosize[0] else 0 |
|
crop_h = 15 if old_img.size[1] != tosize[1] else 0 |
|
old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h)) |
|
temp_canvas_size = (expand_ratio * old_img.width if expand_ratio * old_img.width < tosize[0] else tosize[0], |
|
expand_ratio * old_img.height if expand_ratio * old_img.height < tosize[1] else tosize[ |
|
1]) |
|
temp_canvas, temp_mask = Image.new("RGB", temp_canvas_size, color="white"), Image.new("L", temp_canvas_size, |
|
color="white") |
|
x, y = (temp_canvas.width - old_img.width) // 2, (temp_canvas.height - old_img.height) // 2 |
|
temp_canvas.paste(old_img, (x, y)) |
|
temp_mask.paste(0, (x, y, x + old_img.width, y + old_img.height)) |
|
resized_temp_canvas, resized_temp_mask = self.resize_image(temp_canvas), self.resize_image(temp_mask) |
|
image = self.inpaint(prompt=prompt, image=resized_temp_canvas, mask_image=resized_temp_mask, |
|
height=resized_temp_canvas.height, width=resized_temp_canvas.width, |
|
num_inference_steps=50).resize( |
|
(temp_canvas.width, temp_canvas.height), Image.ANTIALIAS) |
|
image = blend_gt2pt(old_img, image) |
|
old_img = image |
|
return old_img |
|
|
|
@prompts(name="Extend An Image", |
|
description="useful when you need to extend an image into a larger image." |
|
"like: extend the image into a resolution of 2048x1024, extend the image into 2048x1024. " |
|
"The input to this tool should be a comma separated string of two, representing the image_path and the resolution of widthxheight") |
|
def inference(self, inputs): |
|
image_path, resolution = inputs.split(',') |
|
width, height = resolution.split('x') |
|
tosize = (int(width), int(height)) |
|
image = Image.open(image_path) |
|
image = ImageOps.crop(image, (10, 10, 10, 10)) |
|
out_painted_image = self.dowhile(image, tosize, 4, True, False) |
|
updated_image_path = get_new_image_name(image_path, func_name="outpainting") |
|
out_painted_image.save(updated_image_path) |
|
print(f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input Resolution: {resolution}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class ObjectSegmenting: |
|
template_model = True |
|
|
|
def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting): |
|
|
|
self.grounding = Text2Box |
|
self.sam = Segmenting |
|
|
|
@prompts(name="Segment the given object", |
|
description="useful when you only want to segment the certain objects in the picture" |
|
"according to the given text" |
|
"like: segment the cat," |
|
"or can you segment an obeject for me" |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path, the text description of the object to be found") |
|
def inference(self, inputs): |
|
image_path, det_prompt = inputs.split(",") |
|
print(f"image_path={image_path}, text_prompt={det_prompt}") |
|
image_pil, image = self.grounding.load_image(image_path) |
|
boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt) |
|
updated_image_path = self.sam.segment_image_with_boxes(image_pil, image_path, boxes_filt, pred_phrases) |
|
print( |
|
f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be Segment {det_prompt}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |
|
|
|
|
|
class ImageEditing: |
|
template_model = True |
|
|
|
def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting, Inpainting: Inpainting): |
|
print(f"Initializing ImageEditing") |
|
self.sam = Segmenting |
|
self.grounding = Text2Box |
|
self.inpaint = Inpainting |
|
|
|
def pad_edge(self, mask, padding): |
|
|
|
mask = mask.numpy() |
|
true_indices = np.argwhere(mask) |
|
mask_array = np.zeros_like(mask, dtype=bool) |
|
for idx in true_indices: |
|
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx) |
|
mask_array[padded_slice] = True |
|
new_mask = (mask_array * 255).astype(np.uint8) |
|
|
|
return new_mask |
|
|
|
@prompts(name="Remove Something From The Photo", |
|
description="useful when you want to remove and object or something from the photo " |
|
"from its description or location. " |
|
"The input to this tool should be a comma separated string of two, " |
|
"representing the image_path and the object need to be removed. ") |
|
def inference_remove(self, inputs): |
|
image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) |
|
return self.inference_replace_sam(f"{image_path},{to_be_removed_txt},background") |
|
|
|
@prompts(name="Replace Something From The Photo", |
|
description="useful when you want to replace an object from the object description or " |
|
"location with another object from its description. " |
|
"The input to this tool should be a comma separated string of three, " |
|
"representing the image_path, the object to be replaced, the object to be replaced with ") |
|
def inference_replace_sam(self, inputs): |
|
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") |
|
|
|
print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") |
|
image_pil, image = self.grounding.load_image(image_path) |
|
boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, to_be_replaced_txt) |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
self.sam.sam_predictor.set_image(image) |
|
masks = self.sam.get_mask_with_boxes(image_pil, image, boxes_filt) |
|
mask = torch.sum(masks, dim=0).unsqueeze(0) |
|
mask = torch.where(mask > 0, True, False) |
|
mask = mask.squeeze(0).squeeze(0).cpu() |
|
|
|
mask = self.pad_edge(mask, padding=20) |
|
mask_image = Image.fromarray(mask) |
|
|
|
updated_image = self.inpaint(prompt=replace_with_txt, image=image_pil, |
|
mask_image=mask_image) |
|
updated_image_path = get_new_image_name(image_path, func_name="replace-something") |
|
updated_image = updated_image.resize(image_pil.size) |
|
updated_image.save(updated_image_path) |
|
print( |
|
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, " |
|
f"Output Image: {updated_image_path}") |
|
return updated_image_path |