liuhuadai commited on
Commit
6efc863
·
verified ·
1 Parent(s): a00f5de

Upload 340 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. app.py +44 -0
  3. audiocaps_test_16000_struct.tsv +0 -0
  4. configs/audiolcm.yaml +130 -0
  5. configs/autoencoder1d.yaml +74 -0
  6. configs/teacher.yaml +121 -0
  7. infer.sh +4 -0
  8. infer_api.sh +4 -0
  9. ldm/__pycache__/lr_scheduler.cpython-37.pyc +0 -0
  10. ldm/__pycache__/lr_scheduler.cpython-38.pyc +0 -0
  11. ldm/__pycache__/util.cpython-310.pyc +0 -0
  12. ldm/__pycache__/util.cpython-37.pyc +0 -0
  13. ldm/__pycache__/util.cpython-38.pyc +0 -0
  14. ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc +0 -0
  15. ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc +0 -0
  16. ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc +0 -0
  17. ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc +0 -0
  18. ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc +0 -0
  19. ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc +0 -0
  20. ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc +0 -0
  21. ldm/data/__pycache__/tsvdataset.cpython-38.pyc +0 -0
  22. ldm/data/joinaudiodataset_624.py +93 -0
  23. ldm/data/joinaudiodataset_anylen.py +331 -0
  24. ldm/data/joinaudiodataset_struct.py +95 -0
  25. ldm/data/joinaudiodataset_struct_anylen.py +336 -0
  26. ldm/data/joinaudiodataset_struct_sample.py +103 -0
  27. ldm/data/joinaudiodataset_struct_sample_anylen.py +230 -0
  28. ldm/data/preprocess/NAT_mel.py +131 -0
  29. ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc +0 -0
  30. ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc +0 -0
  31. ldm/data/preprocess/add_duration.py +45 -0
  32. ldm/data/preprocess/mel_spec.py +201 -0
  33. ldm/data/test.py +224 -0
  34. ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv +0 -0
  35. ldm/data/tsv_dirs/full_data/V2/MACS.tsv +0 -0
  36. ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv +0 -0
  37. ldm/data/tsv_dirs/full_data/V2/adobe.tsv +0 -0
  38. ldm/data/tsv_dirs/full_data/V2/audiostock.tsv +0 -0
  39. ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv +3 -0
  40. ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv +0 -0
  41. ldm/data/tsv_dirs/full_data/clotho.tsv +0 -0
  42. ldm/data/tsvdataset.py +67 -0
  43. ldm/lr_scheduler.py +98 -0
  44. ldm/models/__pycache__/autoencoder.cpython-37.pyc +0 -0
  45. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  46. ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
  47. ldm/models/__pycache__/autoencoder1d.cpython-37.pyc +0 -0
  48. ldm/models/__pycache__/autoencoder1d.cpython-38.pyc +0 -0
  49. ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc +0 -0
  50. ldm/models/autoencoder.py +504 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv filter=lfs diff=lfs merge=lfs -text
37
+ vocoder/BigVGAN/LibriTTS/train-full.txt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+
3
+ def infer(prompt):
4
+ config = OmegaConf.load("configs/audiolcm.yaml")
5
+
6
+ # print("-------quick debug no load ckpt---------")
7
+ # model = instantiate_from_config(config['model'])# for quick debug
8
+ model = load_model_from_config(config,
9
+ "../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt")
10
+
11
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
12
+ model = model.to(device)
13
+
14
+ sampler = LCMSampler(model)
15
+
16
+ os.makedirs("results/test", exist_ok=True)
17
+
18
+ vocoder = VocoderBigVGAN("../vocoder/logs/bigvnat16k93.5w", device)
19
+
20
+ generator = GenSamples(sampler, model, "results/test", vocoder, save_mel=False, save_wav=True,
21
+ original_inference_steps=config.model.params.num_ddim_timesteps)
22
+ csv_dicts = []
23
+
24
+ with torch.no_grad():
25
+ with model.ema_scope():
26
+ wav_name = f'{prompt.strip().replace(" ", "-")}'
27
+ generator.gen_test_sample(prompt, wav_name=wav_name)
28
+
29
+ print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
30
+
31
+
32
+ def my_inference_function(prompt_oir):
33
+ prompt = {'ori_caption':prompt_oir,'struct_caption':prompt_oir}
34
+ file_path = infer(prompt)
35
+ return "test.wav"
36
+
37
+
38
+
39
+ gradio_interface = gradio.Interface(
40
+ fn = my_inference_function,
41
+ inputs = "text",
42
+ outputs = "audio"
43
+ )
44
+ gradio_interface.launch()
audiocaps_test_16000_struct.tsv ADDED
The diff for this file is too large to render. See raw diff
 
configs/audiolcm.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 3.0e-06
3
+ target: ldm.models.diffusion.lcm_audio.LCM_audio
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ mel_dim: 20
13
+ mel_length: 312
14
+ channels: 0
15
+ cond_stage_trainable: False
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_by_std: true
19
+ use_lcm: True
20
+ num_ddim_timesteps: 50
21
+ w_min: 4
22
+ w_max: 12
23
+ ckpt_path: ../ckpt/maa2.ckpt
24
+
25
+ use_ema: false
26
+ scheduler_config:
27
+ target: ldm.lr_scheduler.LambdaLinearScheduler
28
+ params:
29
+ warm_up_steps:
30
+ - 10000
31
+ cycle_lengths:
32
+ - 10000000000000
33
+ f_start:
34
+ - 1.0e-06
35
+ f_max:
36
+ - 1.0
37
+ f_min:
38
+ - 1.0
39
+ unet_config:
40
+ target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
41
+ params:
42
+ in_channels: 20
43
+ context_dim: 1024
44
+ hidden_size: 576
45
+ num_heads: 8
46
+ depth: 4
47
+ max_len: 1000
48
+ first_stage_config:
49
+ target: ldm.models.autoencoder1d.AutoencoderKL
50
+ params:
51
+ embed_dim: 20
52
+ monitor: val/rec_loss
53
+ ckpt_path: ./model/AutoencoderKL/epoch=000032.ckpt
54
+ ddconfig:
55
+ double_z: true
56
+ in_channels: 80
57
+ out_ch: 80
58
+ z_channels: 20
59
+ kernel_size: 5
60
+ ch: 384
61
+ ch_mult:
62
+ - 1
63
+ - 2
64
+ - 4
65
+ num_res_blocks: 2
66
+ attn_layers:
67
+ - 3
68
+ down_layers:
69
+ - 0
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+ cond_stage_config:
74
+ target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
75
+ params:
76
+ weights_path: ./model/FrozenCLAPFLANEmbedder/CLAP_weights_2022.pth
77
+
78
+ lightning:
79
+ callbacks:
80
+ image_logger:
81
+ target: main.AudioLogger
82
+ params:
83
+ sample_rate: 16000
84
+ for_specs: true
85
+ increase_log_steps: false
86
+ batch_frequency: 5000
87
+ max_images: 8
88
+ melvmin: -5
89
+ melvmax: 1.5
90
+ vocoder_cfg:
91
+ target: vocoder.bigvgan.models.VocoderBigVGAN
92
+ params:
93
+ ckpt_vocoder: ./vocoder/logs/bigvnat16k93.5w
94
+ trainer:
95
+ benchmark: True
96
+ gradient_clip_val: 1.0
97
+ replace_sampler_ddp: false
98
+ max_epochs: 100
99
+ modelcheckpoint:
100
+ params:
101
+ monitor: epoch
102
+ mode: max
103
+ # every_n_train_steps: 2000
104
+ save_top_k: 100
105
+ every_n_epochs: 3
106
+
107
+
108
+ data:
109
+ target: main.SpectrogramDataModuleFromConfig
110
+ params:
111
+ batch_size: 8
112
+ num_workers: 32
113
+ spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
114
+ mel_num: 80
115
+ train:
116
+ target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsTrain
117
+ params:
118
+ specs_dataset_cfg:
119
+ validation:
120
+ target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsValidation
121
+ params:
122
+ specs_dataset_cfg:
123
+
124
+ test_dataset:
125
+ target: ldm.data.tsvdataset.TSVDatasetStruct
126
+ params:
127
+ tsv_path: audiocaps_test_16000_struct.tsv
128
+ spec_crop_len: 624
129
+
130
+
configs/autoencoder1d.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-06
3
+ target: ldm.models.autoencoder1d.AutoencoderKL
4
+ params:
5
+ embed_dim: 20
6
+ monitor: val/rec_loss
7
+ ddconfig:
8
+ double_z: true
9
+ in_channels: 80
10
+ out_ch: 80
11
+ z_channels: 20
12
+ kernel_size: 5
13
+ ch: 384
14
+ ch_mult:
15
+ - 1
16
+ - 2
17
+ - 4
18
+ num_res_blocks: 2
19
+ attn_layers:
20
+ - 3
21
+ down_layers:
22
+ - 0
23
+ dropout: 0.0
24
+ lossconfig:
25
+ target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator
26
+ params:
27
+ disc_start: 80001
28
+ perceptual_weight: 0.0
29
+ kl_weight: 1.0e-06
30
+ disc_weight: 0.5
31
+ disc_in_channels: 1
32
+ disc_loss: mse
33
+ disc_factor: 2
34
+ disc_conditional: false
35
+ r1_reg_weight: 3
36
+
37
+ lightning:
38
+ callbacks:
39
+ image_logger:
40
+ target: main.AudioLogger
41
+ params:
42
+ for_specs: true
43
+ increase_log_steps: false
44
+ batch_frequency: 5000
45
+ max_images: 8
46
+ rescale: false
47
+ melvmin: -5
48
+ melvmax: 1.5
49
+ vocoder_cfg:
50
+ target: vocoder.bigvgan.models.VocoderBigVGAN
51
+ params:
52
+ ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
53
+ trainer:
54
+ sync_batchnorm: false # not working with r1_regularization
55
+ strategy: ddp
56
+
57
+
58
+ data:
59
+ target: main.SpectrogramDataModuleFromConfig
60
+ params:
61
+ batch_size: 4
62
+ num_workers: 16
63
+ spec_dir_path: ldm/data/tsv_dirs/full_data/V1_new
64
+ mel_num: 80
65
+ spec_len: 624
66
+ spec_crop_len: 624
67
+ train:
68
+ target: ldm.data.joinaudiodataset_624.JoinSpecsTrain
69
+ params:
70
+ specs_dataset_cfg: null
71
+ validation:
72
+ target: ldm.data.joinaudiodataset_624.JoinSpecsValidation
73
+ params:
74
+ specs_dataset_cfg: null
configs/teacher.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 3.0e-06
3
+ target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ mel_dim: 20
13
+ mel_length: 312
14
+ channels: 0
15
+ cond_stage_trainable: True
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_by_std: true
19
+ use_ema: false
20
+ scheduler_config:
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps:
24
+ - 10000
25
+ cycle_lengths:
26
+ - 10000000000000
27
+ f_start:
28
+ - 1.0e-06
29
+ f_max:
30
+ - 1.0
31
+ f_min:
32
+ - 1.0
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
35
+ params:
36
+ in_channels: 20
37
+ context_dim: 1024
38
+ hidden_size: 576
39
+ num_heads: 8
40
+ depth: 4
41
+ max_len: 1000
42
+ first_stage_config:
43
+ target: ldm.models.autoencoder1d.AutoencoderKL
44
+ params:
45
+ embed_dim: 20
46
+ monitor: val/rec_loss
47
+ ckpt_path: logs/trainae/ckpt/epoch=000032.ckpt
48
+ ddconfig:
49
+ double_z: true
50
+ in_channels: 80
51
+ out_ch: 80
52
+ z_channels: 20
53
+ kernel_size: 5
54
+ ch: 384
55
+ ch_mult:
56
+ - 1
57
+ - 2
58
+ - 4
59
+ num_res_blocks: 2
60
+ attn_layers:
61
+ - 3
62
+ down_layers:
63
+ - 0
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+ cond_stage_config:
68
+ target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
69
+ params:
70
+ weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth
71
+
72
+ lightning:
73
+ callbacks:
74
+ image_logger:
75
+ target: main.AudioLogger
76
+ params:
77
+ sample_rate: 16000
78
+ for_specs: true
79
+ increase_log_steps: false
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ melvmin: -5
83
+ melvmax: 1.5
84
+ vocoder_cfg:
85
+ target: vocoder.bigvgan.models.VocoderBigVGAN
86
+ params:
87
+ ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
88
+ trainer:
89
+ benchmark: True
90
+ gradient_clip_val: 1.0
91
+ replace_sampler_ddp: false
92
+ modelcheckpoint:
93
+ params:
94
+ monitor: epoch
95
+ mode: max
96
+ save_top_k: 10
97
+ every_n_epochs: 5
98
+
99
+ data:
100
+ target: main.SpectrogramDataModuleFromConfig
101
+ params:
102
+ batch_size: 4
103
+ num_workers: 32
104
+ main_spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
105
+ other_spec_dir_path: 'ldm/data/tsv_dirs/full_data/V2'
106
+ mel_num: 80
107
+ train:
108
+ target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsTrain
109
+ params:
110
+ specs_dataset_cfg:
111
+ validation:
112
+ target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsValidation
113
+ params:
114
+ specs_dataset_cfg:
115
+
116
+ test_dataset:
117
+ target: ldm.data.tsvdataset.TSVDatasetStruct
118
+ params:
119
+ tsv_path: musiccap.tsv
120
+ spec_crop_len: 624
121
+
infer.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
2
+ --ddim_steps 2 -b configs/audiolcm.yaml \
3
+ --sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
4
+ --outdir results/test --test-dataset audiocaps -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt
infer_api.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
2
+ --ddim_steps 2 -b configs/audiolcm.yaml \
3
+ --sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
4
+ --outdir results/test -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt --prompt_txt ./prompt.txt
ldm/__pycache__/lr_scheduler.cpython-37.pyc ADDED
Binary file (3.66 kB). View file
 
