|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed |
|
from sam2.modeling.sam2_base import SAM2Base |
|
from sam2.modeling.sam2_utils import ( |
|
get_1d_sine_pe, |
|
get_next_point, |
|
sample_box_points, |
|
select_closest_cond_frames, |
|
) |
|
|
|
from sam2.utils.misc import concat_points |
|
|
|
from training.utils.data_utils import BatchedVideoDatapoint |
|
|
|
|
|
class SAM2Train(SAM2Base): |
|
def __init__( |
|
self, |
|
image_encoder, |
|
memory_attention=None, |
|
memory_encoder=None, |
|
prob_to_use_pt_input_for_train=0.0, |
|
prob_to_use_pt_input_for_eval=0.0, |
|
prob_to_use_box_input_for_train=0.0, |
|
prob_to_use_box_input_for_eval=0.0, |
|
|
|
num_frames_to_correct_for_train=1, |
|
num_frames_to_correct_for_eval=1, |
|
rand_frames_to_correct_for_train=False, |
|
rand_frames_to_correct_for_eval=False, |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_init_cond_frames_for_train=1, |
|
num_init_cond_frames_for_eval=1, |
|
rand_init_cond_frames_for_train=True, |
|
rand_init_cond_frames_for_eval=False, |
|
|
|
|
|
add_all_frames_to_correct_as_cond=False, |
|
|
|
|
|
num_correction_pt_per_frame=7, |
|
|
|
|
|
|
|
pt_sampling_for_eval="center", |
|
|
|
|
|
|
|
prob_to_sample_from_gt_for_train=0.0, |
|
use_act_ckpt_iterative_pt_sampling=False, |
|
|
|
|
|
forward_backbone_per_frame_for_eval=False, |
|
freeze_image_encoder=False, |
|
**kwargs, |
|
): |
|
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) |
|
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling |
|
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval |
|
|
|
|
|
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train |
|
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train |
|
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval |
|
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval |
|
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: |
|
logging.info( |
|
f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" |
|
) |
|
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train |
|
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval |
|
|
|
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train |
|
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval |
|
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train |
|
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval |
|
|
|
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train |
|
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval |
|
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train |
|
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval |
|
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond |
|
self.num_correction_pt_per_frame = num_correction_pt_per_frame |
|
self.pt_sampling_for_eval = pt_sampling_for_eval |
|
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train |
|
|
|
self.rng = np.random.default_rng(seed=42) |
|
|
|
if freeze_image_encoder: |
|
for p in self.image_encoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def forward(self, input: BatchedVideoDatapoint): |
|
if self.training or not self.forward_backbone_per_frame_for_eval: |
|
|
|
backbone_out = self.forward_image(input.flat_img_batch) |
|
else: |
|
|
|
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} |
|
backbone_out = self.prepare_prompt_inputs(backbone_out, input) |
|
previous_stages_out = self.forward_tracking(backbone_out, input) |
|
|
|
return previous_stages_out |
|
|
|
def _prepare_backbone_features_per_frame(self, img_batch, img_ids): |
|
"""Compute the image backbone features on the fly for the given img_ids.""" |
|
|
|
|
|
if img_ids.numel() > 1: |
|
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) |
|
else: |
|
unique_img_ids, inv_ids = img_ids, None |
|
|
|
|
|
image = img_batch[unique_img_ids] |
|
backbone_out = self.forward_image(image) |
|
( |
|
_, |
|
vision_feats, |
|
vision_pos_embeds, |
|
feat_sizes, |
|
) = self._prepare_backbone_features(backbone_out) |
|
|
|
|
|
if inv_ids is not None: |
|
image = image[inv_ids] |
|
vision_feats = [x[:, inv_ids] for x in vision_feats] |
|
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] |
|
|
|
return image, vision_feats, vision_pos_embeds, feat_sizes |
|
|
|
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): |
|
""" |
|
Prepare input mask, point or box prompts. Optionally, we allow tracking from |
|
a custom `start_frame_idx` to the end of the video (for evaluation purposes). |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_masks_per_frame = { |
|
stage_id: masks.unsqueeze(1) |
|
for stage_id, masks in enumerate(input.masks) |
|
} |
|
|
|
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame |
|
num_frames = input.num_frames |
|
backbone_out["num_frames"] = num_frames |
|
|
|
|
|
if self.training: |
|
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train |
|
prob_to_use_box_input = self.prob_to_use_box_input_for_train |
|
num_frames_to_correct = self.num_frames_to_correct_for_train |
|
rand_frames_to_correct = self.rand_frames_to_correct_for_train |
|
num_init_cond_frames = self.num_init_cond_frames_for_train |
|
rand_init_cond_frames = self.rand_init_cond_frames_for_train |
|
else: |
|
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval |
|
prob_to_use_box_input = self.prob_to_use_box_input_for_eval |
|
num_frames_to_correct = self.num_frames_to_correct_for_eval |
|
rand_frames_to_correct = self.rand_frames_to_correct_for_eval |
|
num_init_cond_frames = self.num_init_cond_frames_for_eval |
|
rand_init_cond_frames = self.rand_init_cond_frames_for_eval |
|
if num_frames == 1: |
|
|
|
|
|
prob_to_use_pt_input = 1.0 |
|
num_frames_to_correct = 1 |
|
num_init_cond_frames = 1 |
|
assert num_init_cond_frames >= 1 |
|
|
|
use_pt_input = self.rng.random() < prob_to_use_pt_input |
|
if rand_init_cond_frames and num_init_cond_frames > 1: |
|
|
|
num_init_cond_frames = self.rng.integers( |
|
1, num_init_cond_frames, endpoint=True |
|
) |
|
if ( |
|
use_pt_input |
|
and rand_frames_to_correct |
|
and num_frames_to_correct > num_init_cond_frames |
|
): |
|
|
|
|
|
num_frames_to_correct = self.rng.integers( |
|
num_init_cond_frames, num_frames_to_correct, endpoint=True |
|
) |
|
backbone_out["use_pt_input"] = use_pt_input |
|
|
|
|
|
if num_init_cond_frames == 1: |
|
init_cond_frames = [start_frame_idx] |
|
else: |
|
|
|
init_cond_frames = [start_frame_idx] + self.rng.choice( |
|
range(start_frame_idx + 1, num_frames), |
|
num_init_cond_frames - 1, |
|
replace=False, |
|
).tolist() |
|
backbone_out["init_cond_frames"] = init_cond_frames |
|
backbone_out["frames_not_in_init_cond"] = [ |
|
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames |
|
] |
|
|
|
backbone_out["mask_inputs_per_frame"] = {} |
|
backbone_out["point_inputs_per_frame"] = {} |
|
for t in init_cond_frames: |
|
if not use_pt_input: |
|
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] |
|
else: |
|
|
|
use_box_input = self.rng.random() < prob_to_use_box_input |
|
if use_box_input: |
|
points, labels = sample_box_points( |
|
gt_masks_per_frame[t], |
|
) |
|
else: |
|
|
|
|
|
points, labels = get_next_point( |
|
gt_masks=gt_masks_per_frame[t], |
|
pred_masks=None, |
|
method=( |
|
"uniform" if self.training else self.pt_sampling_for_eval |
|
), |
|
) |
|
|
|
point_inputs = {"point_coords": points, "point_labels": labels} |
|
backbone_out["point_inputs_per_frame"][t] = point_inputs |
|
|
|
|
|
|
|
if not use_pt_input: |
|
|
|
frames_to_add_correction_pt = [] |
|
elif num_frames_to_correct == num_init_cond_frames: |
|
frames_to_add_correction_pt = init_cond_frames |
|
else: |
|
assert num_frames_to_correct > num_init_cond_frames |
|
|
|
extra_num = num_frames_to_correct - num_init_cond_frames |
|
frames_to_add_correction_pt = ( |
|
init_cond_frames |
|
+ self.rng.choice( |
|
backbone_out["frames_not_in_init_cond"], extra_num, replace=False |
|
).tolist() |
|
) |
|
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt |
|
|
|
return backbone_out |
|
|
|
def forward_tracking( |
|
self, backbone_out, input: BatchedVideoDatapoint, return_dict=False |
|
): |
|
"""Forward video tracking on each frame (and sample correction clicks).""" |
|
img_feats_already_computed = backbone_out["backbone_fpn"] is not None |
|
if img_feats_already_computed: |
|
|
|
|
|
( |
|
_, |
|
vision_feats, |
|
vision_pos_embeds, |
|
feat_sizes, |
|
) = self._prepare_backbone_features(backbone_out) |
|
|
|
|
|
num_frames = backbone_out["num_frames"] |
|
init_cond_frames = backbone_out["init_cond_frames"] |
|
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] |
|
|
|
|
|
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] |
|
output_dict = { |
|
"cond_frame_outputs": {}, |
|
"non_cond_frame_outputs": {}, |
|
} |
|
for stage_id in processing_order: |
|
|
|
|
|
img_ids = input.flat_obj_to_img_idx[stage_id] |
|
if img_feats_already_computed: |
|
|
|
current_vision_feats = [x[:, img_ids] for x in vision_feats] |
|
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] |
|
else: |
|
|
|
|
|
( |
|
_, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
) = self._prepare_backbone_features_per_frame( |
|
input.flat_img_batch, img_ids |
|
) |
|
|
|
|
|
current_out = self.track_step( |
|
frame_idx=stage_id, |
|
is_init_cond_frame=stage_id in init_cond_frames, |
|
current_vision_feats=current_vision_feats, |
|
current_vision_pos_embeds=current_vision_pos_embeds, |
|
feat_sizes=feat_sizes, |
|
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), |
|
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), |
|
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), |
|
frames_to_add_correction_pt=frames_to_add_correction_pt, |
|
output_dict=output_dict, |
|
num_frames=num_frames, |
|
) |
|
|
|
add_output_as_cond_frame = stage_id in init_cond_frames or ( |
|
self.add_all_frames_to_correct_as_cond |
|
and stage_id in frames_to_add_correction_pt |
|
) |
|
if add_output_as_cond_frame: |
|
output_dict["cond_frame_outputs"][stage_id] = current_out |
|
else: |
|
output_dict["non_cond_frame_outputs"][stage_id] = current_out |
|
|
|
if return_dict: |
|
return output_dict |
|
|
|
all_frame_outputs = {} |
|
all_frame_outputs.update(output_dict["cond_frame_outputs"]) |
|
all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) |
|
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] |
|
|
|
all_frame_outputs = [ |
|
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs |
|
] |
|
|
|
return all_frame_outputs |
|
|
|
def track_step( |
|
self, |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse=False, |
|
run_mem_encoder=True, |
|
prev_sam_mask_logits=None, |
|
frames_to_add_correction_pt=None, |
|
gt_masks=None, |
|
): |
|
if frames_to_add_correction_pt is None: |
|
frames_to_add_correction_pt = [] |
|
current_out, sam_outputs, high_res_features, pix_feat = self._track_step( |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse, |
|
prev_sam_mask_logits, |
|
) |
|
|
|
( |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
object_score_logits, |
|
) = sam_outputs |
|
|
|
current_out["multistep_pred_masks"] = low_res_masks |
|
current_out["multistep_pred_masks_high_res"] = high_res_masks |
|
current_out["multistep_pred_multimasks"] = [low_res_multimasks] |
|
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] |
|
current_out["multistep_pred_ious"] = [ious] |
|
current_out["multistep_point_inputs"] = [point_inputs] |
|
current_out["multistep_object_score_logits"] = [object_score_logits] |
|
|
|
|
|
if frame_idx in frames_to_add_correction_pt: |
|
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( |
|
is_init_cond_frame, |
|
point_inputs, |
|
gt_masks, |
|
high_res_features, |
|
pix_feat, |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
object_score_logits, |
|
current_out, |
|
) |
|
( |
|
_, |
|
_, |
|
_, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
object_score_logits, |
|
) = final_sam_outputs |
|
|
|
|
|
current_out["pred_masks"] = low_res_masks |
|
current_out["pred_masks_high_res"] = high_res_masks |
|
current_out["obj_ptr"] = obj_ptr |
|
|
|
|
|
|
|
self._encode_memory_in_output( |
|
current_vision_feats, |
|
feat_sizes, |
|
point_inputs, |
|
run_mem_encoder, |
|
high_res_masks, |
|
object_score_logits, |
|
current_out, |
|
) |
|
return current_out |
|
|
|
def _iter_correct_pt_sampling( |
|
self, |
|
is_init_cond_frame, |
|
point_inputs, |
|
gt_masks, |
|
high_res_features, |
|
pix_feat_with_mem, |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
object_score_logits, |
|
current_out, |
|
): |
|
|
|
assert gt_masks is not None |
|
all_pred_masks = [low_res_masks] |
|
all_pred_high_res_masks = [high_res_masks] |
|
all_pred_multimasks = [low_res_multimasks] |
|
all_pred_high_res_multimasks = [high_res_multimasks] |
|
all_pred_ious = [ious] |
|
all_point_inputs = [point_inputs] |
|
all_object_score_logits = [object_score_logits] |
|
for _ in range(self.num_correction_pt_per_frame): |
|
|
|
|
|
if self.training and self.prob_to_sample_from_gt_for_train > 0: |
|
sample_from_gt = ( |
|
self.rng.random() < self.prob_to_sample_from_gt_for_train |
|
) |
|
else: |
|
sample_from_gt = False |
|
|
|
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) |
|
new_points, new_labels = get_next_point( |
|
gt_masks=gt_masks, |
|
pred_masks=pred_for_new_pt, |
|
method="uniform" if self.training else self.pt_sampling_for_eval, |
|
) |
|
point_inputs = concat_points(point_inputs, new_points, new_labels) |
|
|
|
|
|
|
|
mask_inputs = low_res_masks |
|
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
|
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: |
|
sam_outputs = torch.utils.checkpoint.checkpoint( |
|
self._forward_sam_heads, |
|
backbone_features=pix_feat_with_mem, |
|
point_inputs=point_inputs, |
|
mask_inputs=mask_inputs, |
|
high_res_features=high_res_features, |
|
multimask_output=multimask_output, |
|
use_reentrant=False, |
|
) |
|
else: |
|
sam_outputs = self._forward_sam_heads( |
|
backbone_features=pix_feat_with_mem, |
|
point_inputs=point_inputs, |
|
mask_inputs=mask_inputs, |
|
high_res_features=high_res_features, |
|
multimask_output=multimask_output, |
|
) |
|
( |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
_, |
|
object_score_logits, |
|
) = sam_outputs |
|
all_pred_masks.append(low_res_masks) |
|
all_pred_high_res_masks.append(high_res_masks) |
|
all_pred_multimasks.append(low_res_multimasks) |
|
all_pred_high_res_multimasks.append(high_res_multimasks) |
|
all_pred_ious.append(ious) |
|
all_point_inputs.append(point_inputs) |
|
all_object_score_logits.append(object_score_logits) |
|
|
|
|
|
|
|
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) |
|
current_out["multistep_pred_masks_high_res"] = torch.cat( |
|
all_pred_high_res_masks, dim=1 |
|
) |
|
current_out["multistep_pred_multimasks"] = all_pred_multimasks |
|
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks |
|
current_out["multistep_pred_ious"] = all_pred_ious |
|
current_out["multistep_point_inputs"] = all_point_inputs |
|
current_out["multistep_object_score_logits"] = all_object_score_logits |
|
|
|
return point_inputs, sam_outputs |
|
|