samgis / src /prediction_api /sam_onnx.py
aletrn's picture
[debug] first test on lambda for fastsam prediction
fa76f5f
raw
history blame
7.04 kB
from copy import deepcopy
import cv2
import numpy as np
import onnxruntime
from src import app_logger
class SegmentAnythingONNX:
"""Segmentation model using SegmentAnything"""
def __init__(self, encoder_model_path, decoder_model_path) -> None:
self.target_size = 1024
self.input_size = (684, 1024)
# Load models
providers = onnxruntime.get_available_providers()
# Pop TensorRT Runtime due to crashing issues
# TODO: Add back when TensorRT backend is stable
providers = [p for p in providers if p != "TensorrtExecutionProvider"]
if providers:
app_logger.info(
"Available providers for ONNXRuntime: %s", ", ".join(providers)
)
else:
app_logger.warning("No available providers for ONNXRuntime")
self.encoder_session = onnxruntime.InferenceSession(
encoder_model_path, providers=providers
)
self.encoder_input_name = self.encoder_session.get_inputs()[0].name
self.decoder_session = onnxruntime.InferenceSession(
decoder_model_path, providers=providers
)
@staticmethod
def get_input_points(prompt):
"""Get input points"""
points = []
labels = []
for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]]) # top left
points.append(
[mark["data"][2], mark["data"][3]]
) # bottom right
labels.append(2)
labels.append(3)
points, labels = np.array(points), np.array(labels)
return points, labels
def run_encoder(self, encoder_inputs):
"""Run encoder"""
output = self.encoder_session.run(None, encoder_inputs)
image_embedding = output[0]
return image_embedding
@staticmethod
def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int):
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(old_h, old_w)
new_h, new_w = old_h * scale, old_w * scale
new_w = int(new_w + 0.5)
new_h = int(new_h + 0.5)
return new_h, new_w
def apply_coords(self, coords: np.ndarray, original_size, target_length):
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def run_decoder(
self, image_embedding, original_size, transform_matrix, prompt
):
"""Run decoder"""
input_points, input_labels = self.get_input_points(prompt)
# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate(
[input_points, np.array([[0.0, 0.0]])], axis=0
)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = self.apply_coords(
onnx_coord, self.input_size, self.target_size
).astype(np.float32)
# Apply the transformation matrix to the coordinates.
onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(self.input_size, dtype=np.float32),
}
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
# Transform the masks back to the original image size.
inv_transform_matrix = np.linalg.inv(transform_matrix)
transformed_masks = self.transform_masks(
masks, original_size, inv_transform_matrix
)
return transformed_masks
@staticmethod
def transform_masks(masks, original_size, transform_matrix):
"""Transform masks
Transform the masks back to the original image size.
"""
output_masks = []
for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)
return np.array(output_masks)
def encode(self, cv_image):
"""
Calculate embedding and metadata for a single image.
"""
original_size = cv_image.shape[:2]
# Calculate a transformation matrix to convert to self.input_size
scale_x = self.input_size[1] / cv_image.shape[1]
scale_y = self.input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)
transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)
cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(self.input_size[1], self.input_size[0]),
flags=cv2.INTER_LINEAR,
)
encoder_inputs = {
self.encoder_input_name: cv_image.astype(np.float32),
}
image_embedding = self.run_encoder(encoder_inputs)
return {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}
def predict_masks(self, embedding, prompt):
"""
Predict masks for a single image.
"""
masks = self.run_decoder(
embedding["image_embedding"],
embedding["original_size"],
embedding["transform_matrix"],
prompt,
)
return masks