from copy import deepcopy import cv2 import numpy as np import onnxruntime from src import app_logger class SegmentAnythingONNX: """Segmentation model using SegmentAnything""" def __init__(self, encoder_model_path, decoder_model_path) -> None: self.target_size = 1024 self.input_size = (684, 1024) # Load models providers = onnxruntime.get_available_providers() # Pop TensorRT Runtime due to crashing issues # TODO: Add back when TensorRT backend is stable providers = [p for p in providers if p != "TensorrtExecutionProvider"] if providers: app_logger.info( "Available providers for ONNXRuntime: %s", ", ".join(providers) ) else: app_logger.warning("No available providers for ONNXRuntime") self.encoder_session = onnxruntime.InferenceSession( encoder_model_path, providers=providers ) self.encoder_input_name = self.encoder_session.get_inputs()[0].name self.decoder_session = onnxruntime.InferenceSession( decoder_model_path, providers=providers ) @staticmethod def get_input_points(prompt): """Get input points""" points = [] labels = [] for mark in prompt: if mark["type"] == "point": points.append(mark["data"]) labels.append(mark["label"]) elif mark["type"] == "rectangle": points.append([mark["data"][0], mark["data"][1]]) # top left points.append( [mark["data"][2], mark["data"][3]] ) # bottom right labels.append(2) labels.append(3) points, labels = np.array(points), np.array(labels) return points, labels def run_encoder(self, encoder_inputs): """Run encoder""" output = self.encoder_session.run(None, encoder_inputs) image_embedding = output[0] return image_embedding @staticmethod def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int): """ Compute the output size given input size and target long side length. """ scale = long_side_length * 1.0 / max(old_h, old_w) new_h, new_w = old_h * scale, old_w * scale new_w = int(new_w + 0.5) new_h = int(new_h + 0.5) return new_h, new_w def apply_coords(self, coords: np.ndarray, original_size, target_length): """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size new_h, new_w = self.get_preprocess_shape( original_size[0], original_size[1], target_length ) coords = deepcopy(coords).astype(float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords def run_decoder( self, image_embedding, original_size, transform_matrix, prompt ): """Run decoder""" input_points, input_labels = self.get_input_points(prompt) # Add a batch index, concatenate a padding point, and transform. onnx_coord = np.concatenate( [input_points, np.array([[0.0, 0.0]])], axis=0 )[None, :, :] onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ None, : ].astype(np.float32) onnx_coord = self.apply_coords( onnx_coord, self.input_size, self.target_size ).astype(np.float32) # Apply the transformation matrix to the coordinates. onnx_coord = np.concatenate( [ onnx_coord, np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32), ], axis=2, ) onnx_coord = np.matmul(onnx_coord, transform_matrix.T) onnx_coord = onnx_coord[:, :, :2].astype(np.float32) # Create an empty mask input and an indicator for no mask. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) decoder_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(self.input_size, dtype=np.float32), } masks, _, _ = self.decoder_session.run(None, decoder_inputs) # Transform the masks back to the original image size. inv_transform_matrix = np.linalg.inv(transform_matrix) transformed_masks = self.transform_masks( masks, original_size, inv_transform_matrix ) return transformed_masks @staticmethod def transform_masks(masks, original_size, transform_matrix): """Transform masks Transform the masks back to the original image size. """ output_masks = [] for batch in range(masks.shape[0]): batch_masks = [] for mask_id in range(masks.shape[1]): mask = masks[batch, mask_id] try: try: app_logger.debug(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.") except Exception as e_mask_shape_transform_masks: app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.") # raise e_mask_shape_transform_masks mask = cv2.warpAffine( mask, transform_matrix[:2], (original_size[1], original_size[0]), flags=cv2.INTER_LINEAR, ) except Exception as e_warp_affine1: app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.") app_logger.error(f"e_warp_affine1 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}.") app_logger.error(f"e_warp_affine1 original_size:{original_size}.") raise e_warp_affine1 batch_masks.append(mask) output_masks.append(batch_masks) return np.array(output_masks) def encode(self, cv_image): """ Calculate embedding and metadata for a single image. """ original_size = cv_image.shape[:2] # Calculate a transformation matrix to convert to self.input_size scale_x = self.input_size[1] / cv_image.shape[1] scale_y = self.input_size[0] / cv_image.shape[0] scale = min(scale_x, scale_y) transform_matrix = np.array( [ [scale, 0, 0], [0, scale, 0], [0, 0, 1], ] ) try: cv_image = cv2.warpAffine( cv_image, transform_matrix[:2], (self.input_size[1], self.input_size[0]), flags=cv2.INTER_LINEAR, ) except Exception as e_warp_affine2: app_logger.error(f"e_warp_affine2:{e_warp_affine2}.") np_cv_image = np.array(cv_image) app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.") app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}") app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.") raise e_warp_affine2 encoder_inputs = { self.encoder_input_name: cv_image.astype(np.float32), } image_embedding = self.run_encoder(encoder_inputs) return { "image_embedding": image_embedding, "original_size": original_size, "transform_matrix": transform_matrix, } def predict_masks(self, embedding, prompt): """ Predict masks for a single image. """ masks = self.run_decoder( embedding["image_embedding"], embedding["original_size"], embedding["transform_matrix"], prompt, ) return masks