File size: 6,668 Bytes
0c7479d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Edit Anything trained with Stable Diffusion + ControlNet + SAM  + BLIP2
# pip install mmcv

from torchvision.utils import save_image
from PIL import Image
import subprocess
from collections import OrderedDict
import numpy as np
import cv2
import textwrap
import torch
import os
from annotator.util import resize_image, HWC3
import mmcv
import random

# device = "cuda" if torch.cuda.is_available() else "cpu" # > 15GB GPU memory required
device = "cpu"
use_blip = True
use_gradio = True

if device == 'cpu':
    data_type = torch.float32
else:
    data_type = torch.float16
# Diffusion init using diffusers.

# diffusers==0.14.0 required.
from diffusers.utils import load_image

base_model_path = "stabilityai/stable-diffusion-2-inpainting"
config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'),
                        ('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'),
                        ('SD Inpainting: Not keep position', 'stabilityai/stable-diffusion-2-inpainting')
                        ])

# Segment-Anything init.
# pip install git+https://github.com/facebookresearch/segment-anything.git
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
    print('segment_anything not installed')
    result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
    print(f'Install segment_anything {result}')   
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
    result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
    print(f'Download sam_vit_h_4b8939.pth {result}')   
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)


# BLIP2 init.
if use_blip:
    # need the latest transformers
    # pip install git+https://github.com/huggingface/transformers.git
    from transformers import AutoProcessor, Blip2ForConditionalGeneration
    processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
    blip_model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", torch_dtype=data_type)


def region_classify_w_blip2(image):
    inputs = processor(image, return_tensors="pt").to(device, data_type)
    generated_ids = blip_model.generate(**inputs, max_new_tokens=15)
    generated_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True)[0].strip()
    return generated_text

def region_level_semantic_api(image, topk=5):
    """
    rank regions by area, and classify each region with blip2
    Args:
        image: numpy array
        topk: int
    Returns:
        topk_region_w_class_label: list of dict with key 'class_label'
    """
    topk_region_w_class_label = []
    anns = mask_generator.generate(image)
    if len(anns) == 0:
        return []
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    for i in range(min(topk, len(sorted_anns))):
        ann = anns[i]
        m = ann['segmentation']
        m_3c = m[:,:, np.newaxis]
        m_3c = np.concatenate((m_3c,m_3c,m_3c), axis=2)
        bbox = ann['bbox']
        region = mmcv.imcrop(image*m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
        region_class_label = region_classify_w_blip2(region)
        ann['class_label'] = region_class_label
        print(ann['class_label'], str(bbox))
        topk_region_w_class_label.append(ann)
    return topk_region_w_class_label

def show_semantic_image_label(anns):
    """
    show semantic image label for each region
    Args:
        anns: list of dict with key 'class_label'
    Returns:
        full_img: numpy array
    """
    full_img = None
    # generate mask image
    for i in range(len(anns)):
        m = anns[i]['segmentation']
        if full_img is None:
            full_img = np.zeros((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        full_img[m != 0] = color_mask
    full_img = full_img*255
    # add text on this mask image
    for i in range(len(anns)):
        m = anns[i]['segmentation']
        class_label = anns[i]['class_label']
        # add text to region
        # Calculate the centroid of the region to place the text
        y, x = np.where(m != 0)
        x_center, y_center = int(np.mean(x)), int(np.mean(y))

        # Split the text into multiple lines
        max_width = 20  # Adjust this value based on your preferred maximum width
        wrapped_text = textwrap.wrap(class_label, width=max_width)

        # Add text to region
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1.2
        font_thickness = 2
        font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))  # red
        line_spacing = 40  # Adjust this value based on your preferred line

        for idx, line in enumerate(wrapped_text):
            y_offset = y_center - (len(wrapped_text) - 1) * line_spacing // 2 + idx * line_spacing
            text_size = cv2.getTextSize(line, font, font_scale, font_thickness)[0]
            x_offset = x_center - text_size[0] // 2
            # Draw the text multiple times with small offsets to create a bolder appearance
            offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
            for off_x, off_y in offsets:
                cv2.putText(full_img, line, (x_offset + off_x, y_offset + off_y), font, font_scale, font_color, font_thickness, cv2.LINE_AA)

    return full_img



image_path = "images/sa_224577.jpg"
input_image = Image.open(image_path)
detect_resolution=1024
input_image = resize_image(np.array(input_image, dtype=np.uint8), detect_resolution)
region_level_annots = region_level_semantic_api(input_image, topk=5)
output = show_semantic_image_label(region_level_annots)

image_list = []
input_image = resize_image(input_image, 512)
output = resize_image(output, 512)
input_image = np.array(input_image, dtype=np.uint8)
output = np.array(output, dtype=np.uint8)
image_list.append(torch.tensor(input_image).float())
image_list.append(torch.tensor(output).float())
for each in image_list:
    print(each.shape, type(each))
    print(each.max(), each.min())


image_list = torch.stack(image_list).permute(0, 3, 1, 2)
print(image_list.shape)

save_image(image_list, "images/sample_semantic.jpg", nrow=2,
        normalize=True)