ldm/__pycache__/lr_scheduler.cpython-38.pyc ADDED
Binary file (3.61 kB). View file
 
ldm/__pycache__/util.cpython-310.pyc ADDED
Binary file (8.36 kB). View file
 
ldm/__pycache__/util.cpython-37.pyc ADDED
Binary file (5.1 kB). View file
 
ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (8.3 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc ADDED
Binary file (12.4 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc ADDED
Binary file (3.69 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc ADDED
Binary file (12.5 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc ADDED
Binary file (8.29 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc ADDED
Binary file (8.09 kB). View file
 
ldm/data/__pycache__/tsvdataset.cpython-38.pyc ADDED
Binary file (2.66 kB). View file
 
ldm/data/joinaudiodataset_624.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import logging
5
+ import pandas as pd
6
+ import glob
7
+ logger = logging.getLogger(f'main.{__name__}')
8
+
9
+ sys.path.insert(0, '.') # nopep8
10
+
11
+ class JoinManifestSpecs(torch.utils.data.Dataset):
12
+ def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
13
+ super().__init__()
14
+ self.split = split
15
+ self.batch_max_length = spec_crop_len
16
+ self.batch_min_length = 50
17
+ self.mel_num = mel_num
18
+ self.drop = drop
19
+ manifest_files = []
20
+ for dir_path in spec_dir_path.split(','):
21
+ manifest_files += glob.glob(f'{dir_path}/**/*.tsv',recursive=True)
22
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
23
+ df = pd.concat(df_list,ignore_index=True)
24
+
25
+ if split == 'train':
26
+ self.dataset = df.iloc[100:]
27
+ elif split == 'valid' or split == 'val':
28
+ self.dataset = df.iloc[:100]
29
+ elif split == 'test':
30
+ df = self.add_name_num(df)
31
+ self.dataset = df
32
+ else:
33
+ raise ValueError(f'Unknown split {split}')
34
+ self.dataset.reset_index(inplace=True)
35
+ print('dataset len:', len(self.dataset))
36
+
37
+ def add_name_num(self,df):
38
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
39
+ name_count_dict = {}
40
+ change = []
41
+ for t in df.itertuples():
42
+ name = getattr(t,'name')
43
+ if name in name_count_dict:
44
+ name_count_dict[name] += 1
45
+ else:
46
+ name_count_dict[name] = 0
47
+ change.append((t[0],name_count_dict[name]))
48
+ for t in change:
49
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
50
+ return df
51
+
52
+ def __getitem__(self, idx):
53
+ data = self.dataset.iloc[idx]
54
+ item = {}
55
+ try:
56
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
57
+ except:
58
+ mel_path = data['mel_path']
59
+ print(f'corrupted:{mel_path}')
60
+ spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
61
+
62
+ if spec.shape[1] < self.batch_max_length:
63
+ # spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
64
+ spec = np.tile(spec,reps=(self.batch_max_length//spec.shape[1])+1)
65
+
66
+ item['image'] = spec[:,:self.batch_max_length]
67
+ p = np.random.uniform(0,1)
68
+ if p > self.drop:
69
+ item["caption"] = data['caption']
70
+ else:
71
+ item["caption"] = ""
72
+ if self.split == 'test':
73
+ item['f_name'] = data['name']
74
+ return item
75
+
76
+ def __len__(self):
77
+ return len(self.dataset)
78
+
79
+
80
+ class JoinSpecsTrain(JoinManifestSpecs):
81
+ def __init__(self, specs_dataset_cfg):
82
+ super().__init__('train', **specs_dataset_cfg)
83
+
84
+ class JoinSpecsValidation(JoinManifestSpecs):
85
+ def __init__(self, specs_dataset_cfg):
86
+ super().__init__('valid', **specs_dataset_cfg)
87
+
88
+ class JoinSpecsTest(JoinManifestSpecs):
89
+ def __init__(self, specs_dataset_cfg):
90
+ super().__init__('test', **specs_dataset_cfg)
91
+
92
+
93
+
ldm/data/joinaudiodataset_anylen.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data.sampler import Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ import torch.distributed
9
+ from typing import TypeVar, Optional, Iterator,List
10
+ import logging
11
+ import pandas as pd
12
+ import glob
13
+ import torch.distributed as dist
14
+ logger = logging.getLogger(f'main.{__name__}')
15
+
16
+ sys.path.insert(0, '.') # nopep8
17
+
18
+ class JoinManifestSpecs(torch.utils.data.Dataset):
19
+ def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
20
+ super().__init__()
21
+ self.split = split
22
+ self.max_batch_len = spec_crop_len
23
+ self.min_batch_len = 64
24
+ self.mel_num = mel_num
25
+ self.min_factor = 4
26
+ self.drop = drop
27
+ self.pad_value = pad_value
28
+ assert mode in ['pad','tile']
29
+ self.collate_mode = mode
30
+ # print(f"################# self.collate_mode {self.collate_mode} ##################")
31
+
32
+ manifest_files = []
33
+ for dir_path in spec_dir_path.split(','):
34
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
35
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
36
+ df = pd.concat(df_list,ignore_index=True)
37
+
38
+ if split == 'train':
39
+ self.dataset = df.iloc[100:]
40
+ elif split == 'valid' or split == 'val':
41
+ self.dataset = df.iloc[:100]
42
+ elif split == 'test':
43
+ df = self.add_name_num(df)
44
+ self.dataset = df
45
+ else:
46
+ raise ValueError(f'Unknown split {split}')
47
+ self.dataset.reset_index(inplace=True)
48
+ print('dataset len:', len(self.dataset))
49
+
50
+ def add_name_num(self,df):
51
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
52
+ name_count_dict = {}
53
+ change = []
54
+ for t in df.itertuples():
55
+ name = getattr(t,'name')
56
+ if name in name_count_dict:
57
+ name_count_dict[name] += 1
58
+ else:
59
+ name_count_dict[name] = 0
60
+ change.append((t[0],name_count_dict[name]))
61
+ for t in change:
62
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
63
+ return df
64
+
65
+ def ordered_indices(self):
66
+ index2dur = self.dataset[['duration']]
67
+ index2dur = index2dur.sort_values(by='duration')
68
+ return list(index2dur.index)
69
+
70
+ def __getitem__(self, idx):
71
+ item = {}
72
+ data = self.dataset.iloc[idx]
73
+ try:
74
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
75
+ except:
76
+ mel_path = data['mel_path']
77
+ print(f'corrupted:{mel_path}')
78
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
79
+
80
+
81
+ item['image'] = spec
82
+ p = np.random.uniform(0,1)
83
+ if p > self.drop:
84
+ item["caption"] = data['caption']
85
+ else:
86
+ item["caption"] = ""
87
+ if self.split == 'test':
88
+ item['f_name'] = data['name']
89
+ # item['f_name'] = data['mel_path']
90
+ return item
91
+
92
+ def collater(self,inputs):
93
+ to_dict = {}
94
+ for l in inputs:
95
+ for k,v in l.items():
96
+ if k in to_dict:
97
+ to_dict[k].append(v)
98
+ else:
99
+ to_dict[k] = [v]
100
+ if self.collate_mode == 'pad':
101
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
102
+ elif self.collate_mode == 'tile':
103
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
104
+ else:
105
+ raise NotImplementedError
106
+
107
+ return to_dict
108
+
109
+ def __len__(self):
110
+ return len(self.dataset)
111
+
112
+
113
+ class JoinSpecsTrain(JoinManifestSpecs):
114
+ def __init__(self, specs_dataset_cfg):
115
+ super().__init__('train', **specs_dataset_cfg)
116
+
117
+ class JoinSpecsValidation(JoinManifestSpecs):
118
+ def __init__(self, specs_dataset_cfg):
119
+ super().__init__('valid', **specs_dataset_cfg)
120
+
121
+ class JoinSpecsTest(JoinManifestSpecs):
122
+ def __init__(self, specs_dataset_cfg):
123
+ super().__init__('test', **specs_dataset_cfg)
124
+
125
+ class JoinSpecsDebug(JoinManifestSpecs):
126
+ def __init__(self, specs_dataset_cfg):
127
+ super().__init__('valid', **specs_dataset_cfg)
128
+ self.dataset = self.dataset.iloc[:37]
129
+
130
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
131
+ def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
132
+ rank: Optional[int] = None, shuffle: bool = True,
133
+ seed: int = 0, drop_last: bool = False) -> None:
134
+ if num_replicas is None:
135
+ if not dist.is_initialized():
136
+ # raise RuntimeError("Requires distributed package to be available")
137
+ print("Not in distributed mode")
138
+ num_replicas = 1
139
+ else:
140
+ num_replicas = dist.get_world_size()
141
+ if rank is None:
142
+ if not dist.is_initialized():
143
+ # raise RuntimeError("Requires distributed package to be available")
144
+ rank = 0
145
+ else:
146
+ rank = dist.get_rank()
147
+ if rank >= num_replicas or rank < 0:
148
+ raise ValueError(
149
+ "Invalid rank {}, rank should be in the interval"
150
+ " [0, {}]".format(rank, num_replicas - 1))
151
+ self.indices = indices
152
+ self.num_replicas = num_replicas
153
+ self.rank = rank
154
+ self.epoch = 0
155
+ self.drop_last = drop_last
156
+ self.batch_size = batch_size
157
+
158
+ self.batches = self.build_batches()
159
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
160
+ # If the dataset length is evenly divisible by replicas, then there
161
+ # is no need to drop any data, since the dataset will be split equally.
162
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
163
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
164
+ if len(self.batches) > self.num_replicas:
165
+ self.batches = self.batches[self.rank::self.num_replicas]
166
+ else: # may happen in sanity checking
167
+ self.batches = [self.batches[0]]
168
+ print(f"after split batches_num {len(self.batches)}")
169
+ self.shuffle = shuffle
170
+ if self.shuffle:
171
+ self.batches = np.random.permutation(self.batches)
172
+ self.seed = seed
173
+
174
+ def set_epoch(self,epoch):
175
+ self.epoch = epoch
176
+ if self.shuffle:
177
+ np.random.seed(self.seed+self.epoch)
178
+ self.batches = np.random.permutation(self.batches)
179
+
180
+ def build_batches(self):
181
+ batches,batch = [],[]
182
+ for index in self.indices:
183
+ batch.append(index)
184
+ if len(batch) == self.batch_size:
185
+ batches.append(batch)
186
+ batch = []
187
+ if not self.drop_last and len(batch) > 0:
188
+ batches.append(batch)
189
+ return batches
190
+
191
+ def __iter__(self) -> Iterator[List[int]]:
192
+ for batch in self.batches:
193
+ yield batch
194
+
195
+ def __len__(self) -> int:
196
+ return len(self.batches)
197
+
198
+ def set_epoch(self, epoch: int) -> None:
199
+ r"""
200
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
201
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
202
+ sampler will yield the same ordering.
203
+
204
+ Args:
205
+ epoch (int): Epoch number.
206
+ """
207
+ self.epoch = epoch
208
+
209
+
210
+ def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
211
+ if len(values[0].shape) == 1:
212
+ return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
213
+ else:
214
+ return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
215
+
216
+ def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
217
+ """Convert a list of 1d tensors into a padded 2d tensor."""
218
+ size = max(v.size(0) for v in values)
219
+ if max_len:
220
+ size = min(size,max_len)
221
+ if min_len:
222
+ size = max(size,min_len)
223
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
224
+ size += (min_factor - size % min_factor)
225
+ res = values[0].new(len(values), size).fill_(pad_idx)
226
+
227
+ def copy_tensor(src, dst):
228
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
229
+ if shift_right:
230
+ dst[1:] = src[:-1]
231
+ dst[0] = shift_id
232
+ else:
233
+ dst.copy_(src)
234
+
235
+ for i, v in enumerate(values):
236
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
237
+ return res
238
+
239
+
240
+ def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
241
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
242
+ values[0] shape: (melbins,mel_length)
243
+ """
244
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
245
+ if max_len:
246
+ size = min(size,max_len)
247
+ if min_len:
248
+ size = max(size,min_len)
249
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
250
+ size += (min_factor - size % min_factor)
251
+
252
+ if isinstance(values,np.ndarray):
253
+ values = torch.FloatTensor(values)
254
+ if isinstance(values,list):
255
+ values = [torch.FloatTensor(v) for v in values]
256
+ res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
257
+
258
+ def copy_tensor(src, dst):
259
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
260
+ if shift_right:
261
+ dst[1:] = src[:-1]
262
+ else:
263
+ dst.copy_(src)
264
+
265
+ for i, v in enumerate(values):
266
+ copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
267
+ return res
268
+
269
+
270
+ def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
271
+ if len(values[0].shape) == 1:
272
+ return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
273
+ else:
274
+ return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
275
+
276
+ def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
277
+ """Convert a list of 1d tensors into a padded 2d tensor."""
278
+ size = max(v.size(0) for v in values)
279
+ if max_len:
280
+ size = min(size,max_len)
281
+ if min_len:
282
+ size = max(size,min_len)
283
+ if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
284
+ size += (min_factor - size % min_factor)
285
+ res = values[0].new(len(values), size)
286
+
287
+ def copy_tensor(src, dst):
288
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
289
+ if shift_right:
290
+ dst[1:] = src[:-1]
291
+ dst[0] = shift_id
292
+ else:
293
+ dst.copy_(src)
294
+
295
+ for i, v in enumerate(values):
296
+ n_repeat = math.ceil((size + 1) / v.shape[0])
297
+ v = torch.tile(v,dims=(1,n_repeat))[:size]
298
+ copy_tensor(v, res[i])
299
+
300
+ return res
301
+
302
+
303
+ def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
304
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
305
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
306
+ if max_len:
307
+ size = min(size,max_len)
308
+ if min_len:
309
+ size = max(size,min_len)
310
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
311
+ size += (min_factor - size % min_factor)
312
+
313
+ if isinstance(values,np.ndarray):
314
+ values = torch.FloatTensor(values)
315
+ if isinstance(values,list):
316
+ values = [torch.FloatTensor(v) for v in values]
317
+ res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
318
+
319
+ def copy_tensor(src, dst):
320
+ assert dst.numel() == src.numel()
321
+ if shift_right:
322
+ dst[1:] = src[:-1]
323
+ else:
324
+ dst.copy_(src)
325
+
326
+ for i, v in enumerate(values):
327
+ n_repeat = math.ceil((size + 1) / v.shape[1])
328
+ v = torch.tile(v,dims=(1,n_repeat))[:,:size]
329
+ copy_tensor(v, res[i])
330
+
331
+ return res
ldm/data/joinaudiodataset_struct.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import logging
5
+ import pandas as pd
6
+ import glob
7
+ logger = logging.getLogger(f'main.{__name__}')
8
+
9
+ sys.path.insert(0, '.') # nopep8
10
+
11
+ class JoinManifestSpecs(torch.utils.data.Dataset):
12
+ def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
13
+ super().__init__()
14
+ self.split = split
15
+ self.batch_max_length = spec_crop_len
16
+ self.batch_min_length = 50
17
+ self.drop = drop
18
+ self.mel_num = mel_num
19
+
20
+ manifest_files = []
21
+ for dir_path in spec_dir_path.split(','):
22
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
23
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
24
+ df = pd.concat(df_list,ignore_index=True)
25
+
26
+ if split == 'train':
27
+ self.dataset = df.iloc[100:]
28
+ elif split == 'valid' or split == 'val':
29
+ self.dataset = df.iloc[:100]
30
+ elif split == 'test':
31
+ df = self.add_name_num(df)
32
+ self.dataset = df
33
+ else:
34
+ raise ValueError(f'Unknown split {split}')
35
+ self.dataset.reset_index(inplace=True)
36
+ print('dataset len:', len(self.dataset))
37
+
38
+ def add_name_num(self,df):
39
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
40
+ name_count_dict = {}
41
+ change = []
42
+ for t in df.itertuples():
43
+ name = getattr(t,'name')
44
+ if name in name_count_dict:
45
+ name_count_dict[name] += 1
46
+ else:
47
+ name_count_dict[name] = 0
48
+ change.append((t[0],name_count_dict[name]))
49
+ for t in change:
50
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
51
+ return df
52
+
53
+ def __getitem__(self, idx):
54
+ data = self.dataset.iloc[idx]
55
+ item = {}
56
+ try:
57
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
58
+ except:
59
+ mel_path = data['mel_path']
60
+ print(f'corrupted:{mel_path}')
61
+ spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
62
+
63
+ if spec.shape[1] <= self.batch_max_length:
64
+ spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
65
+
66
+
67
+ item['image'] = spec[:self.mel_num,:self.batch_max_length]
68
+ p = np.random.uniform(0,1)
69
+ if p > self.drop:
70
+ item["caption"] = {"ori_caption":data['ori_cap'],"struct_caption":data['caption']}
71
+ else:
72
+ item["caption"] = {"ori_caption":"","struct_caption":""}
73
+
74
+ if self.split == 'test':
75
+ item['f_name'] = data['name']
76
+ return item
77
+
78
+ def __len__(self):
79
+ return len(self.dataset)
80
+
81
+
82
+ class JoinSpecsTrain(JoinManifestSpecs):
83
+ def __init__(self, specs_dataset_cfg):
84
+ super().__init__('train', **specs_dataset_cfg)
85
+
86
+ class JoinSpecsValidation(JoinManifestSpecs):
87
+ def __init__(self, specs_dataset_cfg):
88
+ super().__init__('valid', **specs_dataset_cfg)
89
+
90
+ class JoinSpecsTest(JoinManifestSpecs):
91
+ def __init__(self, specs_dataset_cfg):
92
+ super().__init__('test', **specs_dataset_cfg)
93
+
94
+
95
+
ldm/data/joinaudiodataset_struct_anylen.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data.sampler import Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ import torch.distributed
9
+ from typing import TypeVar, Optional, Iterator,List
10
+ import logging
11
+ import pandas as pd
12
+ import glob
13
+ import torch.distributed as dist
14
+ logger = logging.getLogger(f'main.{__name__}')
15
+
16
+ sys.path.insert(0, '.') # nopep8
17
+
18
+ class JoinManifestSpecs(torch.utils.data.Dataset):
19
+ def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
20
+ super().__init__()
21
+ self.split = split
22
+ self.max_batch_len = spec_crop_len
23
+ self.min_batch_len = 64
24
+ self.mel_num = mel_num
25
+ self.min_factor = 4
26
+ self.drop = drop
27
+ self.pad_value = pad_value
28
+ assert mode in ['pad','tile']
29
+ self.collate_mode = mode
30
+ # print(f"################# self.collate_mode {self.collate_mode} ##################")
31
+
32
+ manifest_files = []
33
+ for dir_path in spec_dir_path.split(','):
34
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
35
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
36
+ df = pd.concat(df_list,ignore_index=True)
37
+
38
+ if split == 'train':
39
+ self.dataset = df.iloc[100:]
40
+ elif split == 'valid' or split == 'val':
41
+ self.dataset = df.iloc[:100]
42
+ elif split == 'test':
43
+ df = self.add_name_num(df)
44
+ self.dataset = df
45
+ else:
46
+ raise ValueError(f'Unknown split {split}')
47
+ self.dataset.reset_index(inplace=True)
48
+ print('dataset len:', len(self.dataset))
49
+
50
+ def add_name_num(self,df):
51
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
52
+ name_count_dict = {}
53
+ change = []
54
+ for t in df.itertuples():
55
+ name = getattr(t,'name')
56
+ if name in name_count_dict:
57
+ name_count_dict[name] += 1
58
+ else:
59
+ name_count_dict[name] = 0
60
+ change.append((t[0],name_count_dict[name]))
61
+ for t in change:
62
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
63
+ return df
64
+
65
+ def ordered_indices(self):
66
+ index2dur = self.dataset[['duration']]
67
+ index2dur = index2dur.sort_values(by='duration')
68
+ return list(index2dur.index)
69
+
70
+ def __getitem__(self, idx):
71
+ item = {}
72
+ data = self.dataset.iloc[idx]
73
+ try:
74
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
75
+ except:
76
+ mel_path = data['mel_path']
77
+ print(f'corrupted:{mel_path}')
78
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
79
+
80
+
81
+ item['image'] = spec
82
+ p = np.random.uniform(0,1)
83
+ if p > self.drop:
84
+ ori_caption = data['caption']
85
+ struct_caption = f'<{ori_caption}& all>'
86
+ else:
87
+ ori_caption = ""
88
+ struct_caption = ""
89
+ item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
90
+ if self.split == 'test':
91
+ item['f_name'] = data['name']
92
+ # item['f_name'] = data['mel_path']
93
+ return item
94
+
95
+ def collater(self,inputs):
96
+ to_dict = {}
97
+ for l in inputs:
98
+ for k,v in l.items():
99
+ if k in to_dict:
100
+ to_dict[k].append(v)
101
+ else:
102
+ to_dict[k] = [v]
103
+ if self.collate_mode == 'pad':
104
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
105
+ elif self.collate_mode == 'tile':
106
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
107
+ else:
108
+ raise NotImplementedError
109
+ to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
110
+ 'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
111
+
112
+ return to_dict
113
+
114
+ def __len__(self):
115
+ return len(self.dataset)
116
+
117
+
118
+ class JoinSpecsTrain(JoinManifestSpecs):
119
+ def __init__(self, specs_dataset_cfg):
120
+ super().__init__('train', **specs_dataset_cfg)
121
+
122
+ class JoinSpecsValidation(JoinManifestSpecs):
123
+ def __init__(self, specs_dataset_cfg):
124
+ super().__init__('valid', **specs_dataset_cfg)
125
+
126
+ class JoinSpecsTest(JoinManifestSpecs):
127
+ def __init__(self, specs_dataset_cfg):
128
+ super().__init__('test', **specs_dataset_cfg)
129
+
130
+ class JoinSpecsDebug(JoinManifestSpecs):
131
+ def __init__(self, specs_dataset_cfg):
132
+ super().__init__('valid', **specs_dataset_cfg)
133
+ self.dataset = self.dataset.iloc[:37]
134
+
135
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到��个batch中以避免过长的pad
136
+ def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
137
+ rank: Optional[int] = None, shuffle: bool = True,
138
+ seed: int = 0, drop_last: bool = False) -> None:
139
+ if num_replicas is None:
140
+ if not dist.is_initialized():
141
+ # raise RuntimeError("Requires distributed package to be available")
142
+ print("Not in distributed mode")
143
+ num_replicas = 1
144
+ else:
145
+ num_replicas = dist.get_world_size()
146
+ if rank is None:
147
+ if not dist.is_initialized():
148
+ # raise RuntimeError("Requires distributed package to be available")
149
+ rank = 0
150
+ else:
151
+ rank = dist.get_rank()
152
+ if rank >= num_replicas or rank < 0:
153
+ raise ValueError(
154
+ "Invalid rank {}, rank should be in the interval"
155
+ " [0, {}]".format(rank, num_replicas - 1))
156
+ self.indices = indices
157
+ self.num_replicas = num_replicas
158
+ self.rank = rank
159
+ self.epoch = 0
160
+ self.drop_last = drop_last
161
+ self.batch_size = batch_size
162
+
163
+ self.batches = self.build_batches()
164
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
165
+ # If the dataset length is evenly divisible by replicas, then there
166
+ # is no need to drop any data, since the dataset will be split equally.
167
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
168
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
169
+ if len(self.batches) > self.num_replicas:
170
+ self.batches = self.batches[self.rank::self.num_replicas]
171
+ else: # may happen in sanity checking
172
+ self.batches = [self.batches[0]]
173
+ print(f"after split batches_num {len(self.batches)}")
174
+ self.shuffle = shuffle
175
+ if self.shuffle:
176
+ self.batches = np.random.permutation(self.batches)
177
+ self.seed = seed
178
+
179
+ def set_epoch(self,epoch):
180
+ self.epoch = epoch
181
+ if self.shuffle:
182
+ np.random.seed(self.seed+self.epoch)
183
+ self.batches = np.random.permutation(self.batches)
184
+
185
+ def build_batches(self):
186
+ batches,batch = [],[]
187
+ for index in self.indices:
188
+ batch.append(index)
189
+ if len(batch) == self.batch_size:
190
+ batches.append(batch)
191
+ batch = []
192
+ if not self.drop_last and len(batch) > 0:
193
+ batches.append(batch)
194
+ return batches
195
+
196
+ def __iter__(self) -> Iterator[List[int]]:
197
+ for batch in self.batches:
198
+ yield batch
199
+
200
+ def __len__(self) -> int:
201
+ return len(self.batches)
202
+
203
+ def set_epoch(self, epoch: int) -> None:
204
+ r"""
205
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
206
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
207
+ sampler will yield the same ordering.
208
+
209
+ Args:
210
+ epoch (int): Epoch number.
211
+ """
212
+ self.epoch = epoch
213
+
214
+
215
+ def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
216
+ if len(values[0].shape) == 1:
217
+ return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
218
+ else:
219
+ return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
220
+
221
+ def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
222
+ """Convert a list of 1d tensors into a padded 2d tensor."""
223
+ size = max(v.size(0) for v in values)
224
+ if max_len:
225
+ size = min(size,max_len)
226
+ if min_len:
227
+ size = max(size,min_len)
228
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
229
+ size += (min_factor - size % min_factor)
230
+ res = values[0].new(len(values), size).fill_(pad_idx)
231
+
232
+ def copy_tensor(src, dst):
233
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
234
+ if shift_right:
235
+ dst[1:] = src[:-1]
236
+ dst[0] = shift_id
237
+ else:
238
+ dst.copy_(src)
239
+
240
+ for i, v in enumerate(values):
241
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
242
+ return res
243
+
244
+
245
+ def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
246
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
247
+ values[0] shape: (melbins,mel_length)
248
+ """
249
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
250
+ if max_len:
251
+ size = min(size,max_len)
252
+ if min_len:
253
+ size = max(size,min_len)
254
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
255
+ size += (min_factor - size % min_factor)
256
+
257
+ if isinstance(values,np.ndarray):
258
+ values = torch.FloatTensor(values)
259
+ if isinstance(values,list):
260
+ values = [torch.FloatTensor(v) for v in values]
261
+ res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
262
+
263
+ def copy_tensor(src, dst):
264
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
265
+ if shift_right:
266
+ dst[1:] = src[:-1]
267
+ else:
268
+ dst.copy_(src)
269
+
270
+ for i, v in enumerate(values):
271
+ copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
272
+ return res
273
+
274
+
275
+ def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
276
+ if len(values[0].shape) == 1:
277
+ return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
278
+ else:
279
+ return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
280
+
281
+ def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
282
+ """Convert a list of 1d tensors into a padded 2d tensor."""
283
+ size = max(v.size(0) for v in values)
284
+ if max_len:
285
+ size = min(size,max_len)
286
+ if min_len:
287
+ size = max(size,min_len)
288
+ if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
289
+ size += (min_factor - size % min_factor)
290
+ res = values[0].new(len(values), size)
291
+
292
+ def copy_tensor(src, dst):
293
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
294
+ if shift_right:
295
+ dst[1:] = src[:-1]
296
+ dst[0] = shift_id
297
+ else:
298
+ dst.copy_(src)
299
+
300
+ for i, v in enumerate(values):
301
+ n_repeat = math.ceil((size + 1) / v.shape[0])
302
+ v = torch.tile(v,dims=(1,n_repeat))[:size]
303
+ copy_tensor(v, res[i])
304
+
305
+ return res
306
+
307
+
308
+ def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
309
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
310
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
311
+ if max_len:
312
+ size = min(size,max_len)
313
+ if min_len:
314
+ size = max(size,min_len)
315
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
316
+ size += (min_factor - size % min_factor)
317
+
318
+ if isinstance(values,np.ndarray):
319
+ values = torch.FloatTensor(values)
320
+ if isinstance(values,list):
321
+ values = [torch.FloatTensor(v) for v in values]
322
+ res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
323
+
324
+ def copy_tensor(src, dst):
325
+ assert dst.numel() == src.numel()
326
+ if shift_right:
327
+ dst[1:] = src[:-1]
328
+ else:
329
+ dst.copy_(src)
330
+
331
+ for i, v in enumerate(values):
332
+ n_repeat = math.ceil((size + 1) / v.shape[1])
333
+ v = torch.tile(v,dims=(1,n_repeat))[:,:size]
334
+ copy_tensor(v, res[i])
335
+
336
+ return res
ldm/data/joinaudiodataset_struct_sample.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import logging
5
+ import pandas as pd
6
+ import glob
7
+ logger = logging.getLogger(f'main.{__name__}')
8
+
9
+ sys.path.insert(0, '.') # nopep8
10
+
11
+ class JoinManifestSpecs(torch.utils.data.Dataset):
12
+ def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=None, spec_crop_len=None,pad_value=-5,**kwargs):
13
+ super().__init__()
14
+ self.main_prob = 0.5
15
+ self.split = split
16
+ self.batch_max_length = spec_crop_len
17
+ self.batch_min_length = 50
18
+ self.mel_num = mel_num
19
+ self.pad_value = pad_value
20
+ manifest_files = []
21
+ for dir_path in main_spec_dir_path.split(','):
22
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
23
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
24
+ self.df_main = pd.concat(df_list,ignore_index=True)
25
+
26
+ manifest_files = []
27
+ for dir_path in other_spec_dir_path.split(','):
28
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
29
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
30
+ self.df_other = pd.concat(df_list,ignore_index=True)
31
+
32
+ if split == 'train':
33
+ self.dataset = self.df_main.iloc[100:]
34
+ elif split == 'valid' or split == 'val':
35
+ self.dataset = self.df_main.iloc[:100]
36
+ elif split == 'test':
37
+ self.df_main = self.add_name_num(self.df_main)
38
+ self.dataset = self.df_main
39
+ else:
40
+ raise ValueError(f'Unknown split {split}')
41
+ self.dataset.reset_index(inplace=True)
42
+ print('dataset len:', len(self.dataset))
43
+
44
+ def add_name_num(self,df):
45
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
46
+ name_count_dict = {}
47
+ change = []
48
+ for t in df.itertuples():
49
+ name = getattr(t,'name')
50
+ if name in name_count_dict:
51
+ name_count_dict[name] += 1
52
+ else:
53
+ name_count_dict[name] = 0
54
+ change.append((t[0],name_count_dict[name]))
55
+ for t in change:
56
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
57
+ return df
58
+
59
+ def __getitem__(self, idx):
60
+ if np.random.uniform(0,1) < self.main_prob:
61
+ data = self.dataset.iloc[idx]
62
+ ori_caption = data['ori_cap']
63
+ struct_caption = data['caption']
64
+ else:
65
+ randidx = np.random.randint(0,len(self.df_other))
66
+ data = self.df_other.iloc[randidx]
67
+ ori_caption = data['caption']
68
+ struct_caption = f'<{ori_caption}, all>'
69
+ item = {}
70
+ try:
71
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
72
+ except:
73
+ mel_path = data['mel_path']
74
+ print(f'corrupted:{mel_path}')
75
+ spec = np.ones((self.mel_num,self.batch_max_length)).astype(np.float32)*self.pad_value
76
+
77
+ if spec.shape[1] <= self.batch_max_length:
78
+ spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1])),mode='constant',constant_values = (self.pad_value,self.pad_value)) # [80, 624]
79
+
80
+ item['image'] = spec[:self.mel_num,:self.batch_max_length]
81
+ item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
82
+ if self.split == 'test':
83
+ item['f_name'] = data['name']
84
+ return item
85
+
86
+ def __len__(self):
87
+ return len(self.dataset)
88
+
89
+
90
+ class JoinSpecsTrain(JoinManifestSpecs):
91
+ def __init__(self, specs_dataset_cfg):
92
+ super().__init__('train', **specs_dataset_cfg)
93
+
94
+ class JoinSpecsValidation(JoinManifestSpecs):
95
+ def __init__(self, specs_dataset_cfg):
96
+ super().__init__('valid', **specs_dataset_cfg)
97
+
98
+ class JoinSpecsTest(JoinManifestSpecs):
99
+ def __init__(self, specs_dataset_cfg):
100
+ super().__init__('test', **specs_dataset_cfg)
101
+
102
+
103
+
ldm/data/joinaudiodataset_struct_sample_anylen.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ from typing import TypeVar, Optional, Iterator
5
+ import logging
6
+ import pandas as pd
7
+ from ldm.data.joinaudiodataset_anylen import *
8
+ import glob
9
+ logger = logging.getLogger(f'main.{__name__}')
10
+
11
+ sys.path.insert(0, '.') # nopep8
12
+
13
+ class JoinManifestSpecs(torch.utils.data.Dataset):
14
+ def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
15
+ super().__init__()
16
+ self.split = split
17
+ self.max_batch_len = spec_crop_len
18
+ self.min_batch_len = 64
19
+ self.min_factor = 4
20
+ self.mel_num = mel_num
21
+ self.drop = drop
22
+ self.pad_value = pad_value
23
+ assert mode in ['pad','tile']
24
+ self.collate_mode = mode
25
+ manifest_files = []
26
+
27
+ for dir_path in main_spec_dir_path.split(','):
28
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
29
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
30
+ self.df_main = pd.concat(df_list,ignore_index=True)
31
+
32
+ manifest_files = []
33
+ for dir_path in other_spec_dir_path.split(','):
34
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
35
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
36
+ # import ipdb
37
+ # ipdb.set_trace()
38
+ self.df_other = pd.concat(df_list,ignore_index=True)
39
+ self.df_other.reset_index(inplace=True)
40
+
41
+ if split == 'train':
42
+ self.dataset = self.df_main.iloc[100:]
43
+ elif split == 'valid' or split == 'val':
44
+ self.dataset = self.df_main.iloc[:100]
45
+ elif split == 'test':
46
+ self.df_main = self.add_name_num(self.df_main)
47
+ self.dataset = self.df_main
48
+ else:
49
+ raise ValueError(f'Unknown split {split}')
50
+ self.dataset.reset_index(inplace=True)
51
+ print('dataset len:', len(self.dataset),"drop_rate",self.drop)
52
+
53
+ def add_name_num(self,df):
54
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
55
+ name_count_dict = {}
56
+ change = []
57
+ for t in df.itertuples():
58
+ name = getattr(t,'name')
59
+ if name in name_count_dict:
60
+ name_count_dict[name] += 1
61
+ else:
62
+ name_count_dict[name] = 0
63
+ change.append((t[0],name_count_dict[name]))
64
+ for t in change:
65
+ df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
66
+ return df
67
+
68
+ def ordered_indices(self):
69
+ index2dur = self.dataset[['duration']].sort_values(by='duration')
70
+ index2dur_other = self.df_other[['duration']].sort_values(by='duration')
71
+ other_indices = list(index2dur_other.index)
72
+ offset = len(self.dataset)
73
+ other_indices = [x + offset for x in other_indices]
74
+ return list(index2dur.index),other_indices
75
+ # return list(index2dur.index)
76
+
77
+ def collater(self,inputs):
78
+ to_dict = {}
79
+ for l in inputs:
80
+ for k,v in l.items():
81
+ if k in to_dict:
82
+ to_dict[k].append(v)
83
+ else:
84
+ to_dict[k] = [v]
85
+
86
+ if self.collate_mode == 'pad':
87
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
88
+ elif self.collate_mode == 'tile':
89
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
90
+ else:
91
+ raise NotImplementedError
92
+ to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
93
+ 'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
94
+
95
+ return to_dict
96
+
97
+ def __getitem__(self, idx):
98
+ if idx < len(self.dataset):
99
+ data = self.dataset.iloc[idx]
100
+ # p = np.random.uniform(0,1)
101
+ # if p > self.drop:
102
+ ori_caption = data['ori_cap']
103
+ struct_caption = data['caption']
104
+ # else:
105
+ # ori_caption = ""
106
+ # struct_caption = ""
107
+ else:
108
+ data = self.df_other.iloc[idx-len(self.dataset)]
109
+ # p = np.random.uniform(0,1)
110
+ # if p > self.drop:
111
+ ori_caption = data['caption']
112
+ struct_caption = f'<{ori_caption}& all>'
113
+ # else:
114
+ # ori_caption = ""
115
+ # struct_caption = ""
116
+ item = {}
117
+ try:
118
+ spec = np.load(data['mel_path']) # mel spec [80, T]
119
+ if spec.shape[1] > self.max_batch_len:
120
+ spec = spec[:,:self.max_batch_len]
121
+ except:
122
+ mel_path = data['mel_path']
123
+ print(f'corrupted:{mel_path}')
124
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
125
+
126
+ item['image'] = spec
127
+ item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
128
+ if self.split == 'test':
129
+ item['f_name'] = data['name']
130
+ return item
131
+
132
+ def __len__(self):
133
+ return len(self.dataset) + len(self.df_other)
134
+ # return len(self.dataset)
135
+
136
+
137
+ class JoinSpecsTrain(JoinManifestSpecs):
138
+ def __init__(self, specs_dataset_cfg):
139
+ super().__init__('train', **specs_dataset_cfg)
140
+
141
+ class JoinSpecsValidation(JoinManifestSpecs):
142
+ def __init__(self, specs_dataset_cfg):
143
+ super().__init__('valid', **specs_dataset_cfg)
144
+
145
+ class JoinSpecsTest(JoinManifestSpecs):
146
+ def __init__(self, specs_dataset_cfg):
147
+ super().__init__('test', **specs_dataset_cfg)
148
+
149
+
150
+
151
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
152
+ def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
153
+ # def __init__(self, main_indices,batch_size, num_replicas: Optional[int] = None,
154
+ rank: Optional[int] = None, shuffle: bool = True,
155
+ seed: int = 0, drop_last: bool = False) -> None:
156
+ if num_replicas is None:
157
+ if not dist.is_initialized():
158
+ # raise RuntimeError("Requires distributed package to be available")
159
+ print("Not in distributed mode")
160
+ num_replicas = 1
161
+ else:
162
+ num_replicas = dist.get_world_size()
163
+ if rank is None:
164
+ if not dist.is_initialized():
165
+ # raise RuntimeError("Requires distributed package to be available")
166
+ rank = 0
167
+ else:
168
+ rank = dist.get_rank()
169
+ if rank >= num_replicas or rank < 0:
170
+ raise ValueError(
171
+ "Invalid rank {}, rank should be in the interval"
172
+ " [0, {}]".format(rank, num_replicas - 1))
173
+ self.main_indices = main_indices
174
+ self.other_indices = other_indices
175
+ self.max_index = max(self.other_indices)
176
+ self.num_replicas = num_replicas
177
+ self.rank = rank
178
+ self.epoch = 0
179
+ self.drop_last = drop_last
180
+ self.batch_size = batch_size
181
+ self.shuffle = shuffle
182
+ self.batches = self.build_batches()
183
+ self.seed = seed
184
+
185
+ def set_epoch(self,epoch):
186
+ # print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
187
+ self.epoch = epoch
188
+ if self.shuffle:
189
+ np.random.seed(self.seed+self.epoch)
190
+ self.batches = self.build_batches()
191
+
192
+ def build_batches(self):
193
+ batches,batch = [],[]
194
+ for index in self.main_indices:
195
+ batch.append(index)
196
+ if len(batch) == self.batch_size:
197
+ batches.append(batch)
198
+ batch = []
199
+ if not self.drop_last and len(batch) > 0:
200
+ batches.append(batch)
201
+ selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
202
+ for index in selected_others:
203
+ if index + self.batch_size > len(self.other_indices):
204
+ index = len(self.other_indices) - self.batch_size
205
+ batch = [self.other_indices[index + i] for i in range(self.batch_size)]
206
+ batches.append(batch)
207
+ self.batches = batches
208
+ if self.shuffle:
209
+ self.batches = np.random.permutation(self.batches)
210
+ if self.rank == 0:
211
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
212
+
213
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
214
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
215
+ if len(self.batches) >= self.num_replicas:
216
+ self.batches = self.batches[self.rank::self.num_replicas]
217
+ else: # may happen in sanity checking
218
+ self.batches = [self.batches[0]]
219
+ if self.rank == 0:
220
+ print(f"after split batches_num {len(self.batches)}")
221
+
222
+ return self.batches
223
+
224
+ def __iter__(self) -> Iterator[List[int]]:
225
+ print(f"len(self.batches):{len(self.batches)}")
226
+ for batch in self.batches:
227
+ yield batch
228
+
229
+ def __len__(self) -> int:
230
+ return len(self.batches)
ldm/data/preprocess/NAT_mel.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ MAX_WAV_VALUE = 32768.0
10
+
11
+
12
+ def load_wav(full_path):
13
+ sampling_rate, data = read(full_path)
14
+ return data, sampling_rate
15
+
16
+
17
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
18
+ return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C)
19
+
20
+
21
+ def dynamic_range_decompression(x, C=1):
22
+ return np.exp(x) / C
23
+
24
+
25
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
26
+ return torch.log10(torch.clamp(x, min=clip_val) * C)
27
+
28
+
29
+ def dynamic_range_decompression_torch(x, C=1):
30
+ return torch.exp(x) / C
31
+
32
+
33
+ def spectral_normalize_torch(magnitudes):
34
+ output = dynamic_range_compression_torch(magnitudes)
35
+ return output
36
+
37
+
38
+ def spectral_de_normalize_torch(magnitudes):
39
+ output = dynamic_range_decompression_torch(magnitudes)
40
+ return output
41
+
42
+ class MelNet(nn.Module):
43
+ def __init__(self,hparams,device='cpu') -> None:
44
+ super().__init__()
45
+ self.n_fft = hparams['fft_size']
46
+ self.num_mels = hparams['audio_num_mel_bins']
47
+ self.sampling_rate = hparams['audio_sample_rate']
48
+ self.hop_size = hparams['hop_size']
49
+ self.win_size = hparams['win_size']
50
+ self.fmin = hparams['fmin']
51
+ self.fmax = hparams['fmax']
52
+ self.device = device
53
+
54
+ mel = librosa_mel_fn(self.sampling_rate, self.n_fft, self.num_mels, self.fmin, self.fmax)
55
+ self.mel_basis = torch.from_numpy(mel).float().to(self.device)
56
+ self.hann_window = torch.hann_window(self.win_size).to(self.device)
57
+
58
+ def to(self,device,**kwagrs):
59
+ super().to(device=device,**kwagrs)
60
+ self.mel_basis = self.mel_basis.to(device)
61
+ self.hann_window = self.hann_window.to(device)
62
+ self.device = device
63
+
64
+ def forward(self,y,center=False, complex=False):
65
+ if isinstance(y,np.ndarray):
66
+ y = torch.FloatTensor(y)
67
+ if len(y.shape) == 1:
68
+ y = y.unsqueeze(0)
69
+ y = y.clamp(min=-1., max=1.).to(self.device)
70
+
71
+ y = torch.nn.functional.pad(y.unsqueeze(1), [int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)],
72
+ mode='reflect')
73
+ y = y.squeeze(1)
74
+
75
+ spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
76
+ center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
77
+
78
+ if not complex:
79
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
80
+ spec = torch.matmul(self.mel_basis, spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ else:
83
+ B, C, T, _ = spec.shape
84
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
85
+ return spec
86
+
87
+ ## below can be used in one gpu, but not ddp
88
+ mel_basis = {}
89
+ hann_window = {}
90
+
91
+
92
+ def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len)
93
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
94
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
95
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
96
+ # fmax: 10000 # To be increased/reduced depending on data.
97
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
98
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
99
+ n_fft = hparams['fft_size']
100
+ num_mels = hparams['audio_num_mel_bins']
101
+ sampling_rate = hparams['audio_sample_rate']
102
+ hop_size = hparams['hop_size']
103
+ win_size = hparams['win_size']
104
+ fmin = hparams['fmin']
105
+ fmax = hparams['fmax']
106
+ if isinstance(y,np.ndarray):
107
+ y = torch.FloatTensor(y)
108
+ if len(y.shape) == 1:
109
+ y = y.unsqueeze(0)
110
+ y = y.clamp(min=-1., max=1.)
111
+ global mel_basis, hann_window
112
+ if fmax not in mel_basis:
113
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
114
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
115
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
116
+
117
+ y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)],
118
+ mode='reflect')
119
+ y = y.squeeze(1)
120
+
121
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
122
+ center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
123
+
124
+ if not complex:
125
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
126
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
127
+ spec = spectral_normalize_torch(spec)
128
+ else:
129
+ B, C, T, _ = spec.shape
130
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
131
+ return spec
ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc ADDED
Binary file (4.25 kB). View file
 
ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc ADDED
Binary file (4.23 kB). View file
 
