File size: 4,948 Bytes
b4acacf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
target: !module src.models.pl_htsat_q_bart_captioning.AutoCap
variables:
num_workers: &num_workers 90
sampling_rate: &sampling_rate 32000
warmup_epochs: &warmup_epochs 2
lr: &lr 1.0e-5
batch_size: &bs 128
training:
seed: 20
pretrain: True
pretrain_path: "PRETAINED_CHECKPOINT"
resume_training: False # if true, the most recent checkpoint will be found in the log folder and used to initalize the training
precision: "high"
nodes_count: -1 # if -1, train on the whole world size. For multinode training, please lunch the module with torch.distributed.run
device: "cuda"
exclude_metrics: ['spice', 'meteor', 'spider']
logging:
project_name: "autocap"
wandb_key: YOUR_WANDB_KEY (check wandb.ai/authorize)
log_directory: "./run_logs/autocap/train"
# (optional) if s3 path is speicified, checkpoints be saved at S3_FOLDED/log_directory and deleted from the local folder (except the last checkpoint). Otherwise, checkpointwill be save locally indefinitely
# S3_BUCKET: "YOUR_S3_BUCKET"
# S3_FOLDER: 'YOUR_S3_FOLDER'
save_checkpoint_every_n_epochs: 5
save_top_k: -1
step:
epochs: 20
validation_every_n_epochs: 1
num_sanity_val_steps: 1
# debug
# limit_train_batches: 20
# limit_val_batches: 2
model:
clip_grad: 2
audio_features_dropout_p: 0.5
text_features_dropout_p: 0.5
use_text_qformer: false # if not, then append the the text tokens are directly fed to the decoder
use_audio_qformer: true # if not, then the audio features are directly fed to the decoder
use_clap_embeds: true
meta_input: true
add_special_tokens: True # If not then the meat data will start with Title:, Caption:, etc
meta_keys: ['video_caption', 'title']
# meta_keys: ['video_caption', 'videollama_caption', 'title', 'description', 'subtitle', 'labels']
meta:
max_prompt_len : 128
clap_embeds:
model: 'HTSAT-base'
ckpt: 'pretrained_models/clap/music_speech_audioset_epoch_15_esc_89.98.pt'
embed_dim: 512
text_qformer:
num_text_query_token: 64 # output tokens
input_audio2tex_query_embed : true
detach_video_query_embed: false
frozen_text_Qformer: false
hidden_size: 128
add_cross_attention: true
num_attention_heads: 8
num_hidden_layers: 2
audio_qformer:
num_audio_query_token: 256
frozen_audio_Qformer: false
hidden_size: 256
add_cross_attention: true
num_attention_heads: 8
num_hidden_layers: 2
tokenizer:
max_length: 30
special_tokens: ['<HQVC>', '</HQVC>', '<AVC>', '</AVC>', '<TITLE>', '</TITLE>', '<DESC>', '</DESC>', '<SUB>', '</SUB>', '<LBL>', '</LBL>']
audio_args:
sr: 32000
n_fft: 1024
hop_length: 320
f_min: 50
f_max: 14000
n_mels: 64
max_length: 10 # set to 10 for HTSAT encoder, and set to 0 or 30 for CNN encoder
mono: True
# audiocaps: audiocaps_gt_captions
# audioset: no caption, labels are available
# 'wavcaps_audioset_strong', 'wavcaps_bbcsound', 'wavcaps_freesound', 'wavcaps_soundbible' :wavcaps_caption
# clotho: gt_captions
# fs50k: no caption, labels are available
data_args:
data:
metadata_root: "../dataset_preperation/data/metadata/dataset_root.json"
train: ['32k_captioned_audiocaps', 'caption_audioset', 'wavcaps_audioset_strong', 'wavcaps_bbcsound', 'wavcaps_freesound', 'wavcaps_soundbible', 'clotho', 'fsd50k']
val: ['32k_captioned_audiocaps']
test: ['32k_captioned_audiocaps']
keys_synonyms:
gt_audio_caption:
- audiocaps_gt_captions
- gt_captions
- gt_caption
- caption
- gt_audio_caption
- wavcaps_caption
tags:
- keywords
- tags
- labels
batch_size: *bs
num_workers: *num_workers
augmentation_p : 0.1
preprocessing:
video:
fps : 1
height: 224
width: 224
audio:
sampling_rate: *sampling_rate
max_wav_value: 32768.0
duration: 10.0
stft:
filter_length: 1024
hop_length: 320
win_length: 1024
mel:
n_mel_channels: 64
mel_fmin: 50
mel_fmax: 14000
audio_encoder_args:
model_arch: "transformer"
model_name: "htsat"
pretrained: True
freeze: True
spec_augment: True
text_decoder_args:
model_tag: "audio_qformer"
name: "facebook/bart-base"
pretrained: true
freeze: False
freeze_embed_layer: False
bert_args:
attention_probs_dropout_prob: 0.2
hidden_act: "gelu"
hidden_dropout_prob: 0.2
hidden_size: 768
initializer_range: 0.02
intermediate_size: 2048
layer_norm_eps: !!float 1e-5
max_position_embeddings: 128
model_type: "bert"
num_attention_heads: 4
num_hidden_layers: 2
add_type_embeddings: false
vocab_size: 30522
add_cross_attention: true
is_decoder: true
num_labels: 0
name: "bert-base-uncased"
optim_args:
scheduler: cosine
lr: *lr
optimizer_name: "adam"
betas: [0.9, 0.999]
eps: !!float 1e-8
momentum: 0.9
gamma: 0.05
warmup_epochs: *warmup_epochs
weight_decay: !!float 1e-6
|