import torch | |
import numpy as np | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
from vpt.src.configs.config import get_cfg | |
import os | |
from time import sleep | |
from random import randint | |
from vpt.src.utils.file_io import PathManager | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb | |
import warnings | |
import nltk | |
warnings.filterwarnings("ignore") | |
def get_noun_phrase(tokenized): | |
# Taken from Su Nam Kim Paper... | |
grammar = r""" | |
NBAR: | |
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns | |
NP: | |
{<NBAR>} | |
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc... | |
""" | |
chunker = nltk.RegexpParser(grammar) | |
chunked = chunker.parse(nltk.pos_tag(tokenized)) | |
continuous_chunk = [] | |
current_chunk = [] | |
for subtree in chunked: | |
if isinstance(subtree, nltk.Tree): | |
current_chunk.append(' '.join([token for token, pos in subtree.leaves()])) | |
elif current_chunk: | |
named_entity = ' '.join(current_chunk) | |
if named_entity not in continuous_chunk: | |
continuous_chunk.append(named_entity) | |
current_chunk = [] | |
else: | |
continue | |
return continuous_chunk | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
output_dir = cfg.OUTPUT_DIR | |
lr = cfg.SOLVER.BASE_LR | |
output_folder = os.path.join( | |
cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}") | |
# train cfg.RUN_N_TIMES times | |
count = 1 | |
while count <= cfg.RUN_N_TIMES: | |
output_path = os.path.join(output_dir, output_folder, f"run{count}") | |
# pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa | |
sleep(randint(3, 30)) | |
if not PathManager.exists(output_path): | |
PathManager.mkdirs(output_path) | |
cfg.OUTPUT_DIR = output_path | |
break | |
else: | |
count += 1 | |
cfg.freeze() | |
return cfg | |
def get_similarity_map(sm, shape): | |
# sm: torch.Size([1, 196, 1]) | |
# min-max norm | |
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1]) | |
# reshape | |
side = int(sm.shape[1] ** 0.5) # square output, side = 14 | |
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2) | |
# interpolate | |
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') | |
sm = sm.permute(0, 2, 3, 1) | |
return sm.squeeze(0) | |
def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classes_colors,save_path=None,live=False): | |
# Find the class index with the highest similarity for each pixel | |
class_indices = np.argmax(pixel_similarity_array, axis=0) | |
# Create an HSV image placeholder | |
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3) | |
hsv_image[..., 2] = 1 # Set Value to 1 for a white base | |
# Set the hue and value channels | |
for i, color in enumerate(classes_colors): | |
rgb_color = np.array(color).reshape(1, 1, 3) | |
hsv_color = rgb_to_hsv(rgb_color) | |
mask = class_indices == i | |
if i < len(classes): # For the first N-2 classes, set color based on similarity | |
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue | |
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation | |
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value | |
else: # For the last two classes, set pixels to black | |
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black | |
hsv_image[..., 1][mask] = 0 # Saturation set to 0 | |
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black | |
mask_tensor_org = binary_sketch[:,:,0]/255 | |
hsv_image[mask_tensor_org==1] = [0,0,1] | |
# Convert the HSV image back to RGB to display and save | |
rgb_image = hsv_to_rgb(hsv_image) | |
if len(classes) > 1: | |
# Calculate centroids and render class names | |
for i, class_name in enumerate(classes): | |
mask = class_indices == i | |
if np.any(mask): | |
y, x = np.nonzero(mask) | |
centroid_x, centroid_y = np.mean(x), np.mean(y) | |
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i] | |
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8)) | |
# Display the image with class names | |
plt.imshow(rgb_image) | |
plt.axis('off') | |
plt.tight_layout() | |
if live: | |
plt.savefig('output.png', bbox_inches='tight', pad_inches=0) | |
else: | |
save_dir = "/".join(save_path.split("/")[:-1]) | |
if save_dir !='': | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
else: | | | |