File size: 11,105 Bytes
b366428
 
 
 
 
 
3555c0a
 
 
 
 
 
 
b366428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3555c0a
b366428
3555c0a
 
b366428
 
 
 
 
 
 
 
 
 
 
 
 
 
3555c0a
b366428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3555c0a
b366428
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

training:
  precision: "high"
  nodes_count: -1

logging: 
  project_name: "genau"
  wandb_key: YOUR_WANDB_KEY (check wandb.ai/authorize)
  log_directory: "./run_logs/genau/train"

  # (optional) if s3 path is speicified, checkpoints be saved at S3_FOLDED/log_directory and deleted from the local folder (except the last checkpoint). Otherwise, checkpointwill be save locally indefinitely
  # S3_BUCKET: "YOUR_S3_BUCKET" 
  # S3_FOLDER: 'YOUR_S3_FOLDER'

  save_checkpoint_every_n_steps: 1500
  save_top_k: -1
  

variables:
  sampling_rate: &sampling_rate 16000 
  mel_bins: &mel_bins 64
  latent_embed_dim: &latent_embed_dim 64
  latent_t_size: &latent_t_size 256 # TODO might need to change
  latent_f_size: &latent_f_size 1
  in_channels: &unet_in_channels 256
  optimize_ddpm_parameter: &optimize_ddpm_parameter true
  optimize_gpt: &optimize_gpt true
  warmup_steps: &warmup_steps 5000
  lr: &lr 5.0e-3
  mx_steps: &mx_steps 80000000
  batch_size: &bs 20 # TODO: change to 256

data: 
  metadata_root: "../dataset_preperation/data/metadata/dataset_root.json"
  train: ['vggsounds', 'audiocaps', 'caption_audioset', 'wavcaps_audioset_strong', 'wavcaps_bbcsound', 'wavcaps_freesound', 'wavcaps_soundbible', 'clotho', 'fsd50k']
  val: "audioset"
  test: "audioset"
  class_label_indices: "audioset_eval_subset"
  dataloader_add_ons: [] 
  augment_p : 0.0
  num_workers: 48
  consistent_start_time: True 

  keys_synonyms:
    gt_audio_caption:
      - audiocaps_gt_captions
      - gt_caption
      - gt_captions
      - caption
      - best_model_w_meta_pred_caption
      - gt_audio_caption
      - autocap_caption
      - wavcaps_caption
    tags:
      - keywords
      - tags


step:
  validation_every_n_epochs: 3
  save_checkpoint_every_n_steps: 1500
  # limit_val_batches: 1 # TODO: enable for test
  # limit_train_batches: 128 # TODO: enable for test
  max_steps: *mx_steps
  save_top_k: -1

preprocessing:
  video:
      fps : 1
      height: 224
      width: 224
  audio:
    sampling_rate: *sampling_rate
    max_wav_value: 32768.0
    duration: 10.24
  stft:
    filter_length: 1024
    hop_length: 160
    win_length: 1024
  mel:
    n_mel_channels: *mel_bins
    mel_fmin: 0
    mel_fmax: 8000 

augmentation:
  mixup: 0.0

