RegionSpot / regionspot /util /preprocessing.py
bklg's picture
Upload 114 files
a153c95
raw
history blame
5.61 kB
import torch
import numpy as np
import json
import torchvision.transforms.functional as F
from regionspot.modeling.segment_anything.utils.transforms import ResizeLongestSide
NORM_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(1).unsqueeze(2)
NORM_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(1).unsqueeze(2)
def resize_box(after_image_size, befor_image_size, boxes, size=800, max_size=1333):
# size can be min_size (scalar) or (w, h) tuple
#size
#
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
size = get_size(befor_image_size, size, max_size)
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(after_image_size, befor_image_size))
ratio_width, ratio_height = ratios
# ratio_width, ratio_height = 1, 1
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height]
)
return scaled_boxes
def resize_and_normalize(image, target_size=(224, 224)):
resized_image = F.resize(image, target_size)
device = resized_image.device
return (resized_image - NORM_MEAN.to(device)) / NORM_STD.to(device)
def get_pred_boxes(pred_results, image_id):
scores = torch.tensor(pred_results[image_id]['scores'])
labels = torch.tensor(pred_results[image_id]['labels'])
boxes = torch.tensor(pred_results[image_id]['boxes'])
return scores, labels, boxes
def prepare_prompt_infer(batched_inputs, num_proposals=None, pred_results=None, target_size=(224,224)):
boxes_type = 'GT'
if pred_results is not None:
boxes_type = 'PRED_BOX'
for x in batched_inputs:
curr_image = x["image"]
x["curr_image"] = curr_image.clone()
image_id = x["image_id"]
image = curr_image.permute(1, 2, 0).to(torch.uint8)
curr_size = (image.shape[0], image.shape[1])
resized_image = resize_and_normalize(curr_image.cuda() / 255, target_size=target_size)
x["image"] = torch.as_tensor(ResizeLongestSide(1024).apply_image(np.array(image.cpu())), dtype=torch.float).permute(2, 0, 1).cuda()
raw_size = (x['height'], x['width'])
if boxes_type != 'GT':
scores, gt_label, boxes_prompt = get_pred_boxes(pred_results, str(image_id))
boxes_prompt = resize_box(curr_size, raw_size, boxes_prompt)
x['pred_boxes'] = boxes_prompt
x['scores'] = scores
else:
boxes_prompt = x["instances"].gt_boxes.tensor.cpu()
if len(boxes_prompt) == 0:
boxes_prompt = torch.tensor([[0, 0, *curr_size]])
boxes_prompt = ResizeLongestSide(1024).apply_boxes(np.array(boxes_prompt), curr_size)
x['boxes'] = torch.as_tensor(boxes_prompt, dtype=torch.float).cuda()
x['resized_image'] = resized_image
x['original_size'] = curr_size
return batched_inputs
def prepare_prompt_train(batched_inputs, target_size=(224,224)):
max_boxes = max(len(x["extra_info"]['mask_tokens']) for x in batched_inputs)
num_proposals = max(max_boxes, 1)
for x in batched_inputs:
raw_image = x["image"]
image = (x["image"].permute(1, 2, 0)).to(torch.uint8)
curr_size = (image.shape[0], image.shape[1])
resized_image = resize_and_normalize(raw_image.cuda() / 255, target_size=target_size)
input_image = ResizeLongestSide(1024).apply_image(np.array(image.cpu()))
input_image_torch = torch.as_tensor(input_image, dtype=torch.float).permute(2, 0, 1).cuda()
x["image"] = input_image_torch
mask_tokens = x["extra_info"]['mask_tokens'].clone().detach().cuda()
labels = torch.tensor(x["extra_info"]['classes']).cuda()
if x['dataset_name'] == 'coco':
try:
# Convert labels using the coco_new_dict
labels = [constants.coco_new_dict[label.item()] for label in labels]
labels = torch.tensor(labels).cuda()
except:
pass
else:
# Decrement each label by 1 unless it's zero
new_labels = [label.item() - 1 if label.item() != 0 else 0 for label in labels]
labels = torch.tensor(new_labels).cuda()
num_gt = len(mask_tokens)
num_repeat = num_proposals // num_gt
repeat_tensor = [num_repeat] * (num_gt - num_proposals % num_gt) + [num_repeat + 1] * (num_proposals % num_gt)
repeat_tensor = torch.tensor(repeat_tensor).cuda()
mask_tokens = torch.repeat_interleave(mask_tokens, repeat_tensor, dim=0)
labels = torch.repeat_interleave(labels, repeat_tensor, dim=0)
x['resized_image'] = resized_image
x['label'] = labels
x['mask_tokens'] = mask_tokens
x['original_size'] = curr_size
return batched_inputs