File size: 11,595 Bytes
8078d22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# @package _global_

scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 8
  max_num_objects: 3
  base_lr: 5.0e-6
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 40

dataset:
  # PATHS to Dataset
  img_folder: null # PATH to MOSE JPEGImages folder
  gt_folder: null  # PATH to MOSE Annotations folder
  file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
  multiplier: 2

# Video transforms
vos:
  train_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomHorizontalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 25
          shear: 20
          image_interpolation: bilinear
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: True
          brightness: 0.1
          contrast: 0.03
          saturation: 0.03
          hue: null
        - _target_: training.dataset.transforms.RandomGrayscale
          p: 0.05
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: False
          brightness: 0.1
          contrast: 0.05
          saturation: 0.05
          hue: null
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]

trainer:
  _target_: training.trainer.Trainer
  mode: train_only
  max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
  accelerator: cuda
  seed_value: 123

  model:
    _target_: training.model.sam2.SAM2Train
    image_encoder:
      _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
      scalp: 1
      trunk:
        _target_: sam2.modeling.backbones.hieradet.Hiera
        embed_dim: 112
        num_heads: 2
        drop_path_rate: 0.1
      neck:
        _target_: sam2.modeling.backbones.image_encoder.FpnNeck
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 256
          normalize: true
          scale: null
          temperature: 10000
        d_model: 256
        backbone_channel_list: [896, 448, 224, 112]
        fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
        fpn_interp_model: nearest

    memory_attention:
      _target_: sam2.modeling.memory_attention.MemoryAttention
      d_model: 256
      pos_enc_at_input: true
      layer:
        _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
        activation: relu
        dim_feedforward: 2048
        dropout: 0.1
        pos_enc_at_attn: false
        self_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
        d_model: 256
        pos_enc_at_cross_attn_keys: true
        pos_enc_at_cross_attn_queries: false
        cross_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          rope_k_repeat: True
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
          kv_in_dim: 64
      num_layers: 4

    memory_encoder:
        _target_: sam2.modeling.memory_encoder.MemoryEncoder
        out_dim: 64
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 64
          normalize: true
          scale: null
          temperature: 10000
        mask_downsampler:
          _target_: sam2.modeling.memory_encoder.MaskDownSampler
          kernel_size: 3
          stride: 2
          padding: 1
        fuser:
          _target_: sam2.modeling.memory_encoder.Fuser
          layer:
            _target_: sam2.modeling.memory_encoder.CXBlock
            dim: 256
            kernel_size: 7
            padding: 3
            layer_scale_init_value: 1e-6
            use_dwconv: True  # depth-wise convs
          num_layers: 2

    num_maskmem: 7
    image_size: ${scratch.resolution}
    # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
    sigmoid_scale_for_mem_enc: 20.0
    sigmoid_bias_for_mem_enc: -10.0
    use_mask_input_as_output_without_sam: true
    # Memory
    directly_add_no_mem_embed: true
    no_obj_embed_spatial: true
    # use high-resolution feature map in the SAM mask decoder
    use_high_res_features_in_sam: true
    # output 3 masks on the first click on initial conditioning frames
    multimask_output_in_sam: true
    # SAM heads
    iou_prediction_use_sigmoid: True
    # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
    use_obj_ptrs_in_encoder: true
    add_tpos_enc_to_obj_ptrs: true
    proj_tpos_enc_in_obj_ptrs: true
    use_signed_tpos_enc_to_obj_ptrs: true
    only_obj_ptrs_in_the_past_for_eval: true
    # object occlusion prediction
    pred_obj_scores: true
    pred_obj_scores_mlp: true
    fixed_no_obj_ptr: true
    # multimask tracking settings
    multimask_output_for_tracking: true
    use_multimask_token_for_obj_ptr: true
    multimask_min_pt_num: 0
    multimask_max_pt_num: 1
    use_mlp_for_obj_ptr_proj: true
    # Compilation flag
    # compile_image_encoder: False

    ####### Training specific params #######
    # box/point input and corrections
    prob_to_use_pt_input_for_train: 0.5
    prob_to_use_pt_input_for_eval: 0.0
    prob_to_use_box_input_for_train: 0.5  # 0.5*0.5 = 0.25 prob to use box instead of points
    prob_to_use_box_input_for_eval: 0.0
    prob_to_sample_from_gt_for_train: 0.1  # with a small prob, sampling correction points from GT mask instead of prediction errors
    num_frames_to_correct_for_train: 2  # iteratively sample on random 1~2 frames (always include the first frame)
    num_frames_to_correct_for_eval: 1  # only iteratively sample on first frame
    rand_frames_to_correct_for_train: True  # random #init-cond-frame ~ 2
    add_all_frames_to_correct_as_cond: True  # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
    # maximum 2 initial conditioning frames
    num_init_cond_frames_for_train: 2
    rand_init_cond_frames_for_train: True  # random 1~2
    num_correction_pt_per_frame: 7
    use_act_ckpt_iterative_pt_sampling: false
    

    
    num_init_cond_frames_for_eval: 1  # only mask on the first frame
    forward_backbone_per_frame_for_eval: True
    

  data:
    train:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - ${scratch.train_batch_size}

      datasets:
        - _target_: training.dataset.utils.RepeatFactorWrapper
          dataset:
            _target_: training.dataset.utils.ConcatDataset
            datasets:
            - _target_: training.dataset.vos_dataset.VOSDataset
              transforms: ${vos.train_transforms}
              training: true
              video_dataset:
                _target_: training.dataset.vos_raw_dataset.PNGRawDataset
                img_folder: ${dataset.img_folder}
                gt_folder: ${dataset.gt_folder}
                file_list_txt: ${dataset.file_list_txt}
              sampler:
                _target_: training.dataset.vos_sampler.RandomUniformSampler
                num_frames: ${scratch.num_frames}
                max_num_objects: ${scratch.max_num_objects}
              multiplier: ${dataset.multiplier}
      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: training.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: training.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.9
        apply_to: 'image_encoder.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.base_lr}
            end_value: ${divide:${scratch.base_lr},10}
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.vision_lr}
            end_value: ${divide:${scratch.vision_lr},10}
          param_names:
            - 'image_encoder.*'
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.1
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  loss:
    all:
      _target_: training.loss_fns.MultiStepMultiMasksAndIous
      weight_dict:
        loss_mask: 20
        loss_dice: 1
        loss_iou: 1
        loss_class: 1
      supervise_all_iou: true
      iou_use_l1_loss: true
      pred_obj_scores: true
      focal_gamma_obj_score: 0.0
      focal_alpha_obj_score: -1.0

  distributed:
    backend: nccl
    find_unused_parameters: True

  logging:
    tensorboard_writer:
      _target_: training.utils.logger.make_tensorboard_logger
      log_dir:  ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  # initialize from a SAM 2 checkpoint
  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: training.utils.checkpoint_utils.load_state_dict_into_model
      strict: True
      ignore_unexpected_keys: null
      ignore_missing_keys: null

      state_dict:
        _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
        ckpt_state_dict_keys: ['model']

launcher:
  num_nodes: 1
  gpus_per_node: 8
  experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}

# SLURM args if running on a cluster
submitit:
  partition: null
  account: null
  qos: null
  cpus_per_task: 10
  use_cluster: false
  timeout_hour: 24
  name: null
  port_range: [10000, 65000]