|
import json |
|
from copy import deepcopy |
|
|
|
import cv2 |
|
import numpy as np |
|
import onnxruntime |
|
|
|
from src import app_logger |
|
from src.utilities.serialize import serialize |
|
|
|
|
|
class SegmentAnythingONNX: |
|
"""Segmentation model using SegmentAnything""" |
|
|
|
def __init__(self, encoder_model_path, decoder_model_path) -> None: |
|
self.target_size = 1024 |
|
self.input_size = (684, 1024) |
|
|
|
|
|
providers = onnxruntime.get_available_providers() |
|
|
|
|
|
|
|
providers = [p for p in providers if p != "TensorrtExecutionProvider"] |
|
|
|
if providers: |
|
app_logger.info( |
|
"Available providers for ONNXRuntime: %s", ", ".join(providers) |
|
) |
|
else: |
|
app_logger.warning("No available providers for ONNXRuntime") |
|
self.encoder_session = onnxruntime.InferenceSession( |
|
encoder_model_path, providers=providers |
|
) |
|
self.encoder_input_name = self.encoder_session.get_inputs()[0].name |
|
self.decoder_session = onnxruntime.InferenceSession( |
|
decoder_model_path, providers=providers |
|
) |
|
|
|
@staticmethod |
|
def get_input_points(prompt): |
|
"""Get input points""" |
|
points = [] |
|
labels = [] |
|
for mark in prompt: |
|
if mark["type"] == "point": |
|
points.append(mark["data"]) |
|
labels.append(mark["label"]) |
|
elif mark["type"] == "rectangle": |
|
points.append([mark["data"][0], mark["data"][1]]) |
|
points.append( |
|
[mark["data"][2], mark["data"][3]] |
|
) |
|
labels.append(2) |
|
labels.append(3) |
|
points, labels = np.array(points), np.array(labels) |
|
return points, labels |
|
|
|
def run_encoder(self, encoder_inputs): |
|
"""Run encoder""" |
|
output = self.encoder_session.run(None, encoder_inputs) |
|
image_embedding = output[0] |
|
return image_embedding |
|
|
|
@staticmethod |
|
def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int): |
|
""" |
|
Compute the output size given input size and target long side length. |
|
""" |
|
scale = long_side_length * 1.0 / max(old_h, old_w) |
|
new_h, new_w = old_h * scale, old_w * scale |
|
new_w = int(new_w + 0.5) |
|
new_h = int(new_h + 0.5) |
|
return new_h, new_w |
|
|
|
def apply_coords(self, coords: np.ndarray, original_size, target_length): |
|
""" |
|
Expects a numpy array of length 2 in the final dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], target_length |
|
) |
|
coords = deepcopy(coords).astype(float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def run_decoder( |
|
self, image_embedding, original_size, transform_matrix, prompt |
|
): |
|
"""Run decoder""" |
|
input_points, input_labels = self.get_input_points(prompt) |
|
|
|
|
|
onnx_coord = np.concatenate( |
|
[input_points, np.array([[0.0, 0.0]])], axis=0 |
|
)[None, :, :] |
|
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ |
|
None, : |
|
].astype(np.float32) |
|
onnx_coord = self.apply_coords( |
|
onnx_coord, self.input_size, self.target_size |
|
).astype(np.float32) |
|
|
|
|
|
onnx_coord = np.concatenate( |
|
[ |
|
onnx_coord, |
|
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32), |
|
], |
|
axis=2, |
|
) |
|
onnx_coord = np.matmul(onnx_coord, transform_matrix.T) |
|
onnx_coord = onnx_coord[:, :, :2].astype(np.float32) |
|
|
|
|
|
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) |
|
onnx_has_mask_input = np.zeros(1, dtype=np.float32) |
|
|
|
decoder_inputs = { |
|
"image_embeddings": image_embedding, |
|
"point_coords": onnx_coord, |
|
"point_labels": onnx_label, |
|
"mask_input": onnx_mask_input, |
|
"has_mask_input": onnx_has_mask_input, |
|
"orig_im_size": np.array(self.input_size, dtype=np.float32), |
|
} |
|
masks, _, _ = self.decoder_session.run(None, decoder_inputs) |
|
|
|
|
|
inv_transform_matrix = np.linalg.inv(transform_matrix) |
|
transformed_masks = self.transform_masks( |
|
masks, original_size, inv_transform_matrix |
|
) |
|
|
|
return transformed_masks |
|
|
|
@staticmethod |
|
def transform_masks(masks, original_size, transform_matrix): |
|
"""Transform masks |
|
Transform the masks back to the original image size. |
|
""" |
|
output_masks = [] |
|
for batch in range(masks.shape[0]): |
|
batch_masks = [] |
|
for mask_id in range(masks.shape[1]): |
|
mask = masks[batch, mask_id] |
|
try: |
|
try: |
|
app_logger.info(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.") |
|
except Exception as e_mask_shape_transform_masks: |
|
app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.") |
|
|
|
output_filename = f"2_cv2img_{'_'.join([str(s) for s in mask.shape])}.npy" |
|
np.save(output_filename, np.array(mask), allow_pickle=True, fix_imports=True) |
|
app_logger.info(f"written: /tmp/{output_filename} ...") |
|
with open("/tmp/2_args.json", "w") as jj_out_dst: |
|
json.dump({ |
|
"transform_matrix": serialize(transform_matrix), |
|
"M": serialize(transform_matrix[:2]), |
|
"original_size": serialize(original_size), |
|
"dsize": serialize((original_size[1], original_size[0])), |
|
"flags": cv2.INTER_LINEAR |
|
}, jj_out_dst) |
|
app_logger.info(f"written: /tmp/jj_out.json") |
|
mask = cv2.warpAffine( |
|
mask, |
|
transform_matrix[:2], |
|
(original_size[1], original_size[0]), |
|
flags=cv2.INTER_LINEAR, |
|
) |
|
except Exception as e_warp_affine1: |
|
app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.") |
|
app_logger.error(f"e_warp_affine1 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}.") |
|
app_logger.error(f"e_warp_affine1 original_size:{original_size}.") |
|
raise e_warp_affine1 |
|
batch_masks.append(mask) |
|
output_masks.append(batch_masks) |
|
return np.array(output_masks) |
|
|
|
def encode(self, cv_image): |
|
""" |
|
Calculate embedding and metadata for a single image. |
|
""" |
|
original_size = cv_image.shape[:2] |
|
|
|
|
|
scale_x = self.input_size[1] / cv_image.shape[1] |
|
scale_y = self.input_size[0] / cv_image.shape[0] |
|
scale = min(scale_x, scale_y) |
|
transform_matrix = np.array( |
|
[ |
|
[scale, 0, 0], |
|
[0, scale, 0], |
|
[0, 0, 1], |
|
] |
|
) |
|
try: |
|
np_cv_image = np.array(cv_image) |
|
try: |
|
app_logger.info(f"cv_image shape_encode:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.") |
|
except Exception as e_cv_image_shape_encode: |
|
app_logger.error(f"e_cv_image_shape_encode:{e_cv_image_shape_encode}.") |
|
|
|
output_filename = f"/tmp/1_cv2img_{'_'.join([str(s) for s in np_cv_image.shape])}.npy" |
|
np.save(output_filename, np_cv_image, allow_pickle=True, fix_imports=True) |
|
app_logger.info(f"written: /tmp/{output_filename} ...") |
|
with open("/tmp/1_args.json", "w") as jj_out_dst: |
|
json.dump({ |
|
"transform_matrix": serialize(transform_matrix), |
|
"M": serialize(transform_matrix[:2]), |
|
"flags": cv2.INTER_LINEAR |
|
}, jj_out_dst) |
|
app_logger.info(f"written: /tmp/jj_out.json") |
|
cv_image = cv2.warpAffine( |
|
cv_image, |
|
transform_matrix[:2], |
|
(self.input_size[1], self.input_size[0]), |
|
flags=cv2.INTER_LINEAR, |
|
) |
|
except Exception as e_warp_affine2: |
|
app_logger.error(f"e_warp_affine2:{e_warp_affine2}.") |
|
np_cv_image = np.array(cv_image) |
|
app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.") |
|
app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}") |
|
app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.") |
|
raise e_warp_affine2 |
|
|
|
encoder_inputs = { |
|
self.encoder_input_name: cv_image.astype(np.float32), |
|
} |
|
image_embedding = self.run_encoder(encoder_inputs) |
|
return { |
|
"image_embedding": image_embedding, |
|
"original_size": original_size, |
|
"transform_matrix": transform_matrix, |
|
} |
|
|
|
def predict_masks(self, embedding, prompt): |
|
""" |
|
Predict masks for a single image. |
|
""" |
|
masks = self.run_decoder( |
|
embedding["image_embedding"], |
|
embedding["original_size"], |
|
embedding["transform_matrix"], |
|
prompt, |
|
) |
|
|
|
return masks |
|
|