File size: 4,948 Bytes
b4acacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
target: !module src.models.pl_htsat_q_bart_captioning.AutoCap

variables:
  num_workers: &num_workers 90
  sampling_rate: &sampling_rate 32000 
  warmup_epochs: &warmup_epochs 2
  lr: &lr 1.0e-5
  batch_size: &bs 128 

training:
  seed: 20
  pretrain: True
  pretrain_path: "PRETAINED_CHECKPOINT"
  resume_training: False # if true, the most recent checkpoint will be found in the log folder and used to initalize the training
  precision: "high"
  nodes_count: -1 # if -1, train on the whole world size. For multinode training, please lunch the module with torch.distributed.run 
  device: "cuda"
  exclude_metrics: ['spice', 'meteor', 'spider']

logging: 
  project_name: "autocap"
  wandb_key: YOUR_WANDB_KEY (check wandb.ai/authorize)
  log_directory: "./run_logs/autocap/train"

  # (optional) if s3 path is speicified, checkpoints be saved at S3_FOLDED/log_directory and deleted from the local folder (except the last checkpoint). Otherwise, checkpointwill be save locally indefinitely
  # S3_BUCKET: "YOUR_S3_BUCKET" 
  # S3_FOLDER: 'YOUR_S3_FOLDER'
  save_checkpoint_every_n_epochs: 5
  save_top_k: -1

step:
  epochs: 20
  validation_every_n_epochs: 1
  num_sanity_val_steps: 1

  # debug
  # limit_train_batches: 20
  # limit_val_batches: 2


model:
  clip_grad: 2
  audio_features_dropout_p: 0.5
  text_features_dropout_p: 0.5
  use_text_qformer: false # if not, then append the the text tokens are directly fed to the decoder
  use_audio_qformer: true # if not, then the audio features are directly fed to the decoder
  use_clap_embeds: true
  meta_input: true
  add_special_tokens: True # If not then the meat data will start with Title:, Caption:, etc
  meta_keys: ['video_caption', 'title']
  # meta_keys: ['video_caption', 'videollama_caption', 'title', 'description', 'subtitle', 'labels'] 


meta: 
  max_prompt_len : 128

clap_embeds:
  model: 'HTSAT-base'
  ckpt: 'pretrained_models/clap/music_speech_audioset_epoch_15_esc_89.98.pt'
  embed_dim: 512

text_qformer:
  num_text_query_token: 64 # output tokens
  input_audio2tex_query_embed : true
  detach_video_query_embed: false
  frozen_text_Qformer: false
  hidden_size: 128
  add_cross_attention: true
  num_attention_heads: 8
  num_hidden_layers: 2

audio_qformer:
  num_audio_query_token: 256
  frozen_audio_Qformer: false
  hidden_size: 256
  add_cross_attention: true
  num_attention_heads: 8
  num_hidden_layers: 2

tokenizer:
  max_length: 30
  special_tokens: ['<HQVC>', '</HQVC>', '<AVC>', '</AVC>', '<TITLE>', '</TITLE>', '<DESC>', '</DESC>', '<SUB>', '</SUB>', '<LBL>', '</LBL>']

audio_args:
  sr: 32000
  n_fft: 1024
  hop_length: 320
  f_min: 50
  f_max: 14000
  n_mels: 64
  max_length: 10 # set to 10 for HTSAT encoder, and set to 0 or 30 for CNN encoder
  mono: True

# audiocaps: audiocaps_gt_captions
# audioset: no caption, labels are available
# 'wavcaps_audioset_strong', 'wavcaps_bbcsound', 'wavcaps_freesound', 'wavcaps_soundbible' :wavcaps_caption
# clotho: gt_captions
# fs50k: no caption, labels are available
data_args:
  data: 
    metadata_root: "../dataset_preperation/data/metadata/dataset_root.json"
    train: ['32k_captioned_audiocaps', 'caption_audioset', 'wavcaps_audioset_strong', 'wavcaps_bbcsound', 'wavcaps_freesound', 'wavcaps_soundbible', 'clotho', 'fsd50k']
    val: ['32k_captioned_audiocaps']
    test: ['32k_captioned_audiocaps']

    keys_synonyms:
      gt_audio_caption:
        - audiocaps_gt_captions
        - gt_captions
        - gt_caption
        - caption
        - gt_audio_caption
        - wavcaps_caption
      tags:
        - keywords
        - tags
        - labels

  batch_size: *bs 
  num_workers: *num_workers
  augmentation_p : 0.1

  preprocessing:
    video:
      fps : 1
      height: 224
      width: 224
    audio:
      sampling_rate: *sampling_rate
      max_wav_value: 32768.0
      duration: 10.0
    stft:
      filter_length: 1024
      hop_length: 320
      win_length: 1024
    mel:
      n_mel_channels: 64
      mel_fmin: 50
      mel_fmax: 14000 


audio_encoder_args:
  model_arch: "transformer"
  model_name: "htsat"
  pretrained: True
  freeze: True
  spec_augment: True

text_decoder_args:
  model_tag: "audio_qformer"
  name: "facebook/bart-base"
  pretrained: true
  freeze: False
  freeze_embed_layer: False
  bert_args:
    attention_probs_dropout_prob: 0.2
    hidden_act: "gelu"
    hidden_dropout_prob: 0.2
    hidden_size: 768
    initializer_range: 0.02
    intermediate_size: 2048
    layer_norm_eps: !!float 1e-5
    max_position_embeddings: 128
    model_type: "bert"
    num_attention_heads: 4
    num_hidden_layers: 2
    add_type_embeddings: false
    vocab_size: 30522
    add_cross_attention: true
    is_decoder: true
    num_labels: 0
    name: "bert-base-uncased"


optim_args:
  scheduler: cosine
  lr: *lr
  optimizer_name: "adam"
  betas: [0.9, 0.999]
  eps: !!float 1e-8
  momentum: 0.9
  gamma: 0.05
  warmup_epochs: *warmup_epochs
  weight_decay: !!float 1e-6