Spaces:
Running
on
Zero
Running
on
Zero
# @package _global_ | |
scratch: | |
resolution: 1024 | |
train_batch_size: 1 | |
num_train_workers: 10 | |
num_frames: 8 | |
max_num_objects: 3 | |
base_lr: 5.0e-6 | |
vision_lr: 3.0e-06 | |
phases_per_epoch: 1 | |
num_epochs: 40 | |
dataset: | |
# PATHS to Dataset | |
img_folder: null # PATH to MOSE JPEGImages folder | |
gt_folder: null # PATH to MOSE Annotations folder | |
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training | |
multiplier: 2 | |
# Video transforms | |
vos: | |
train_transforms: | |
- _target_: training.dataset.transforms.ComposeAPI | |
transforms: | |
- _target_: training.dataset.transforms.RandomHorizontalFlip | |
consistent_transform: True | |
- _target_: training.dataset.transforms.RandomAffine | |
degrees: 25 | |
shear: 20 | |
image_interpolation: bilinear | |
consistent_transform: True | |
- _target_: training.dataset.transforms.RandomResizeAPI | |
sizes: ${scratch.resolution} | |
square: true | |
consistent_transform: True | |
- _target_: training.dataset.transforms.ColorJitter | |
consistent_transform: True | |
brightness: 0.1 | |
contrast: 0.03 | |
saturation: 0.03 | |
hue: null | |
- _target_: training.dataset.transforms.RandomGrayscale | |
p: 0.05 | |
consistent_transform: True | |
- _target_: training.dataset.transforms.ColorJitter | |
consistent_transform: False | |
brightness: 0.1 | |
contrast: 0.05 | |
saturation: 0.05 | |
hue: null | |
- _target_: training.dataset.transforms.ToTensorAPI | |
- _target_: training.dataset.transforms.NormalizeAPI | |
mean: [0.485, 0.456, 0.406] | |
std: [0.229, 0.224, 0.225] | |
trainer: | |
_target_: training.trainer.Trainer | |
mode: train_only | |
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} | |
accelerator: cuda | |
seed_value: 123 | |
model: | |
_target_: training.model.sam2.SAM2Train | |
image_encoder: | |
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder | |
scalp: 1 | |
trunk: | |
_target_: sam2.modeling.backbones.hieradet.Hiera | |
embed_dim: 112 | |
num_heads: 2 | |
drop_path_rate: 0.1 | |
neck: | |
_target_: sam2.modeling.backbones.image_encoder.FpnNeck | |
position_encoding: | |
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine | |
num_pos_feats: 256 | |
normalize: true | |
scale: null | |
temperature: 10000 | |
d_model: 256 | |
backbone_channel_list: [896, 448, 224, 112] | |
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features | |
fpn_interp_model: nearest | |
memory_attention: | |
_target_: sam2.modeling.memory_attention.MemoryAttention | |
d_model: 256 | |
pos_enc_at_input: true | |
layer: | |
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer | |
activation: relu | |
dim_feedforward: 2048 | |
dropout: 0.1 | |
pos_enc_at_attn: false | |
self_attention: | |
_target_: sam2.modeling.sam.transformer.RoPEAttention | |
rope_theta: 10000.0 | |
feat_sizes: [32, 32] | |
embedding_dim: 256 | |
num_heads: 1 | |
downsample_rate: 1 | |
dropout: 0.1 | |
d_model: 256 | |
pos_enc_at_cross_attn_keys: true | |
pos_enc_at_cross_attn_queries: false | |
cross_attention: | |
_target_: sam2.modeling.sam.transformer.RoPEAttention | |
rope_theta: 10000.0 | |
feat_sizes: [32, 32] | |
rope_k_repeat: True | |
embedding_dim: 256 | |
num_heads: 1 | |
downsample_rate: 1 | |
dropout: 0.1 | |
kv_in_dim: 64 | |
num_layers: 4 | |
memory_encoder: | |
_target_: sam2.modeling.memory_encoder.MemoryEncoder | |
out_dim: 64 | |
position_encoding: | |
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine | |
num_pos_feats: 64 | |
normalize: true | |
scale: null | |
temperature: 10000 | |
mask_downsampler: | |
_target_: sam2.modeling.memory_encoder.MaskDownSampler | |
kernel_size: 3 | |
stride: 2 | |
padding: 1 | |
fuser: | |
_target_: sam2.modeling.memory_encoder.Fuser | |
layer: | |
_target_: sam2.modeling.memory_encoder.CXBlock | |
dim: 256 | |
kernel_size: 7 | |
padding: 3 | |
layer_scale_init_value: 1e-6 | |
use_dwconv: True # depth-wise convs | |
num_layers: 2 | |
num_maskmem: 7 | |
image_size: ${scratch.resolution} | |
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask | |
sigmoid_scale_for_mem_enc: 20.0 | |
sigmoid_bias_for_mem_enc: -10.0 | |
use_mask_input_as_output_without_sam: true | |
# Memory | |
directly_add_no_mem_embed: true | |
no_obj_embed_spatial: true | |
# use high-resolution feature map in the SAM mask decoder | |
use_high_res_features_in_sam: true | |
# output 3 masks on the first click on initial conditioning frames | |
multimask_output_in_sam: true | |
# SAM heads | |
iou_prediction_use_sigmoid: True | |
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder | |
use_obj_ptrs_in_encoder: true | |
add_tpos_enc_to_obj_ptrs: true | |
proj_tpos_enc_in_obj_ptrs: true | |
use_signed_tpos_enc_to_obj_ptrs: true | |
only_obj_ptrs_in_the_past_for_eval: true | |
# object occlusion prediction | |
pred_obj_scores: true | |
pred_obj_scores_mlp: true | |
fixed_no_obj_ptr: true | |
# multimask tracking settings | |
multimask_output_for_tracking: true | |
use_multimask_token_for_obj_ptr: true | |
multimask_min_pt_num: 0 | |
multimask_max_pt_num: 1 | |
use_mlp_for_obj_ptr_proj: true | |
# Compilation flag | |
# compile_image_encoder: False | |
####### Training specific params ####### | |
# box/point input and corrections | |
prob_to_use_pt_input_for_train: 0.5 | |
prob_to_use_pt_input_for_eval: 0.0 | |
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points | |
prob_to_use_box_input_for_eval: 0.0 | |
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors | |
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) | |
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame | |
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 | |
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) | |
# maximum 2 initial conditioning frames | |
num_init_cond_frames_for_train: 2 | |
rand_init_cond_frames_for_train: True # random 1~2 | |
num_correction_pt_per_frame: 7 | |
use_act_ckpt_iterative_pt_sampling: false | |
num_init_cond_frames_for_eval: 1 # only mask on the first frame | |
forward_backbone_per_frame_for_eval: True | |
data: | |
train: | |
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset | |
phases_per_epoch: ${scratch.phases_per_epoch} | |
batch_sizes: | |
- ${scratch.train_batch_size} | |
datasets: | |
- _target_: training.dataset.utils.RepeatFactorWrapper | |
dataset: | |
_target_: training.dataset.utils.ConcatDataset | |
datasets: | |
- _target_: training.dataset.vos_dataset.VOSDataset | |
transforms: ${vos.train_transforms} | |
training: true | |
video_dataset: | |
_target_: training.dataset.vos_raw_dataset.PNGRawDataset | |
img_folder: ${dataset.img_folder} | |
gt_folder: ${dataset.gt_folder} | |
file_list_txt: ${dataset.file_list_txt} | |
sampler: | |
_target_: training.dataset.vos_sampler.RandomUniformSampler | |
num_frames: ${scratch.num_frames} | |
max_num_objects: ${scratch.max_num_objects} | |
multiplier: ${dataset.multiplier} | |
shuffle: True | |
num_workers: ${scratch.num_train_workers} | |
pin_memory: True | |
drop_last: True | |
collate_fn: | |
_target_: training.utils.data_utils.collate_fn | |
_partial_: true | |
dict_key: all | |
optim: | |
amp: | |
enabled: True | |
amp_dtype: bfloat16 | |
optimizer: | |
_target_: torch.optim.AdamW | |
gradient_clip: | |
_target_: training.optimizer.GradientClipper | |
max_norm: 0.1 | |
norm_type: 2 | |
param_group_modifiers: | |
- _target_: training.optimizer.layer_decay_param_modifier | |
_partial_: True | |
layer_decay_value: 0.9 | |
apply_to: 'image_encoder.trunk' | |
overrides: | |
- pattern: '*pos_embed*' | |
value: 1.0 | |
options: | |
lr: | |
- scheduler: | |
_target_: fvcore.common.param_scheduler.CosineParamScheduler | |
start_value: ${scratch.base_lr} | |
end_value: ${divide:${scratch.base_lr},10} | |
- scheduler: | |
_target_: fvcore.common.param_scheduler.CosineParamScheduler | |
start_value: ${scratch.vision_lr} | |
end_value: ${divide:${scratch.vision_lr},10} | |
param_names: | |
- 'image_encoder.*' | |
weight_decay: | |
- scheduler: | |
_target_: fvcore.common.param_scheduler.ConstantParamScheduler | |
value: 0.1 | |
- scheduler: | |
_target_: fvcore.common.param_scheduler.ConstantParamScheduler | |
value: 0.0 | |
param_names: | |
- '*bias*' | |
module_cls_names: ['torch.nn.LayerNorm'] | |
loss: | |
all: | |
_target_: training.loss_fns.MultiStepMultiMasksAndIous | |
weight_dict: | |
loss_mask: 20 | |
loss_dice: 1 | |
loss_iou: 1 | |
loss_class: 1 | |
supervise_all_iou: true | |
iou_use_l1_loss: true | |
pred_obj_scores: true | |
focal_gamma_obj_score: 0.0 | |
focal_alpha_obj_score: -1.0 | |
distributed: | |
backend: nccl | |
find_unused_parameters: True | |
logging: | |
tensorboard_writer: | |
_target_: training.utils.logger.make_tensorboard_logger | |
log_dir: ${launcher.experiment_log_dir}/tensorboard | |
flush_secs: 120 | |
should_log: True | |
log_dir: ${launcher.experiment_log_dir}/logs | |
log_freq: 10 | |
# initialize from a SAM 2 checkpoint | |
checkpoint: | |
save_dir: ${launcher.experiment_log_dir}/checkpoints | |
save_freq: 0 # 0 only last checkpoint is saved. | |
model_weight_initializer: | |
_partial_: True | |
_target_: training.utils.checkpoint_utils.load_state_dict_into_model | |
strict: True | |
ignore_unexpected_keys: null | |
ignore_missing_keys: null | |
state_dict: | |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels | |
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint | |
ckpt_state_dict_keys: ['model'] | |
launcher: | |
num_nodes: 1 | |
gpus_per_node: 8 | |
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} | |
# SLURM args if running on a cluster | |
submitit: | |
partition: null | |
account: null | |
qos: null | |
cpus_per_task: 10 | |
use_cluster: false | |
timeout_hour: 24 | |
name: null | |
port_range: [10000, 65000] | |