model:
  target: src.models.genau_ddpm.GenAu
  params: 
    # dataset token
    dataset_embed_dim: 32

    # logging 
    validate_uncond: False
    validate_wo_ema: True
    num_val_sampled_timestamps: 10

    # evaluation # disable evaluation
    # evaluator:
    #   target: audioldm_eval.EvaluationHelper
    #   params: 
    #     sampling_rate: 16000
    #     device: 'cuda'

    # Optimizer
    optimizer_config:
      # Which optimizer to use
      target: !module src.modules.optimizers.lamb.Lamb
      # Which LR to use 
      lr: *lr
      # The weight decay to use
      weight_decay: 0.01
      # Beta parameters for configs/experiments/getty_images_image_model/w480_debug.yaml
      betas: [0.9,0.99]
      # Eps parameter for Adam
      eps: 0.00000001      

    base_learning_rate: *lr
    # Final lr for cosine annealing
    final_lr: 0.0015  # Use cosine lr scheduling but do not reach 0 as performance degrade with very small lr
    # Number of warmup steps
    warmup_steps: *warmup_steps
    # Number of steps between each lr update
    lr_update_each_steps: 10
    # Total number of training steps
    max_steps: *mx_steps # TODO enable

    # Autoencoder
    first_stage_config:
      base_learning_rate: 8.0e-06
      target: src.modules.latent_encoder.autoencoder_1d.AutoencoderKL1D
      params: 
        # reload_from_ckpt: "data/checkpoints/vae_mel_16k_64bins.ckpt"
        reload_from_ckpt: "1dvae_64ch_16k_64bins"
        sampling_rate: *sampling_rate
        batchsize: *bs # TODO: chagne 
        monitor: val/rec_loss
        image_key: fbank
        subband: 1
        embed_dim: *latent_embed_dim
        time_shuffle: 1
        lossconfig:
          target: src.losses.LPIPSWithDiscriminator
          params:
            disc_start: 50001
            kl_weight: 1000.0
            disc_weight: 0.5
            disc_in_channels: 1
        ddconfig: 
          double_z: true
          mel_bins: *mel_bins # The frequency bins of mel spectrogram
          z_channels: *unet_in_channels
          resolution: 256
          downsample_time: false
          in_channels: 64
          out_ch: 64 # in and out channels must stay as 64
          ch: 512 
          ch_mult:
          - 1
          - 2
          - 4
          num_res_blocks: 3
          attn_resolutions: []
          dropout: 0.0
      
    # Other parameters
    clip_grad: 0.5
    optimize_ddpm_parameter: *optimize_ddpm_parameter
    sampling_rate: *sampling_rate
    batchsize: *bs
    linear_start: 0.0015 # in DDPM, a linear scheduler is used from 1e-4 to 0.2. LDM uses linera scheduler with same params. Make-an-audio uses different start and end values. Improved DDPM introduced coise and RIN introduced sigmoid one.
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    unconditional_prob_cfg: 0.1
    parameterization: eps # [eps, x0, v]
    first_stage_key: fbank
    latent_t_size: *latent_t_size # TODO might need to change
    latent_f_size: *latent_f_size
    channels: *latent_embed_dim # TODO might need to change
    monitor: val/loss_simple_ema
    
    scale_by_std: True
    # scale_factor: 1.0144787

    
    backbone_type : fit
    unet_config:
      target: src.modules.fit.fit_audio.FIT

      params:
        weight_initializer:
          target: !module src.modules.initializers.initializers.RINWeightScalerInitializer
          scale: 0.57735026919 # 1/sqrt(3) from Yuwei's findings

        fit_block_module: !module src.modules.fit.layers.fit_layers.FITBlockV5
        context_channels: 1024
        summary_text_embeddings_channels: 1536 # text embedding (e.g CLAP) size

        # If True inserts the conditioning information in the context
        conditioning_in_context: True

        # The type of positional encodings to use for the time input
        time_pe_type: learned
        # Uses a label that specifies the id of the dataset from which the current input comes
        use_dataset_id_conditioning: True
        # Uses a label that specifies the resolution of the current input
        use_resolution_conditioning: False

        # Size of the input in pixels
        input_size: [1, *latent_t_size, *latent_f_size]  # (frames_count, height, widht)
        # The size in pixels of each patch
        patch_size: [1, 1, 1]
        # The number of patches in each group
        group_size: [1, 32, 1]
        input_channels: *latent_embed_dim
        # The number of channels in the patch embeddings
        patch_channels: 1024
        # The number of fit blocks
        fit_blocks_count: 6
        # The number of local layers in each fit block
        local_layers_per_block: 2
        # The number of global layers in each fit block
        global_layers_per_block: 4
        # The number of latent tokens
        latent_count: 256
        # The number of channels in the latent tokens
        latent_channels: 1536

        self_conditioning_ff_config: {}
        fit_block_config:
          attention_class: !module src.modules.fit.layers.rin_layers.Attention
          ff_class: !module src.modules.fit.layers.rin_layers.FeedForward
          
          # Dropout parameters
          drop_units: 0.1
          drop_path: 0.0

          # Whether to use feedforward layers after corss attention
          use_cross_attention_feedforward: True
          
          # Configuration for attention layers
          default_attention_config:
            heads: 8
            dim_head: 128
          read_attention_config:
            # Ensure heads * dim_head = min(input_channels, patch_channels)
            heads: 8
            dim_head: 128
          read_context_attention_config:
            # Ensure heads * dim_head = min(latent_channels, context_channels)
            heads: 8
            dim_head: 128
          read_latent_conditioning_attention_config:
            # Ensure heads * dim_head = latent_channels
            heads: 12
            dim_head: 128
          write_attention_config:
            # Ensure heads * dim_head = min(input_channels, patch_channels)
            heads: 8
            dim_head: 128
          local_attention_config:
            # Ensure heads * dim_head = patch_channels
            heads: 8
            dim_head: 128
          global_attention_config:
            # Ensure heads * dim_head = latent_channels
            heads: 12
            dim_head: 128
          
          ff_config: {}
    # unet_config:
    #   target: audioldm_train.modules.diffusionmodules.openaimodel.UNetModel
    #   params:
    #     image_size: 64 
    #     extra_film_condition_dim: 512 # If you use film as extra condition, set this parameter. For example if you have two conditioning vectors each have dimension 512, then this number would be 1024
    #     # context_dim: 
    #     # - 768
    #     in_channels: *unet_in_channels # The input channel of the UNet model
    #     out_channels: *latent_embed_dim # TODO might need to change
    #     model_channels: 128 # TODO might need to change
    #     attention_resolutions:
    #     - 8
    #     - 4
    #     - 2
    #     num_res_blocks: 2
    #     channel_mult: 
    #     - 1
    #     - 2
    #     - 3
    #     - 5
    #     num_head_channels: 32
    #     use_spatial_transformer: true
    #     transformer_depth: 1
    #     extra_sa_layer: false
    
    cond_stage_config:
      film_clap_cond1:
        cond_stage_key: text
        conditioning_key: film
        target: src.modules.conditional.conditional_models.CLAPAudioEmbeddingClassifierFreev2
        params:
          pretrained_path: clap_htsat_tiny
          sampling_rate: 16000
          embed_mode: text # or text
          amodel: HTSAT-tiny
      film_flan_t5_cond2:
        cond_stage_key: text
        conditioning_key: film
        target: src.modules.conditional.conditional_models.FlanT5HiddenState
        params:
          text_encoder_name: google/flan-t5-large # google/flan-t5-xxl
          freeze_text_encoder: True
          return_embeds: True
          pool_tokens: True
        
      noncond_dataset_ids: # for none_fit backbone, please use film_dataset_ids and enable encode_dataset_ids
        cond_stage_key: all
        conditioning_key: ignore
        target: src.modules.conditional.conditional_models.DatasetIDs
        params:
          encode_dataset_ids: False
          dataset2id:
            audiocaps: 0
            clotho: 1
            vggsounds: 2
            wavcaps_audioset_strong: 3
            wavcaps_bbcsound: 4
            wavcaps_freesound: 5
            wavcaps_soundbible: 6
            fsd50k: 7
            caption_audioset: 8
            autocap: 9
            unconditional: 0 # set the uncondtional to 0 for future experiments



    evaluation_params:
      unconditional_guidance_scale: 3.5
      ddim_sampling_steps: 200
      n_candidates_per_samples: 3