samgis / src /prediction_api /sam_onnx.py
aletrn's picture
[feat] wip workflow from request to response
6d1f220
raw
history blame
10.2 kB
import json
from copy import deepcopy
import cv2
import numpy as np
import onnxruntime
from src import app_logger
from src.utilities.serialize import serialize
class SegmentAnythingONNX:
"""Segmentation model using SegmentAnything"""
def __init__(self, encoder_model_path, decoder_model_path) -> None:
self.target_size = 1024
self.input_size = (684, 1024)
# Load models
providers = onnxruntime.get_available_providers()
# Pop TensorRT Runtime due to crashing issues
# TODO: Add back when TensorRT backend is stable
providers = [p for p in providers if p != "TensorrtExecutionProvider"]
if providers:
app_logger.info(
"Available providers for ONNXRuntime: %s", ", ".join(providers)
)
else:
app_logger.warning("No available providers for ONNXRuntime")
self.encoder_session = onnxruntime.InferenceSession(
encoder_model_path, providers=providers
)
self.encoder_input_name = self.encoder_session.get_inputs()[0].name
self.decoder_session = onnxruntime.InferenceSession(
decoder_model_path, providers=providers
)
@staticmethod
def get_input_points(prompt):
"""Get input points"""
points = []
labels = []
for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]]) # top left
points.append(
[mark["data"][2], mark["data"][3]]
) # bottom right
labels.append(2)
labels.append(3)
points, labels = np.array(points), np.array(labels)
return points, labels
def run_encoder(self, encoder_inputs):
"""Run encoder"""
output = self.encoder_session.run(None, encoder_inputs)
image_embedding = output[0]
return image_embedding
@staticmethod
def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int):
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(old_h, old_w)
new_h, new_w = old_h * scale, old_w * scale
new_w = int(new_w + 0.5)
new_h = int(new_h + 0.5)
return new_h, new_w
def apply_coords(self, coords: np.ndarray, original_size, target_length):
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def run_decoder(
self, image_embedding, original_size, transform_matrix, prompt
):
"""Run decoder"""
input_points, input_labels = self.get_input_points(prompt)
# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate(
[input_points, np.array([[0.0, 0.0]])], axis=0
)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = self.apply_coords(
onnx_coord, self.input_size, self.target_size
).astype(np.float32)
# Apply the transformation matrix to the coordinates.
onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(self.input_size, dtype=np.float32),
}
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
# Transform the masks back to the original image size.
inv_transform_matrix = np.linalg.inv(transform_matrix)
transformed_masks = self.transform_masks(
masks, original_size, inv_transform_matrix
)
return transformed_masks
@staticmethod
def transform_masks(masks, original_size, transform_matrix):
"""Transform masks
Transform the masks back to the original image size.
"""
output_masks = []
for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
try:
try:
app_logger.info(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
except Exception as e_mask_shape_transform_masks:
app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
# raise e_mask_shape_transform_masks
output_filename = f"2_cv2img_{'_'.join([str(s) for s in mask.shape])}.npy"
np.save(output_filename, np.array(mask), allow_pickle=True, fix_imports=True)
app_logger.info(f"written: /tmp/{output_filename} ...")
with open("/tmp/2_args.json", "w") as jj_out_dst:
json.dump({
"transform_matrix": serialize(transform_matrix),
"M": serialize(transform_matrix[:2]),
"original_size": serialize(original_size),
"dsize": serialize((original_size[1], original_size[0])),
"flags": cv2.INTER_LINEAR
}, jj_out_dst)
app_logger.info(f"written: /tmp/jj_out.json")
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
except Exception as e_warp_affine1:
app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.")
app_logger.error(f"e_warp_affine1 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}.")
app_logger.error(f"e_warp_affine1 original_size:{original_size}.")
raise e_warp_affine1
batch_masks.append(mask)
output_masks.append(batch_masks)
return np.array(output_masks)
def encode(self, cv_image):
"""
Calculate embedding and metadata for a single image.
"""
original_size = cv_image.shape[:2]
# Calculate a transformation matrix to convert to self.input_size
scale_x = self.input_size[1] / cv_image.shape[1]
scale_y = self.input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)
transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)
try:
np_cv_image = np.array(cv_image)
try:
app_logger.info(f"cv_image shape_encode:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
except Exception as e_cv_image_shape_encode:
app_logger.error(f"e_cv_image_shape_encode:{e_cv_image_shape_encode}.")
# raise e_cv_image_shape_encode
output_filename = f"/tmp/1_cv2img_{'_'.join([str(s) for s in np_cv_image.shape])}.npy"
np.save(output_filename, np_cv_image, allow_pickle=True, fix_imports=True)
app_logger.info(f"written: /tmp/{output_filename} ...")
with open("/tmp/1_args.json", "w") as jj_out_dst:
json.dump({
"transform_matrix": serialize(transform_matrix),
"M": serialize(transform_matrix[:2]),
"flags": cv2.INTER_LINEAR
}, jj_out_dst)
app_logger.info(f"written: /tmp/jj_out.json")
cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(self.input_size[1], self.input_size[0]),
flags=cv2.INTER_LINEAR,
)
except Exception as e_warp_affine2:
app_logger.error(f"e_warp_affine2:{e_warp_affine2}.")
np_cv_image = np.array(cv_image)
app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}")
app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.")
raise e_warp_affine2
encoder_inputs = {
self.encoder_input_name: cv_image.astype(np.float32),
}
image_embedding = self.run_encoder(encoder_inputs)
return {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}
def predict_masks(self, embedding, prompt):
"""
Predict masks for a single image.
"""
masks = self.run_decoder(
embedding["image_embedding"],
embedding["original_size"],
embedding["transform_matrix"],
prompt,
)
return masks