ldm/data/preprocess/add_duration.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import audioread
3
+ from tqdm import tqdm
4
+ from tqdm.contrib.concurrent import process_map
5
+
6
+ def map_duration(tsv_withdur,tsv_toadd):# tsv_withdur 和 tsv_toadd 'name'列相同且tsv_withdur有duration信息,目标是给tsv_toadd的相同行加上duration信息。
7
+ df1 = pd.read_csv(tsv_withdur,sep='\t')
8
+ df2 = pd.read_csv(tsv_toadd,sep='\t')
9
+
10
+ df = df2.merge(df1,on=['name'],suffixes=['','_y'])
11
+ dropset = list(set(df.columns) - set(df1.columns))
12
+ df = df.drop(dropset,axis=1)
13
+ df.to_csv(tsv_toadd,sep='\t',index=False)
14
+ return df
15
+
16
+ def add_duration(args):
17
+ index,audiopath = args
18
+ try:
19
+ with audioread.audio_open(audiopath) as f:
20
+ totalsec = f.duration
21
+ except:
22
+ totalsec = -1
23
+ return (index,totalsec)
24
+
25
+ def add_dur2tsv(tsv_path,save_path):
26
+ df = pd.read_csv(tsv_path,sep='\t')
27
+ item_list = []
28
+ for item in tqdm(df.itertuples()):
29
+ item_list.append((item[0],getattr(item,'audio_path')))
30
+
31
+ r = process_map(add_duration,item_list,max_workers=16,chunksize=32)
32
+ index2dur = {}
33
+ for index,dur in r:
34
+ if dur == -1:
35
+ bad_wav = df.loc[index,'audio_path']
36
+ print(f'bad wav:{bad_wav}')
37
+ index2dur[index] = dur
38
+
39
+ df['duration'] = df.index.map(index2dur)
40
+ df.to_csv(save_path,sep='\t',index=False)
41
+
42
+ if __name__ == '__main__':
43
+ add_dur2tsv('/root/autodl-tmp/liuhuadai/AudioLCM/now.tsv','/root/autodl-tmp/liuhuadai/AudioLCM/now_duration.tsv')
44
+ #map_duration(tsv_withdur='tsv_maker/filter_audioset.tsv',
45
+ # tsv_toadd='MAA1 Dataset tsvs/V3/refilter_audioset.tsv')
ldm/data/preprocess/mel_spec.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ldm.data.preprocess.NAT_mel import MelNet
2
+ import os
3
+ from tqdm import tqdm
4
+ from glob import glob
5
+ import math
6
+ import pandas as pd
7
+ import logging
8
+ import math
9
+ import audioread
10
+ from tqdm.contrib.concurrent import process_map
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchaudio
14
+ import numpy as np
15
+ from torch.distributed import init_process_group
16
+ from torch.utils.data import Dataset,DataLoader,DistributedSampler
17
+ import torch.multiprocessing as mp
18
+ from argparse import Namespace
19
+ from multiprocessing import Pool
20
+ import json
21
+
22
+
23
+ class tsv_dataset(Dataset):
24
+ def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None:
25
+ super().__init__()
26
+ if os.path.isdir(tsv_path):
27
+ files = glob(os.path.join(tsv_path,'*.tsv'))
28
+ df = pd.concat([pd.read_csv(file,sep='\t') for file in files])
29
+ else:
30
+ df = pd.read_csv(tsv_path,sep='\t')
31
+ self.audio_paths = []
32
+ self.sr = sr
33
+ self.mode = mode
34
+ self.target_mel_length = target_mel_length
35
+ self.hop_size = hop_size
36
+ for t in tqdm(df.itertuples()):
37
+ self.audio_paths.append(getattr(t,'audio_path'))
38
+
39
+ def __len__(self):
40
+ return len(self.audio_paths)
41
+
42
+ def pad_wav(self,wav):
43
+ # wav should be in shape(1,wav_len)
44
+ wav_length = wav.shape[-1]
45
+ assert wav_length > 100, "wav is too short, %s" % wav_length
46
+ segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1]
47
+ if segment_length is None or wav_length == segment_length:
48
+ return wav
49
+ elif wav_length > segment_length:
50
+ return wav[:,:segment_length]
51
+ elif wav_length < segment_length:
52
+ temp_wav = torch.zeros((1, segment_length),dtype=torch.float32)
53
+ temp_wav[:, :wav_length] = wav
54
+ return temp_wav
55
+
56
+
57
+ def __getitem__(self, index):
58
+ audio_path = self.audio_paths[index]
59
+ wav, orisr = torchaudio.load(audio_path)
60
+ if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len)
61
+ wav = wav.mean(0,keepdim=True)
62
+ wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr)
63
+ if self.mode == 'pad':
64
+ assert self.target_mel_length is not None
65
+ wav = self.pad_wav(wav)
66
+ return audio_path,wav
67
+
68
+ def process_audio_by_tsv(rank,args):
69
+ if args.num_gpus > 1:
70
+ init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'],
71
+ world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank)
72
+
73
+ sr = args.audio_sample_rate
74
+ dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length)
75
+ sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None
76
+ # batch_size must == 1,since wav_len is not equal
77
+ loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False)
78
+
79
+ device = torch.device('cuda:{:d}'.format(rank))
80
+
81
+ mel_net = MelNet(args.__dict__)
82
+ mel_net.to(device)
83
+ # if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
84
+ # mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device)
85
+
86
+ loader = tqdm(loader) if rank == 0 else loader
87
+ for batch in loader:
88
+ audio_paths,wavs = batch
89
+ wavs = wavs.to(device)
90
+ if args.save_resample:
91
+ for audio_path,wav in zip(audio_paths,wavs):
92
+ psplits = audio_path.split('/')
93
+ root,wav_name = psplits[0],psplits[-1]
94
+ # save resample
95
+ resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy'
96
+ resample_dir_name = os.path.join(resample_root,*psplits[1:-1])
97
+ resample_path = os.path.join(resample_dir_name,resample_name)
98
+ os.makedirs(resample_dir_name,exist_ok=True)
99
+ np.save(resample_path,wav.cpu().numpy().squeeze(0))
100
+
101
+ if args.save_mel:
102
+ mode = args.mode
103
+ batch_max_length = args.batch_max_length
104
+
105
+ for audio_path,wav in zip(audio_paths,wavs):
106
+ psplits = audio_path.split('/')
107
+ root,wav_name = psplits[0],psplits[-1]
108
+ mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy'
109
+ mel_dir_name = os.path.join(mel_root,*psplits[1:-1])
110
+ mel_path = os.path.join(mel_dir_name,mel_name)
111
+ if not os.path.exists(mel_path):
112
+ mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len)
113
+ if mel_spec.shape[1] <= batch_max_length:
114
+ if mode == 'tile': # pad is done in dataset as pad wav
115
+ n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1])
116
+ mel_spec = np.tile(mel_spec,reps=(1,n_repeat))
117
+ elif mode == 'none' or mode == 'pad':
118
+ pass
119
+ else:
120
+ raise ValueError(f'mode:{mode} is not supported')
121
+ mel_spec = mel_spec[:,:batch_max_length]
122
+ os.makedirs(mel_dir_name,exist_ok=True)
123
+ np.save(mel_path,mel_spec)
124
+
125
+
126
+ def split_list(i_list,num):
127
+ each_num = math.ceil(i_list / num)
128
+ result = []
129
+ for i in range(num):
130
+ s = each_num * i
131
+ e = (each_num * (i+1))
132
+ result.append(i_list[s:e])
133
+ return result
134
+
135
+
136
+ def drop_bad_wav(item):
137
+ index,path = item
138
+ try:
139
+ with audioread.audio_open(path) as f:
140
+ totalsec = f.duration
141
+ if totalsec < 0.1:
142
+ return index # index
143
+ except:
144
+ print(f"corrupted wav:{path}")
145
+ return index
146
+ return False
147
+
148
+ def drop_bad_wavs(tsv_path):# 'audioset.csv'
149
+ df = pd.read_csv(tsv_path,sep='\t')
150
+ item_list = []
151
+ for item in tqdm(df.itertuples()):
152
+ item_list.append((item[0],getattr(item,'audio_path')))
153
+
154
+ r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16)
155
+ bad_indices = list(filter(lambda x:x!= False,r))
156
+
157
+ print(bad_indices)
158
+ with open('bad_wavs.json','w') as f:
159
+ x = [item_list[i] for i in bad_indices]
160
+ json.dump(x,f)
161
+ df = df.drop(bad_indices,axis=0)
162
+ df.to_csv(tsv_path,sep='\t',index=False)
163
+
164
+ if __name__ == '__main__':
165
+ logging.basicConfig(filename='example.log', level=logging.INFO,
166
+ format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
167
+ tsv_path = './musiccap.tsv'
168
+ if os.path.isdir(tsv_path):
169
+ files = glob(os.path.join(tsv_path,'*.tsv'))
170
+ for file in files:
171
+ drop_bad_wavs(file)
172
+ else:
173
+ drop_bad_wavs(tsv_path)
174
+ num_gpus = 1
175
+ args = {
176
+ 'audio_sample_rate': 16000,
177
+ 'audio_num_mel_bins':80,
178
+ 'fft_size': 1024,# 4000:512 ,16000:1024,
179
+ 'win_size': 1024,
180
+ 'hop_size': 256,
181
+ 'fmin': 0,
182
+ 'fmax': 8000,
183
+ 'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 #
184
+ 'tsv_path': tsv_path,
185
+ 'num_gpus': num_gpus,
186
+ 'mode': 'none',
187
+ 'save_resample':False,
188
+ 'save_mel' :True
189
+ }
190
+ args = Namespace(**args)
191
+ args.dist_config = {
192
+ "dist_backend": "nccl",
193
+ "dist_url": "tcp://localhost:54189",
194
+ "world_size": 1
195
+ }
196
+ if args.num_gpus>1:
197
+ mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,))
198
+ else:
199
+ process_audio_by_tsv(0,args=args)
200
+ print("done")
201
+
ldm/data/test.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ from typing import TypeVar, Optional, Iterator
5
+ import logging
6
+ import pandas as pd
7
+ from ldm.data.joinaudiodataset_anylen import *
8
+ import glob
9
+ logger = logging.getLogger(f'main.{__name__}')
10
+
11
+ sys.path.insert(0, '.') # nopep8
12
+
13
+ class JoinManifestSpecs(torch.utils.data.Dataset):
14
+ def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
15
+ super().__init__()
16
+ self.split = split
17
+ self.max_batch_len = spec_crop_len
18
+ self.min_batch_len = 64
19
+ self.min_factor = 4
20
+ self.mel_num = mel_num
21
+ self.drop = drop
22
+ self.pad_value = pad_value
23
+ assert mode in ['pad','tile']
24
+ self.collate_mode = mode
25
+ manifest_files = []
26
+ for dir_path in main_spec_dir_path.split(','):
27
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
28
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
29
+ self.df_main = pd.concat(df_list,ignore_index=True)
30
+
31
+ manifest_files = []
32
+ for dir_path in other_spec_dir_path.split(','):
33
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
34
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
35
+ self.df_other = pd.concat(df_list,ignore_index=True)
36
+ self.df_other.reset_index(inplace=True)
37
+
38
+ if split == 'train':
39
+ self.dataset = self.df_main.iloc[100:]
40
+ elif split == 'valid' or split == 'val':
41
+ self.dataset = self.df_main.iloc[:100]
42
+ elif split == 'test':
43
+ self.df_main = self.add_name_num(self.df_main)
44
+ self.dataset = self.df_main
45
+ else:
46
+ raise ValueError(f'Unknown split {split}')
47
+ self.dataset.reset_index(inplace=True)
48
+ print('dataset len:', len(self.dataset),"drop_rate",self.drop)
49
+
50
+ def add_name_num(self,df):
51
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
52
+ name_count_dict = {}
53
+ change = []
54
+ for t in df.itertuples():
55
+ name = getattr(t,'name')
56
+ if name in name_count_dict:
57
+ name_count_dict[name] += 1
58
+ else:
59
+ name_count_dict[name] = 0
60
+ change.append((t[0],name_count_dict[name]))
61
+ for t in change:
62
+ df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
63
+ return df
64
+
65
+ def ordered_indices(self):
66
+ index2dur = self.dataset[['duration']].sort_values(by='duration')
67
+ index2dur_other = self.df_other[['duration']].sort_values(by='duration')
68
+ other_indices = list(index2dur_other.index)
69
+ offset = len(self.dataset)
70
+ other_indices = [x + offset for x in other_indices]
71
+ return list(index2dur.index),other_indices
72
+
73
+ def collater(self,inputs):
74
+ to_dict = {}
75
+ for l in inputs:
76
+ for k,v in l.items():
77
+ if k in to_dict:
78
+ to_dict[k].append(v)
79
+ else:
80
+ to_dict[k] = [v]
81
+
82
+ if self.collate_mode == 'pad':
83
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
84
+ elif self.collate_mode == 'tile':
85
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
86
+ else:
87
+ raise NotImplementedError
88
+ to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
89
+ 'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
90
+
91
+ return to_dict
92
+
93
+ def __getitem__(self, idx):
94
+ if idx < len(self.dataset):
95
+ data = self.dataset.iloc[idx]
96
+ p = np.random.uniform(0,1)
97
+ if p > self.drop:
98
+ ori_caption = data['ori_cap']
99
+ struct_caption = data['caption']
100
+ else:
101
+ ori_caption = ""
102
+ struct_caption = ""
103
+ else:
104
+ data = self.df_other.iloc[idx-len(self.dataset)]
105
+ p = np.random.uniform(0,1)
106
+ if p > self.drop:
107
+ ori_caption = data['caption']
108
+ struct_caption = f'<{ori_caption}& all>'
109
+ else:
110
+ ori_caption = ""
111
+ struct_caption = ""
112
+ item = {}
113
+ try:
114
+ spec = np.load(data['mel_path']) # mel spec [80, T]
115
+ if spec.shape[1] > self.max_batch_len:
116
+ spec = spec[:,:self.max_batch_len]
117
+ except:
118
+ mel_path = data['mel_path']
119
+ print(f'corrupted:{mel_path}')
120
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
121
+
122
+ item['image'] = spec
123
+ item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
124
+ if self.split == 'test':
125
+ item['f_name'] = data['name']
126
+ return item
127
+
128
+ def __len__(self):
129
+ return len(self.dataset) + len(self.df_other)
130
+
131
+
132
+ class JoinSpecsTrain(JoinManifestSpecs):
133
+ def __init__(self, specs_dataset_cfg):
134
+ super().__init__('train', **specs_dataset_cfg)
135
+
136
+ class JoinSpecsValidation(JoinManifestSpecs):
137
+ def __init__(self, specs_dataset_cfg):
138
+ super().__init__('valid', **specs_dataset_cfg)
139
+
140
+ class JoinSpecsTest(JoinManifestSpecs):
141
+ def __init__(self, specs_dataset_cfg):
142
+ super().__init__('test', **specs_dataset_cfg)
143
+
144
+
145
+
146
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
147
+ def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
148
+ rank: Optional[int] = None, shuffle: bool = True,
149
+ seed: int = 0, drop_last: bool = False) -> None:
150
+ if num_replicas is None:
151
+ if not dist.is_initialized():
152
+ # raise RuntimeError("Requires distributed package to be available")
153
+ print("Not in distributed mode")
154
+ num_replicas = 1
155
+ else:
156
+ num_replicas = dist.get_world_size()
157
+ if rank is None:
158
+ if not dist.is_initialized():
159
+ # raise RuntimeError("Requires distributed package to be available")
160
+ rank = 0
161
+ else:
162
+ rank = dist.get_rank()
163
+ if rank >= num_replicas or rank < 0:
164
+ raise ValueError(
165
+ "Invalid rank {}, rank should be in the interval"
166
+ " [0, {}]".format(rank, num_replicas - 1))
167
+ self.main_indices = main_indices
168
+ self.other_indices = other_indices
169
+ self.max_index = max(self.other_indices)
170
+ self.num_replicas = num_replicas
171
+ self.rank = rank
172
+ self.epoch = 0
173
+ self.drop_last = drop_last
174
+ self.batch_size = batch_size
175
+ self.shuffle = shuffle
176
+ self.batches = self.build_batches()
177
+ self.seed = seed
178
+
179
+ def set_epoch(self,epoch):
180
+ # print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
181
+ self.epoch = epoch
182
+ if self.shuffle:
183
+ np.random.seed(self.seed+self.epoch)
184
+ self.batches = self.build_batches()
185
+
186
+ def build_batches(self):
187
+ batches,batch = [],[]
188
+ for index in self.main_indices:
189
+ batch.append(index)
190
+ if len(batch) == self.batch_size:
191
+ batches.append(batch)
192
+ batch = []
193
+ if not self.drop_last and len(batch) > 0:
194
+ batches.append(batch)
195
+ selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
196
+ for index in selected_others:
197
+ if index + self.batch_size > len(self.other_indices):
198
+ index = len(self.other_indices) - self.batch_size
199
+ batch = [self.other_indices[index + i] for i in range(self.batch_size)]
200
+ batches.append(batch)
201
+ self.batches = batches
202
+ if self.shuffle:
203
+ self.batches = np.random.permutation(self.batches)
204
+ if self.rank == 0:
205
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
206
+
207
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
208
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
209
+ if len(self.batches) >= self.num_replicas:
210
+ self.batches = self.batches[self.rank::self.num_replicas]
211
+ else: # may happen in sanity checking
212
+ self.batches = [self.batches[0]]
213
+ if self.rank == 0:
214
+ print(f"after split batches_num {len(self.batches)}")
215
+
216
+ return self.batches
217
+
218
+ def __iter__(self) -> Iterator[List[int]]:
219
+ print(f"len(self.batches):{len(self.batches)}")
220
+ for batch in self.batches:
221
+ yield batch
222
+
223
+ def __len__(self) -> int:
224
+ return len(self.batches)
ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/V2/MACS.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/V2/adobe.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/V2/audiostock.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc67e42c9defa98edfc2c6b23c731fafa4a22307fddfd1fb95ccfc00d0168951
3
+ size 15062608
ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsv_dirs/full_data/clotho.tsv ADDED
The diff for this file is too large to render. See raw diff
 
