Spaces:
Running
Running
""" | |
inference.py | |
------------ | |
Provides functionality to run the OPDMulti model on an input image, independent of dataset and ground truth, and | |
visualize the output. Large portions of the code originate from get_prediction.py, rgbd_to_pcd_vis.py, | |
evaluate_on_log.py, and other related files. The primary goal was to create a more standalone script which could be | |
converted more easily into a public demo, thus the goal was to sever most dependencies on existing ground truth or | |
datasets. | |
Example usage: | |
python inference.py \ | |
--rgb path/to/59-4860.png \ | |
--depth path/to/59-4860_d.png \ | |
--model path/to/model.pth \ | |
--output path/to/output_dir | |
""" | |
import argparse | |
import logging | |
import os | |
import time | |
from copy import deepcopy | |
from typing import Any | |
import imageio | |
import open3d as o3d | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from detectron2 import engine, evaluation | |
from detectron2.modeling import build_model | |
from detectron2.config import get_cfg, CfgNode | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.structures import instances | |
from detectron2.utils import comm | |
from detectron2.utils.logger import setup_logger | |
from PIL import Image, ImageChops | |
from mask2former import ( | |
add_maskformer2_config, | |
add_motionnet_config, | |
) | |
from utilities import prediction_to_json | |
# import based on torch version. Required for model loading. Code is taken from fvcore.common.checkpoint, in order to | |
# replicate model loading without the overhead of setting up an OPDTrainer | |
TORCH_VERSION: tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) | |
if TORCH_VERSION >= (1, 11): | |
from torch.ao import quantization | |
from torch.ao.quantization import FakeQuantizeBase, ObserverBase | |
elif ( | |
TORCH_VERSION >= (1, 8) | |
and hasattr(torch.quantization, "FakeQuantizeBase") | |
and hasattr(torch.quantization, "ObserverBase") | |
): | |
from torch import quantization | |
from torch.quantization import FakeQuantizeBase, ObserverBase | |
# TODO: find a global place for this instead of in many places in code | |
TYPE_CLASSIFICATION = { | |
0: "rotation", | |
1: "translation", | |
} | |
POINT_COLOR = [1, 0, 0] # red for demonstration | |
ARROW_COLOR = [0, 1, 0] # green | |
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg") | |
def get_parser() -> argparse.ArgumentParser: | |
""" | |
Specfy command-line arguments. | |
The primary inputs to the script should be the image paths (RGBD) and camera intrinsics. Other arguments are | |
provided to facilitate script testing and model changes. Run file with -h/--help to see all arguments. | |
:return: parser for extracting command-line arguments | |
""" | |
parser = argparse.ArgumentParser(description="Inference for OPDMulti") | |
# The main arguments which should be specified by the user | |
parser.add_argument( | |
"--rgb", | |
dest="rgb_image", | |
metavar="FILE", | |
help="path to RGB image file on which to run model", | |
) | |
parser.add_argument( | |
"--depth", | |
dest="depth_image", | |
metavar="FILE", | |
help="path to depth image file on which to run model", | |
) | |
parser.add_argument( # FIXME: might make more sense to make this a path | |
"-i", | |
"--intrinsics", | |
nargs=9, | |
default=[ | |
214.85935872395834, | |
0.0, | |
0.0, | |
0.0, | |
214.85935872395834, | |
0.0, | |
125.90160319010417, | |
95.13726399739583, | |
1.0, | |
], | |
dest="intrinsics", | |
help="camera intrinsics matrix, as a list of values", | |
) | |
# optional parameters for user to specify | |
parser.add_argument( | |
"-n", | |
"--num-samples", | |
default=10, | |
dest="num_samples", | |
metavar="NUM", | |
help="number of sample states to generate in visualization", | |
) | |
parser.add_argument( | |
"--crop", | |
action="store_true", | |
dest="crop", | |
help="crop whitespace out of images for visualization", | |
) | |
# local script development arguments | |
parser.add_argument( | |
"-m", | |
"--model", | |
default="path/to/model/file", # FIXME: set a good default path | |
dest="model", | |
metavar="FILE", | |
help="path to model file to run", | |
) | |
parser.add_argument( | |
"-c", | |
"--config", | |
default="configs/coco/instance-segmentation/swin/opd_v1_real.yaml", | |
metavar="FILE", | |
dest="config_file", | |
help="path to config file", | |
) | |
parser.add_argument( | |
"-o", | |
"--output", | |
default="output", # FIXME: set a good default path | |
dest="output", | |
help="path to output directory in which to save results", | |
) | |
parser.add_argument( | |
"--num-processes", | |
default=1, | |
dest="num_processes", | |
help="number of processes per machine. When using GPUs, this should be the number of GPUs.", | |
) | |
parser.add_argument( | |
"-s", | |
"--score-threshold", | |
default=0.8, | |
type=float, | |
dest="score_threshold", | |
help="threshold between 0.0 and 1.0 by which to filter out bad predictions", | |
) | |
parser.add_argument( | |
"--input-format", | |
default="RGB", | |
dest="input_format", | |
help="input format of image. Must be one of RGB, RGBD, or depth", | |
) | |
parser.add_argument( | |
"--cpu", | |
action="store_true", | |
help="flag to require code to use CPU only", | |
) | |
return parser | |
def setup_cfg(args: argparse.Namespace) -> CfgNode: | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
# add model configurations | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
add_motionnet_config(cfg) | |
cfg.merge_from_file(args.config_file) | |
# set additional config parameters | |
cfg.MODEL.WEIGHTS = args.model | |
cfg.OBJ_DETECT = False # TODO: figure out if this is needed, and parameterize it | |
cfg.MODEL.MOTIONNET.VOTING = "none" | |
# Output directory | |
cfg.OUTPUT_DIR = args.output | |
cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda" | |
cfg.MODEL.MODELATTRPATH = None | |
# Input format | |
cfg.INPUT.FORMAT = args.input_format | |
if args.input_format == "RGB": | |
cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[0:3] | |
cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[0:3] | |
elif args.input_format == "depth": | |
cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[3:4] | |
cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[3:4] | |
elif args.input_format == "RGBD": | |
pass | |
else: | |
raise ValueError("Invalid input format") | |
cfg.freeze() | |
engine.default_setup(cfg, args) | |
# Setup logger for "mask_former" module | |
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="opdformer") | |
return cfg | |
def format_input(rgb_path: str) -> list[dict[str, Any]]: | |
""" | |
Read and format input image into detectron2 form so that it can be passed to the model. | |
:param rgb_path: path to RGB image file | |
:return: list of dictionaries per image, where each dictionary is of the form | |
{ | |
"file_name": path to RGB image, | |
"image": torch.Tensor of dimensions [channel, height, width] representing the image | |
} | |
""" | |
image = imageio.imread(rgb_path).astype(np.float32) | |
image_tensor = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # dim: [channel, height, width] | |
return [{"file_name": rgb_path, "image": image_tensor}] | |
def load_model(model: nn.Module, checkpoint: Any) -> None: | |
""" | |
Load weights from a checkpoint. | |
The majority of the function definition is taken from the DetectionCheckpointer implementation provided in | |
detectron2. While not all of this code is necessarily needed for model loading, it was ported with the intention | |
of keeping the implementation and output as close to the original as possible, and reusing the checkpoint class here | |
in isolation was determined to be infeasible. | |
:param model: model for which to load weights | |
:param checkpoint: checkpoint contains the weights. | |
""" | |
def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None: | |
"""If prefix is found on all keys in state dict, remove prefix.""" | |
keys = sorted(state_dict.keys()) | |
if not all(len(key) == 0 or key.startswith(prefix) for key in keys): | |
return | |
for key in keys: | |
newkey = key[len(prefix) :] | |
state_dict[newkey] = state_dict.pop(key) | |
checkpoint_state_dict = checkpoint.pop("model") | |
# convert from numpy to tensor | |
for k, v in checkpoint_state_dict.items(): | |
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor): | |
raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v))) | |
if not isinstance(v, torch.Tensor): | |
checkpoint_state_dict[k] = torch.from_numpy(v) | |
# if the state_dict comes from a model that was wrapped in a | |
# DataParallel or DistributedDataParallel during serialization, | |
# remove the "module" prefix before performing the matching. | |
_strip_prefix_if_present(checkpoint_state_dict, "module.") | |
# workaround https://github.com/pytorch/pytorch/issues/24139 | |
model_state_dict = model.state_dict() | |
incorrect_shapes = [] | |
for k in list(checkpoint_state_dict.keys()): # state dict is modified in loop, so list op is necessary | |
if k in model_state_dict: | |
model_param = model_state_dict[k] | |
# Allow mismatch for uninitialized parameters | |
if TORCH_VERSION >= (1, 8) and isinstance(model_param, nn.parameter.UninitializedParameter): | |
continue | |
shape_model = tuple(model_param.shape) | |
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) | |
if shape_model != shape_checkpoint: | |
has_observer_base_classes = ( | |
TORCH_VERSION >= (1, 8) | |
and hasattr(quantization, "ObserverBase") | |
and hasattr(quantization, "FakeQuantizeBase") | |
) | |
if has_observer_base_classes: | |
# Handle the special case of quantization per channel observers, | |
# where buffer shape mismatches are expected. | |
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: | |
# foo.bar.param_or_buffer_name -> [foo, bar] | |
key_parts = key.split(".")[:-1] | |
cur_module = model | |
for key_part in key_parts: | |
cur_module = getattr(cur_module, key_part) | |
return cur_module | |
cls_to_skip = ( | |
ObserverBase, | |
FakeQuantizeBase, | |
) | |
target_module = _get_module_for_key(model, k) | |
if isinstance(target_module, cls_to_skip): | |
# Do not remove modules with expected shape mismatches | |
# them from the state_dict loading. They have special logic | |
# in _load_from_state_dict to handle the mismatches. | |
continue | |
incorrect_shapes.append((k, shape_checkpoint, shape_model)) | |
checkpoint_state_dict.pop(k) | |
model.load_state_dict(checkpoint_state_dict, strict=False) | |
def predict(model: nn.Module, inp: list[dict[str, Any]]) -> list[dict[str, instances.Instances]]: | |
""" | |
Compute model predictions. | |
:param model: model to run on input | |
:param inp: input, in the form | |
{ | |
"image_file": path to image, | |
"image": float32 torch.tensor of dimensions [channel, height, width] as RGB/RGBD/depth image | |
} | |
:return: list of detected instances and predicted openable parameters | |
""" | |
with torch.no_grad(), evaluation.inference_context(model): | |
out = model(inp) | |
return out | |
def generate_rotation_visualization( | |
pcd: o3d.geometry.PointCloud, | |
axis_arrow: o3d.geometry.TriangleMesh, | |
mask: np.ndarray, | |
axis_vector: np.ndarray, | |
origin: np.ndarray, | |
range_min: float, | |
range_max: float, | |
num_samples: int, | |
output_dir: str, | |
) -> None: | |
""" | |
Generate visualization files for a rotation motion of a part. | |
:param pcd: point cloud object representing 2D image input (RGBD) as a point cloud | |
:param axis_arrow: mesh object representing axis arrow of rotation to be rendered in visualization | |
:param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image | |
:param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation | |
:param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation | |
:param range_min: float representing the minimum range of motion in radians | |
:param range_max: float representing the maximum range of motion in radians | |
:param num_samples: number of sample states to visualize in between range_min and range_max of motion | |
:param output_dir: string path to directory in which to save visualization output | |
""" | |
angle_in_radians = np.linspace(range_min, range_max, num_samples) | |
angles_in_degrees = angle_in_radians * 180 / np.pi | |
for idx, angle_in_degrees in enumerate(angles_in_degrees): | |
# Make a copy of your original point cloud and arrow for each rotation | |
rotated_pcd = deepcopy(pcd) | |
rotated_arrow = deepcopy(axis_arrow) | |
angle_rad = np.radians(angle_in_degrees) | |
rotated_pcd = rotate_part(rotated_pcd, mask, axis_vector, origin, angle_rad) | |
# Create a Visualizer object for each rotation | |
vis = o3d.visualization.Visualizer() | |
vis.create_window() | |
# Add the rotated geometries | |
vis.add_geometry(rotated_pcd) | |
vis.add_geometry(rotated_arrow) | |
# Apply the additional rotation around x-axis if desired | |
angle_x = np.pi * 5.5 / 5 # 198 degrees | |
rotation_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x) | |
rotated_pcd.rotate(rotation_matrix, center=rotated_pcd.get_center()) | |
rotated_arrow.rotate(rotation_matrix, center=rotated_pcd.get_center()) | |
# Capture and save the image | |
output_filename = f"{output_dir}/{idx}.png" | |
vis.capture_screen_image(output_filename, do_render=True) | |
vis.destroy_window() | |
def generate_translation_visualization( | |
pcd: o3d.geometry.PointCloud, | |
axis_arrow: o3d.geometry.TriangleMesh, | |
mask: np.ndarray, | |
end: np.ndarray, | |
range_min: float, | |
range_max: float, | |
num_samples: int, | |
output_dir: str, | |
) -> None: | |
""" | |
Generate visualization files for a translation motion of a part. | |
:param pcd: point cloud object representing 2D image input (RGBD) as a point cloud | |
:param axis_arrow: mesh object representing axis arrow of translation to be rendered in visualization | |
:param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image | |
:param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation | |
:param origin: np.array of dimensions (3, ) representing the origin point of the axis of translation | |
:param range_min: float representing the minimum range of motion | |
:param range_max: float representing the maximum range of motion | |
:param num_samples: number of sample states to visualize in between range_min and range_max of motion | |
:param output_dir: string path to directory in which to save visualization output | |
""" | |
translate_distances = np.linspace(range_min, range_max, num_samples) | |
for idx, translate_distance in enumerate(translate_distances): | |
translated_pcd = deepcopy(pcd) | |
translated_arrow = deepcopy(axis_arrow) | |
translated_pcd = translate_part(translated_pcd, mask, end, translate_distance.item()) | |
# Create a Visualizer object for each rotation | |
vis = o3d.visualization.Visualizer() | |
vis.create_window() | |
# Add the translated geometries | |
vis.add_geometry(translated_pcd) | |
vis.add_geometry(translated_arrow) | |
# Apply the additional rotation around x-axis if desired | |
# TODO: not sure why we need this rotation for the translation, and when it would be desired | |
angle_x = np.pi * 5.5 / 5 # 198 degrees | |
R = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x) | |
translated_pcd.rotate(R, center=translated_pcd.get_center()) | |
translated_arrow.rotate(R, center=translated_pcd.get_center()) | |
# Capture and save the image | |
output_filename = f"{output_dir}/{idx}.png" | |
vis.capture_screen_image(output_filename, do_render=True) | |
vis.destroy_window() | |
def get_rotation_matrix_from_vectors(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray: | |
""" | |
Find the rotation matrix that aligns vec1 to vec2 | |
:param vec1: A 3d "source" vector | |
:param vec2: A 3d "destination" vector | |
:return: A transform matrix (3x3) which when applied to vec1, aligns it with vec2. | |
""" | |
a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3) | |
v = np.cross(a, b) | |
c = np.dot(a, b) | |
s = np.linalg.norm(v) | |
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) | |
rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2)) | |
return rotation_matrix | |
def draw_line(start_point: np.ndarray, end_point: np.ndarray) -> o3d.geometry.TriangleMesh: | |
""" | |
Generate 3D mesh representing axis from start_point to end_point. | |
:param start_point: np.ndarray of dimensions (3, ) representing the start point of the axis | |
:param end_point: np.ndarray of dimensions (3, ) representing the end point of the axis | |
:return: mesh object representing axis from start to end | |
""" | |
# Compute direction vector and normalize it | |
direction_vector = end_point - start_point | |
normalized_vector = direction_vector / np.linalg.norm(direction_vector) | |
# Compute the rotation matrix to align the Z-axis with the desired direction | |
target_vector = np.array([0, 0, 1]) | |
rot_mat = get_rotation_matrix_from_vectors(target_vector, normalized_vector) | |
# Create the cylinder (shaft of the arrow) | |
cylinder_length = 0.9 # 90% of the total arrow length, you can adjust as needed | |
cylinder_radius = 0.01 # Adjust the thickness of the arrow shaft | |
cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=cylinder_radius, height=cylinder_length) | |
# Move base of cylinder to origin, rotate, then translate to start_point | |
cylinder.translate([0, 0, 0]) | |
cylinder.rotate(rot_mat, center=[0, 0, 0]) | |
cylinder.translate(start_point) | |
# Create the cone (head of the arrow) | |
cone_height = 0.1 # 10% of the total arrow length, adjust as needed | |
cone_radius = 0.03 # Adjust the size of the arrowhead | |
cone = o3d.geometry.TriangleMesh.create_cone(radius=cone_radius, height=cone_height) | |
# Move base of cone to origin, rotate, then translate to end of cylinder | |
cone.translate([-0, 0, 0]) | |
cone.rotate(rot_mat, center=[0, 0, 0]) | |
cone.translate(start_point + normalized_vector * 0.4) | |
arrow = cylinder + cone | |
return arrow | |
def rotate_part( | |
pcd: o3d.geometry.PointCloud, mask: np.ndarray, axis_vector: np.ndarray, origin: np.ndarray, angle_rad: float | |
) -> o3d.geometry.PointCloud: | |
""" | |
Generate rotated point cloud of mask based on provided angle around axis. | |
:param pcd: point cloud object representing points of image | |
:param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image | |
:param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation | |
:param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation | |
:param angle_rad: angle in radians to rotate mask part | |
:return: point cloud object after rotation of masked part | |
""" | |
# Get the coordinates of the point cloud as a numpy array | |
points_np = np.asarray(pcd.points) | |
# Convert point cloud colors to numpy array for easier manipulation | |
colors_np = np.asarray(pcd.colors) | |
# Create skew-symmetric matrix from end | |
K = np.array( | |
[ | |
[0, -axis_vector[2], axis_vector[1]], | |
[axis_vector[2], 0, -axis_vector[0]], | |
[-axis_vector[1], axis_vector[0], 0], | |
] | |
) | |
# Compute rotation matrix using Rodrigues' formula | |
R = np.eye(3) + np.sin(angle_rad) * K + (1 - np.cos(angle_rad)) * np.dot(K, K) | |
# Iterate over the mask and rotate the points corresponding to the object pixels | |
for i in range(mask.shape[0]): | |
for j in range(mask.shape[1]): | |
if mask[i, j] > 0: # This condition checks if the pixel belongs to the object | |
point_index = i * mask.shape[1] + j | |
# Translate the point such that the rotation origin is at the world origin | |
translated_point = points_np[point_index] - origin | |
# Rotate the translated point | |
rotated_point = np.dot(R, translated_point) | |
# Translate the point back | |
points_np[point_index] = rotated_point + origin | |
colors_np[point_index] = POINT_COLOR | |
# Update the point cloud's coordinates | |
pcd.points = o3d.utility.Vector3dVector(points_np) | |
# Update point cloud colors | |
pcd.colors = o3d.utility.Vector3dVector(colors_np) | |
return pcd | |
def translate_part(pcd, mask, axis_vector, distance): | |
""" | |
Generate translated point cloud of mask based on provided angle around axis. | |
:param pcd: point cloud object representing points of image | |
:param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image | |
:param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation | |
:param distance: distance within coordinate system to translate mask part | |
:return: point cloud object after translation of masked part | |
""" | |
normalized_vector = axis_vector / np.linalg.norm(axis_vector) | |
translation_vector = normalized_vector * distance | |
# Convert point cloud colors to numpy array for easier manipulation | |
colors_np = np.asarray(pcd.colors) | |
# Get the coordinates of the point cloud as a numpy array | |
points_np = np.asarray(pcd.points) | |
# Iterate over the mask and assign the color to the points corresponding to the object pixels | |
for i in range(mask.shape[0]): | |
for j in range(mask.shape[1]): | |
if mask[i, j] > 0: # This condition checks if the pixel belongs to the object | |
point_index = i * mask.shape[1] + j | |
colors_np[point_index] = POINT_COLOR | |
points_np[point_index] += translation_vector | |
# Update point cloud colors | |
pcd.colors = o3d.utility.Vector3dVector(colors_np) | |
# Update the point cloud's coordinates | |
pcd.points = o3d.utility.Vector3dVector(points_np) | |
return pcd | |
def batch_trim(images_path: str, save_path: str, identical: bool = False) -> None: | |
""" | |
Trim white spaces from all images in the given path and save new images to folder. | |
:param images_path: local path to folder containing all images. Images must have the extension ".png", ".jpg", or | |
".jpeg". | |
:param save_path: local path to folder in which to save trimmed images | |
:param identical: if True, will apply same crop to all images, else each image will have its whitespace trimmed | |
independently. Note that in the latter case, each image may have a slightly different size. | |
""" | |
def get_trim(im): | |
"""Trim whitespace from an image and return the cropped image.""" | |
bg = Image.new(im.mode, im.size, im.getpixel((0, 0))) | |
diff = ImageChops.difference(im, bg) | |
diff = ImageChops.add(diff, diff, 2.0, -100) | |
bbox = diff.getbbox() | |
return bbox | |
if identical: # | |
images = [] | |
optimal_box = None | |
# load all images | |
for image_file in sorted(os.listdir(images_path)): | |
if image_file.endswith(IMAGE_EXTENSIONS): | |
image_path = os.path.join(images_path, image_file) | |
images.append(Image.open(image_path)) | |
# find optimal box size | |
for im in images: | |
bbox = get_trim(im) | |
if bbox is None: | |
bbox = (0, 0, im.size[0], im.size[1]) # bound entire image | |
if optimal_box is None: | |
optimal_box = bbox | |
else: | |
optimal_box = ( | |
min(optimal_box[0], bbox[0]), | |
min(optimal_box[1], bbox[1]), | |
max(optimal_box[2], bbox[2]), | |
max(optimal_box[3], bbox[3]), | |
) | |
# apply cropping, if optimal box was found | |
for idx, im in enumerate(images): | |
im.crop(optimal_box) | |
im.save(os.path.join(save_path, f"{idx}.png")) | |
im.close() | |
else: # trim each image separately | |
for image_file in os.listdir(images_path): | |
if image_file.endswith(IMAGE_EXTENSIONS): | |
image_path = os.path.join(images_path, image_file) | |
with Image.open(image_path) as im: | |
bbox = get_trim(im) | |
trimmed = im.crop(bbox) if bbox else im | |
trimmed.save(os.path.join(save_path, image_file)) | |
def create_gif(image_folder_path: str, num_samples: int, gif_filename: str = "output.gif") -> None: | |
""" | |
Create gif out of folder of images and save to file. | |
:param image_folder_path: path to folder containing images (non-recursive). Assumes images are named as {i}.png for | |
each of i from 0 to num_samples. | |
:param num_samples: number of sampled images to compile into gif. | |
:param gif_filename: filename for gif, defaults to "output.gif" | |
""" | |
# Generate a list of image filenames (assuming the images are saved as 0.png, 1.png, etc.) | |
image_files = [f"{image_folder_path}/{i}.png" for i in range(num_samples)] | |
# Read the images using imageio | |
images = [imageio.imread(image_file) for image_file in image_files] | |
assert all( | |
images[0].shape == im.shape for im in images | |
), f"Found some images with a different shape: {[im.shape for im in images]}" | |
# Save images as a gif | |
gif_output_path = f"{image_folder_path}/{gif_filename}" | |
imageio.mimsave(gif_output_path, images, duration=0.1) | |
return | |
def main( | |
cfg: CfgNode, | |
rgb_image: str, | |
depth_image: str, | |
intrinsics: list[float], | |
num_samples: int, | |
crop: bool, | |
score_threshold: float, | |
) -> None: | |
""" | |
Main inference method. | |
:param cfg: configuration object | |
:param rgb_image: local path to RGB image | |
:param depth_image: local path to depth image | |
:param intrinsics: camera intrinsics matrix as a list of 9 values | |
:param num_samples: number of sample visualization states to generate | |
:param crop: if True, images will be cropped to remove whitespace before visualization | |
:param score_threshold: float between 0 and 1 representing threshold at which to filter instances based on score | |
""" | |
logger = logging.getLogger("detectron2") | |
# setup data | |
logger.info("Loading image.") | |
inp = format_input(rgb_image) | |
# setup model | |
logger.info("Loading model.") | |
model = build_model(cfg) | |
weights = torch.load(cfg.MODEL.WEIGHTS, map_location=torch.device("cpu")) | |
if "model" not in weights: | |
weights = {"model": weights} | |
load_model(model, weights) | |
# run model on data | |
logger.info("Running model.") | |
prediction = predict(model, inp)[0] # index 0 since there is only one image | |
pred_instances = prediction["instances"] | |
# log results | |
image_id = os.path.splitext(os.path.basename(rgb_image))[0] | |
pred_dict = {"image_id": image_id} | |
instances = pred_instances.to(torch.device("cpu")) | |
pred_dict["instances"] = prediction_to_json(instances, image_id) | |
torch.save(pred_dict, os.path.join(cfg.OUTPUT_DIR, f"{image_id}_prediction.pth")) | |
# select best prediction to visualize | |
score_ranking = np.argsort([-1 * pred_instances[i].scores.item() for i in range(len(pred_instances))]) | |
score_ranking = [idx for idx in score_ranking if pred_instances[int(idx)].scores.item() > score_threshold] | |
if len(score_ranking) == 0: | |
logging.warning("The model did not predict any moving parts above the score threshold.") | |
return | |
for idx in score_ranking: # iterate through all best predictions, by score threshold | |
pred = pred_instances[int(idx)] # take highest predicted one | |
logger.info("Rendering prediction for instance %d", int(idx)) | |
output_dir = os.path.join(cfg.OUTPUT_DIR, str(idx)) | |
os.makedirs(output_dir, exist_ok=True) | |
# extract predicted values for visualization | |
mask = np.squeeze(pred.pred_masks.cpu().numpy()) # dim: [height, width] | |
origin = pred.morigin.cpu().numpy().flatten() # dim: [3, ] | |
axis_vector = pred.maxis.cpu().numpy().flatten() # dim: [3, ] | |
pred_type = TYPE_CLASSIFICATION.get(pred.mtype.item()) | |
range_min = 0 - pred.mstate.cpu().numpy() | |
range_max = pred.mstatemax.cpu().numpy() - pred.mstate.cpu().numpy() | |
# process visualization | |
color = o3d.io.read_image(rgb_image) | |
depth = o3d.io.read_image(depth_image) | |
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color, depth, convert_rgb_to_intensity=False) | |
color_np = np.asarray(color) | |
height, width = color_np.shape[:2] | |
# generate intrinsics | |
intrinsic_matrix = np.reshape(intrinsics, (3, 3), order="F") | |
intrinsic_obj = o3d.camera.PinholeCameraIntrinsic( | |
width, | |
height, | |
intrinsic_matrix[0, 0], | |
intrinsic_matrix[1, 1], | |
intrinsic_matrix[0, 2], | |
intrinsic_matrix[1, 2], | |
) | |
# Convert the RGBD image to a point cloud | |
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic_obj) | |
# Create a LineSet to visualize the direction vector | |
axis_arrow = draw_line(origin, axis_vector + origin) | |
axis_arrow.paint_uniform_color(ARROW_COLOR) | |
# if USE_GT: | |
# anno_path = f"/localhome/atw7/projects/opdmulti/data/data_demo_dev/59-4860.json" | |
# part_id = 32 | |
# # get annotation for the frame | |
# import json | |
# with open(anno_path, "r") as f: | |
# anno = json.load(f) | |
# articulations = anno["articulation"] | |
# for articulation in articulations: | |
# if articulation["partId"] == part_id: | |
# range_min = articulation["rangeMin"] - articulation["state"] | |
# range_max = articulation["rangeMax"] - articulation["state"] | |
# break | |
if pred_type == "rotation": | |
generate_rotation_visualization( | |
pcd, | |
axis_arrow, | |
mask, | |
axis_vector, | |
origin, | |
range_min, | |
range_max, | |
num_samples, | |
output_dir, | |
) | |
elif pred_type == "translation": | |
generate_translation_visualization( | |
pcd, | |
axis_arrow, | |
mask, | |
axis_vector, | |
range_min, | |
range_max, | |
num_samples, | |
output_dir, | |
) | |
else: | |
raise ValueError(f"Invalid motion prediction type: {pred_type}") | |
if pred_type: | |
if crop: # crop images to remove shared extraneous whitespace | |
output_dir_cropped = f"{output_dir}_cropped" | |
if not os.path.isdir(output_dir_cropped): | |
os.makedirs(output_dir_cropped) | |
batch_trim(output_dir, output_dir_cropped, identical=True) | |
# create_gif(output_dir_cropped, num_samples) | |
else: # leave original dimensions of image as-is | |
# create_gif(output_dir, num_samples) | |
pass | |
if __name__ == "__main__": | |
# parse arguments | |
start_time = time.time() | |
args = get_parser().parse_args() | |
cfg = setup_cfg(args) | |
# run main code | |
engine.launch( | |
main, | |
args.num_processes, | |
args=( | |
cfg, | |
args.rgb_image, | |
args.depth_image, | |
args.intrinsics, | |
args.num_samples, | |
args.crop, | |
args.score_threshold, | |
), | |
) | |
end_time = time.time() | |
print(f"Inference time: {end_time - start_time:.2f} seconds") | |