ldm/data/tsvdataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ class TSVDataset(Dataset):
7
+ def __init__(self, tsv_path, spec_crop_len=None):
8
+ super().__init__()
9
+ self.batch_max_length = spec_crop_len
10
+ self.batch_min_length = 50
11
+ df = pd.read_csv(tsv_path,sep='\t')
12
+ df = self.add_name_num(df)
13
+ self.dataset = df
14
+ print('dataset len:', len(self.dataset))
15
+
16
+ def add_name_num(self,df):
17
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
18
+ name_count_dict = {}
19
+ change = []
20
+ for t in df.itertuples():
21
+ name = getattr(t,'name')
22
+ if name in name_count_dict:
23
+ name_count_dict[name] += 1
24
+ else:
25
+ name_count_dict[name] = 0
26
+ change.append((t[0],name_count_dict[name]))
27
+ for t in change:
28
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
29
+ return df
30
+
31
+
32
+ def __getitem__(self, idx):
33
+ data = self.dataset.iloc[idx]
34
+ item = {}
35
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
36
+ if spec.shape[1] <= self.batch_max_length:
37
+ spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
38
+
39
+ item['image'] = spec
40
+ item["caption"] = data['caption']
41
+ item["f_name"] = data['name']
42
+ return item
43
+
44
+ def __len__(self):
45
+ return len(self.dataset)
46
+
47
+ class TSVDatasetStruct(TSVDataset):
48
+ def __getitem__(self, idx):
49
+ data = self.dataset.iloc[idx]
50
+ item = {}
51
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
52
+ if spec.shape[1] <= self.batch_max_length:
53
+ spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
54
+
55
+ item['image'] = spec[:,:self.batch_max_length]
56
+ item["caption"] = {'ori_caption':data['ori_cap'],'struct_caption':data['caption']}
57
+ item["f_name"] = data['name']
58
+ return item
59
+
60
+ class TSVDatasetTestFake(TSVDataset):
61
+ def __init__(self, specs_dataset_cfg):
62
+ super().__init__(phase='test', **specs_dataset_cfg)
63
+ self.dataset = [self.dataset[0]]
64
+
65
+
66
+
67
+
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/__pycache__/autoencoder.cpython-37.pyc ADDED
Binary file (15.6 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (15.5 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-39.pyc ADDED
Binary file (14.9 kB). View file
 
ldm/models/__pycache__/autoencoder1d.cpython-37.pyc ADDED
Binary file (13.5 kB). View file
 
ldm/models/__pycache__/autoencoder1d.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ from contextlib import contextmanager
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+ from packaging import version
8
+ import numpy as np
9
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
10
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from ldm.util import instantiate_from_config
13
+ from icecream import ic
14
+
15
+ class VQModel(pl.LightningModule):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ n_embed,
20
+ embed_dim,
21
+ ckpt_path=None,
22
+ ignore_keys=[],
23
+ image_key="image",
24
+ colorize_nlabels=None,
25
+ monitor=None,
26
+ batch_resize_range=None,
27
+ scheduler_config=None,
28
+ lr_g_factor=1.0,
29
+ remap=None,
30
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
31
+ use_ema=False
32
+ ):
33
+ super().__init__()
34
+ self.embed_dim = embed_dim
35
+ self.n_embed = n_embed
36
+ self.image_key = image_key
37
+ self.encoder = Encoder(**ddconfig)
38
+ self.decoder = Decoder(**ddconfig)
39
+ self.loss = instantiate_from_config(lossconfig)
40
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
41
+ remap=remap,
42
+ sane_index_shape=sane_index_shape)
43
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
44
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
45
+ if colorize_nlabels is not None:
46
+ assert type(colorize_nlabels)==int
47
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
48
+ if monitor is not None:
49
+ self.monitor = monitor
50
+ self.batch_resize_range = batch_resize_range
51
+ if self.batch_resize_range is not None:
52
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
53
+
54
+ self.use_ema = use_ema
55
+ if self.use_ema:
56
+ self.model_ema = LitEma(self)
57
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
58
+
59
+ if ckpt_path is not None:
60
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
61
+ self.scheduler_config = scheduler_config
62
+ self.lr_g_factor = lr_g_factor
63
+
64
+ @contextmanager
65
+ def ema_scope(self, context=None):
66
+ if self.use_ema:
67
+ self.model_ema.store(self.parameters())
68
+ self.model_ema.copy_to(self)
69
+ if context is not None:
70
+ print(f"{context}: Switched to EMA weights")
71
+ try:
72
+ yield None
73
+ finally:
74
+ if self.use_ema:
75
+ self.model_ema.restore(self.parameters())
76
+ if context is not None:
77
+ print(f"{context}: Restored training weights")
78
+
79
+ def init_from_ckpt(self, path, ignore_keys=list()):
80
+ sd = torch.load(path, map_location="cpu")["state_dict"]
81
+ keys = list(sd.keys())
82
+ for k in keys:
83
+ for ik in ignore_keys:
84
+ if k.startswith(ik):
85
+ print("Deleting key {} from state_dict.".format(k))
86
+ del sd[k]
87
+ missing, unexpected = self.load_state_dict(sd, strict=False)
88
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
89
+ if len(missing) > 0:
90
+ print(f"Missing Keys: {missing}")
91
+ print(f"Unexpected Keys: {unexpected}")
92
+
93
+ def on_train_batch_end(self, *args, **kwargs):
94
+ if self.use_ema:
95
+ self.model_ema(self)
96
+
97
+ def encode(self, x):
98
+ h = self.encoder(x)
99
+ h = self.quant_conv(h)
100
+ quant, emb_loss, info = self.quantize(h)
101
+ return quant, emb_loss, info
102
+
103
+ def encode_to_prequant(self, x):
104
+ h = self.encoder(x)
105
+ h = self.quant_conv(h)
106
+ return h
107
+
108
+ def decode(self, quant):
109
+ quant = self.post_quant_conv(quant)
110
+ dec = self.decoder(quant)
111
+ return dec
112
+
113
+ def decode_code(self, code_b):
114
+ quant_b = self.quantize.embed_code(code_b)
115
+ dec = self.decode(quant_b)
116
+ return dec
117
+
118
+ def forward(self, input, return_pred_indices=False):
119
+ quant, diff, (_,_,ind) = self.encode(input)
120
+ dec = self.decode(quant)
121
+ if return_pred_indices:
122
+ return dec, diff, ind
123
+ return dec, diff
124
+
125
+ def get_input(self, batch, k):
126
+ x = batch[k]
127
+ if len(x.shape) == 3:
128
+ x = x[..., None]
129
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
130
+ if self.batch_resize_range is not None:
131
+ lower_size = self.batch_resize_range[0]
132
+ upper_size = self.batch_resize_range[1]
133
+ if self.global_step <= 4:
134
+ # do the first few batches with max size to avoid later oom
135
+ new_resize = upper_size
136
+ else:
137
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
138
+ if new_resize != x.shape[2]:
139
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
140
+ x = x.detach()
141
+ return x
142
+
143
+ def training_step(self, batch, batch_idx, optimizer_idx):
144
+ # https://github.com/pytorch/pytorch/issues/37142
145
+ # try not to fool the heuristics
146
+ x = self.get_input(batch, self.image_key)
147
+ xrec, qloss, ind = self(x, return_pred_indices=True)
148
+
149
+ if optimizer_idx == 0:
150
+ # autoencode
151
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
152
+ last_layer=self.get_last_layer(), split="train",
153
+ predicted_indices=ind)
154
+
155
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
156
+ return aeloss
157
+
158
+ if optimizer_idx == 1:
159
+ # discriminator
160
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
161
+ last_layer=self.get_last_layer(), split="train")
162
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
163
+ return discloss
164
+
165
+ def validation_step(self, batch, batch_idx):
166
+ log_dict = self._validation_step(batch, batch_idx)
167
+ with self.ema_scope():
168
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
169
+ return log_dict
170
+
171
+ def _validation_step(self, batch, batch_idx, suffix=""):
172
+ x = self.get_input(batch, self.image_key)
173
+ xrec, qloss, ind = self(x, return_pred_indices=True)
174
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
175
+ self.global_step,
176
+ last_layer=self.get_last_layer(),
177
+ split="val"+suffix,
178
+ predicted_indices=ind
179
+ )
180
+
181
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
182
+ self.global_step,
183
+ last_layer=self.get_last_layer(),
184
+ split="val"+suffix,
185
+ predicted_indices=ind
186
+ )
187
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
188
+ self.log(f"val{suffix}/rec_loss", rec_loss,
189
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
190
+ self.log(f"val{suffix}/aeloss", aeloss,
191
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
192
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
193
+ del log_dict_ae[f"val{suffix}/rec_loss"]
194
+ self.log_dict(log_dict_ae)
195
+ self.log_dict(log_dict_disc)
196
+ return self.log_dict
197
+
198
+ def test_step(self, batch, batch_idx):
199
+ x = self.get_input(batch, self.image_key)
200
+ xrec, qloss, ind = self(x, return_pred_indices=True)
201
+ reconstructions = (xrec + 1)/2 # to mel scale
202
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
203
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
204
+ if not os.path.exists(savedir):
205
+ os.makedirs(savedir)
206
+
207
+ file_names = batch['f_name']
208
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
209
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
210
+ for b in range(reconstructions.shape[0]):
211
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
212
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
213
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
214
+ np.save(save_img_path,reconstructions[b])
215
+
216
+ return None
217
+
218
+ def configure_optimizers(self):
219
+ lr_d = self.learning_rate
220
+ lr_g = self.lr_g_factor*self.learning_rate
221
+ print("lr_d", lr_d)
222
+ print("lr_g", lr_g)
223
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
224
+ list(self.decoder.parameters())+
225
+ list(self.quantize.parameters())+
226
+ list(self.quant_conv.parameters())+
227
+ list(self.post_quant_conv.parameters()),
228
+ lr=lr_g, betas=(0.5, 0.9))
229
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
230
+ lr=lr_d, betas=(0.5, 0.9))
231
+
232
+ if self.scheduler_config is not None:
233
+ scheduler = instantiate_from_config(self.scheduler_config)
234
+
235
+ print("Setting up LambdaLR scheduler...")
236
+ scheduler = [
237
+ {
238
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
239
+ 'interval': 'step',
240
+ 'frequency': 1
241
+ },
242
+ {
243
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
244
+ 'interval': 'step',
245
+ 'frequency': 1
246
+ },
247
+ ]
248
+ return [opt_ae, opt_disc], scheduler
249
+ return [opt_ae, opt_disc], []
250
+
251
+ def get_last_layer(self):
252
+ return self.decoder.conv_out.weight
253
+
254
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
255
+ log = dict()
256
+ x = self.get_input(batch, self.image_key)
257
+ x = x.to(self.device)
258
+ if only_inputs:
259
+ log["inputs"] = x
260
+ return log
261
+ xrec, _ = self(x)
262
+ if x.shape[1] > 3:
263
+ # colorize with random projection
264
+ assert xrec.shape[1] > 3
265
+ x = self.to_rgb(x)
266
+ xrec = self.to_rgb(xrec)
267
+ log["inputs"] = x
268
+ log["reconstructions"] = xrec
269
+ if plot_ema:
270
+ with self.ema_scope():
271
+ xrec_ema, _ = self(x)
272
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
273
+ log["reconstructions_ema"] = xrec_ema
274
+ return log
275
+
276
+ def to_rgb(self, x):
277
+ assert self.image_key == "segmentation"
278
+ if not hasattr(self, "colorize"):
279
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
280
+ x = F.conv2d(x, weight=self.colorize)
281
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
282
+ return x
283
+
284
+
285
+ class VQModelInterface(VQModel):
286
+ def __init__(self, embed_dim, *args, **kwargs):
287
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
288
+ self.embed_dim = embed_dim
289
+
290
+ def encode(self, x):# VQModel的quantize写在encoder里,VQModelInterface则将其写在decoder里
291
+ h = self.encoder(x)
292
+ h = self.quant_conv(h)
293
+ return h
294
+
295
+ def decode(self, h, force_not_quantize=False):
296
+ # also go through quantization layer
297
+ if not force_not_quantize:
298
+ quant, emb_loss, info = self.quantize(h)
299
+ else:
300
+ quant = h
301
+ quant = self.post_quant_conv(quant)
302
+ dec = self.decoder(quant)
303
+ return dec
304
+
305
+
306
+ class AutoencoderKL(pl.LightningModule):
307
+ def __init__(self,
308
+ ddconfig,
309
+ lossconfig,
310
+ embed_dim,
311
+ ckpt_path=None,
312
+ ignore_keys=[],
313
+ image_key="image",
314
+ colorize_nlabels=None,
315
+ monitor=None,
316
+ ):
317
+ super().__init__()
318
+ self.to_1d = False
319
+ print(f"to_1d is {self.to_1d} in AUTOENCODER")
320
+ self.image_key = image_key
321
+ self.encoder = Encoder(**ddconfig)
322
+ self.decoder = Decoder(**ddconfig)
323
+ self.loss = instantiate_from_config(lossconfig)
324
+ assert ddconfig["double_z"]
325
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
326
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
327
+ self.embed_dim = embed_dim
328
+ if colorize_nlabels is not None:
329
+ assert type(colorize_nlabels)==int
330
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
331
+ if monitor is not None:
332
+ self.monitor = monitor
333
+ if ckpt_path is not None:
334
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
335
+ # self.automatic_optimization = False # hjw for debug
336
+
337
+ def init_from_ckpt(self, path, ignore_keys=list()):
338
+ sd = torch.load(path, map_location="cpu")["state_dict"]
339
+ keys = list(sd.keys())
340
+ for k in keys:
341
+ for ik in ignore_keys:
342
+ if k.startswith(ik):
343
+ print("Deleting key {} from state_dict.".format(k))
344
+ del sd[k]
345
+ self.load_state_dict(sd, strict=False)
346
+ print(f"Restored from {path}")
347
+
348
+ def encode(self, x):
349
+ if self.to_1d and len(x.shape)==3:
350
+ x = x.unsqueeze(1)
351
+ h = self.encoder(x)
352
+ moments = self.quant_conv(h)
353
+ if self.to_1d:
354
+ b,c,h,w = moments.shape
355
+ moments = moments.reshape(b,c*h,w)
356
+ posterior = DiagonalGaussianDistribution(moments)
357
+ return posterior
358
+
359
+ def decode(self, z):
360
+ if self.to_1d:
361
+ b,c_h,w = z.shape
362
+ c = self.post_quant_conv.in_channels
363
+ z = z.reshape(b,c,-1,w)
364
+ z = self.post_quant_conv(z)
365
+ dec = self.decoder(z)
366
+ return dec
367
+
368
+ def forward(self, input, sample_posterior=True):
369
+ posterior = self.encode(input)
370
+ if sample_posterior:
371
+ z = posterior.sample()
372
+ else:
373
+ z = posterior.mode()
374
+ dec = self.decode(z)
375
+ return dec, posterior
376
+
377
+ def get_input(self, batch, k):
378
+ x = batch[k]
379
+ if len(x.shape) == 3:
380
+ x = x[..., None]
381
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
382
+ return x
383
+
384
+ def training_step(self, batch, batch_idx, optimizer_idx):
385
+ inputs = self.get_input(batch, self.image_key)
386
+ reconstructions, posterior = self(inputs)
387
+
388
+ if optimizer_idx == 0:
389
+ # train encoder+decoder+logvar
390
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
391
+ last_layer=self.get_last_layer(), split="train")
392
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
393
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
394
+ # print(optimizer_idx,log_dict_ae)
395
+ return aeloss
396
+
397
+ if optimizer_idx == 1:
398
+ # train the discriminator
399
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
400
+ last_layer=self.get_last_layer(), split="train")
401
+
402
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
403
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
404
+ # print(optimizer_idx,log_dict_disc)
405
+ return discloss
406
+
407
+ def validation_step(self, batch, batch_idx):
408
+ inputs = self.get_input(batch, self.image_key)
409
+ reconstructions, posterior = self(inputs)
410
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
411
+ last_layer=self.get_last_layer(), split="val")
412
+
413
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
414
+ last_layer=self.get_last_layer(), split="val")
415
+
416
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
417
+ self.log_dict(log_dict_ae)
418
+ self.log_dict(log_dict_disc)
419
+ return self.log_dict
420
+
421
+ def test_step(self, batch, batch_idx):
422
+ inputs = self.get_input(batch, self.image_key)# inputs shape:(b,mel_len,T)
423
+ reconstructions, posterior = self(inputs)# reconstructions:(b,mel_len,T)
424
+ mse_loss = torch.nn.functional.mse_loss(reconstructions,inputs)
425
+ self.log('test/mse_loss',mse_loss)
426
+
427
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
428
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
429
+ if batch_idx == 0:
430
+ print(f"save_path is: {savedir}")
431
+ if not os.path.exists(savedir):
432
+ os.makedirs(savedir)
433
+ print(f"save_path is: {savedir}")
434
+
435
+ file_names = batch['f_name']
436
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
437
+ # reconstructions = (reconstructions + 1)/2 # to mel scale
438
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squeeze channel dim
439
+ for b in range(reconstructions.shape[0]):
440
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
441
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
442
+ save_img_path = os.path.join(savedir, f'{v_n}.npy') # f'{v_n}_sample_{num}.npy' f'{v_n}.npy'
443
+ np.save(save_img_path,reconstructions[b])
444
+
445
+ return None
446
+
447
+ def configure_optimizers(self):
448
+ lr = self.learning_rate
449
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
450
+ list(self.decoder.parameters())+
451
+ list(self.quant_conv.parameters())+
452
+ list(self.post_quant_conv.parameters()),
453
+ lr=lr, betas=(0.5, 0.9))
454
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
455
+ lr=lr, betas=(0.5, 0.9))
456
+ return [opt_ae, opt_disc], []
457
+
458
+ def get_last_layer(self):
459
+ return self.decoder.conv_out.weight
460
+
461
+ @torch.no_grad()
462
+ def log_images(self, batch, only_inputs=False,save_dir = 'mel_result_ae13_26_debug/fake_class', **kwargs): # 在main.py的on_validation_batch_end中调用
463
+ log = dict()
464
+ x = self.get_input(batch, self.image_key)
465
+ x = x.to(self.device)
466
+ if not only_inputs:
467
+ xrec, posterior = self(x)
468
+ if x.shape[1] > 3:
469
+ # colorize with random projection
470
+ assert xrec.shape[1] > 3
471
+ x = self.to_rgb(x)
472
+ xrec = self.to_rgb(xrec)
473
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
474
+ log["reconstructions"] = xrec
475
+ log["inputs"] = x
476
+ return log
477
+
478
+ def to_rgb(self, x):
479
+ assert self.image_key == "segmentation"
480
+ if not hasattr(self, "colorize"):
481
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
482
+ x = F.conv2d(x, weight=self.colorize)
483
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
484
+ return x
485
+
486
+
487
+ class IdentityFirstStage(torch.nn.Module):
488
+ def __init__(self, *args, vq_interface=False, **kwargs):
489
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
490
+ super().__init__()
491
+
492
+ def encode(self, x, *args, **kwargs):
493
+ return x
494
+
495
+ def decode(self, x, *args, **kwargs):
496
+ return x
497
+
498
+ def quantize(self, x, *args, **kwargs):
499
+ if self.vq_interface:
500
+ return x, None, [None, None, None]
501
+ return x
502
+
503
+ def forward(self, x, *args, **kwargs):
504
+ return x