Respair commited on
Commit
bcdb559
·
verified ·
1 Parent(s): 12073cc

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. .gradio/certificate.pem +31 -0
  3. Configs/config.yml +116 -0
  4. Configs/config_ft.yml +116 -0
  5. Configs/config_kanade.yml +118 -0
  6. Inference/infer_24khz_mod.ipynb +0 -0
  7. Inference/input_for_prompt.txt +4 -0
  8. Inference/prompt.txt +4 -0
  9. Inference/random_texts.txt +14 -0
  10. LICENSE +25 -0
  11. Models/Style_Tsukasa_v02/Top_ckpt_24khz.pth +3 -0
  12. Models/Style_Tsukasa_v02/config_kanade.yml +121 -0
  13. Modules/KotoDama_sampler.py +269 -0
  14. Modules/__init__.py +1 -0
  15. Modules/__pycache__/KotoDama_sampler.cpython-311.pyc +0 -0
  16. Modules/__pycache__/__init__.cpython-311.pyc +0 -0
  17. Modules/__pycache__/discriminators.cpython-311.pyc +0 -0
  18. Modules/__pycache__/hifigan.cpython-311.pyc +0 -0
  19. Modules/__pycache__/istftnet.cpython-311.pyc +0 -0
  20. Modules/__pycache__/slmadv.cpython-311.pyc +0 -0
  21. Modules/__pycache__/utils.cpython-311.pyc +0 -0
  22. Modules/diffusion/__init__.py +1 -0
  23. Modules/diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  24. Modules/diffusion/__pycache__/diffusion.cpython-311.pyc +0 -0
  25. Modules/diffusion/__pycache__/modules.cpython-311.pyc +0 -0
  26. Modules/diffusion/__pycache__/sampler.cpython-311.pyc +0 -0
  27. Modules/diffusion/__pycache__/utils.cpython-311.pyc +0 -0
  28. Modules/diffusion/audio_diffusion_pytorch/__init__.py +20 -0
  29. Modules/diffusion/audio_diffusion_pytorch/__pycache__/__init__.cpython-311.pyc +0 -0
  30. Modules/diffusion/audio_diffusion_pytorch/__pycache__/components.cpython-311.pyc +0 -0
  31. Modules/diffusion/audio_diffusion_pytorch/__pycache__/diffusion.cpython-311.pyc +0 -0
  32. Modules/diffusion/audio_diffusion_pytorch/__pycache__/models.cpython-311.pyc +0 -0
  33. Modules/diffusion/audio_diffusion_pytorch/__pycache__/utils.cpython-311.pyc +0 -0
  34. Modules/diffusion/audio_diffusion_pytorch/components.py +236 -0
  35. Modules/diffusion/audio_diffusion_pytorch/diffusion.py +354 -0
  36. Modules/diffusion/audio_diffusion_pytorch/models.py +250 -0
  37. Modules/diffusion/audio_diffusion_pytorch/utils.py +125 -0
  38. Modules/diffusion/diffusion.py +94 -0
  39. Modules/diffusion/modules.py +693 -0
  40. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.github/workflows/python-publish.yml +39 -0
  41. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.gitignore +2 -0
  42. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.pre-commit-config.yaml +41 -0
  43. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LICENSE +21 -0
  44. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LOGO.png +0 -0
  45. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/README.md +251 -0
  46. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/setup.py +29 -0
  47. Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/tests/testcustomloss.py +39 -0
  48. Modules/diffusion/sampler.py +691 -0
  49. Modules/diffusion/utils.py +82 -0
  50. Modules/discriminators.py +188 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
37
+ Utils/PLBERT/step_1050000.t7 filter=lfs diff=lfs merge=lfs -text
38
+ reference_sample_wavs/01008270.wav filter=lfs diff=lfs merge=lfs -text
39
+ reference_sample_wavs/kaede_san.wav filter=lfs diff=lfs merge=lfs -text
40
+ reference_sample_wavs/riamu_zeroshot_02.wav filter=lfs diff=lfs merge=lfs -text
41
+ reference_sample_wavs/sample_ref01.wav filter=lfs diff=lfs merge=lfs -text
42
+ reference_sample_wavs/sample_ref02.wav filter=lfs diff=lfs merge=lfs -text
43
+ reference_sample_wavs/shiki_fine05.wav filter=lfs diff=lfs merge=lfs -text
44
+ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
Configs/config.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 100 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 400 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: "/local/LJSpeech-1.1/wavs"
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: false
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'istftnet' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10, 6]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+
59
+ # speech language model config
60
+ slm:
61
+ model: 'microsoft/wavlm-base-plus'
62
+ sr: 16000 # sampling rate of SLM
63
+ hidden: 768 # hidden size of SLM
64
+ nlayers: 13 # number of layers of SLM
65
+ initial_channel: 64 # initial channels of SLM discriminator head
66
+
67
+ # style diffusion model config
68
+ diffusion:
69
+ embedding_mask_proba: 0.1
70
+ # transformer config
71
+ transformer:
72
+ num_layers: 3
73
+ num_heads: 8
74
+ head_features: 64
75
+ multiplier: 2
76
+
77
+ # diffusion distribution config
78
+ dist:
79
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
80
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
81
+ mean: -3.0
82
+ std: 1.0
83
+
84
+ loss_params:
85
+ lambda_mel: 5. # mel reconstruction loss
86
+ lambda_gen: 1. # generator loss
87
+ lambda_slm: 1. # slm feature matching loss
88
+
89
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
90
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
91
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
92
+
93
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
94
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
95
+ lambda_dur: 1. # duration loss (2nd stage)
96
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
97
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
98
+ lambda_diff: 1. # score matching loss (2nd stage)
99
+
100
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
101
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
102
+
103
+ optimizer_params:
104
+ lr: 0.0001 # general learning rate
105
+ bert_lr: 0.00001 # learning rate for PLBERT
106
+ ft_lr: 0.00001 # learning rate for acoustic modules
107
+
108
+ slmadv_params:
109
+ min_len: 400 # minimum length of samples
110
+ max_len: 500 # maximum length of samples
111
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
112
+ iter: 10 # update the discriminator every this iterations of generator update
113
+ thresh: 5 # gradient norm above which the gradient is scaled
114
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
115
+ sig: 1.5 # sigma for differentiable duration modeling
116
+
Configs/config_ft.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/IMAS_FineTuned"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50 # number of finetuning epoch (1 hour of data)
6
+ batch_size: 3
7
+ max_len: 2500 # maximum number of frames
8
+ pretrained_model: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade/NO_SLM_3_epoch_2nd_00002.pth"
9
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
10
+ load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
11
+
12
+ F0_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/JDC/bst.t7"
13
+ ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
14
+ ASR_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/bst_00080.pth"
15
+
16
+ PLBERT_dir: 'Utils/PLBERT/'
17
+
18
+ data_params:
19
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/FT_imas.csv"
20
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/FT_imas_valid.csv"
21
+ root_path: ""
22
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
23
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
24
+
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ decoder:
49
+ type: 'istftnet' # either hifigan or istftnet
50
+ resblock_kernel_sizes: [3,7,11]
51
+ upsample_rates : [10, 6]
52
+ upsample_initial_channel: 512
53
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
54
+ upsample_kernel_sizes: [20, 12]
55
+ gen_istft_n_fft: 20
56
+ gen_istft_hop_size: 5
57
+
58
+
59
+
60
+ # speech language model config
61
+ slm:
62
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
63
+ sr: 16000 # sampling rate of SLM
64
+ hidden: 1280 # hidden size of SLM
65
+ nlayers: 33 # number of layers of SLM
66
+ initial_channel: 64 # initial channels of SLM discriminator head
67
+
68
+ # style diffusion model config
69
+ diffusion:
70
+ embedding_mask_proba: 0.1
71
+ # transformer config
72
+ transformer:
73
+ num_layers: 3
74
+ num_heads: 8
75
+ head_features: 64
76
+ multiplier: 2
77
+
78
+ # diffusion distribution config
79
+ dist:
80
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
81
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
82
+ mean: -3.0
83
+ std: 1.0
84
+
85
+ loss_params:
86
+ lambda_mel: 10. # mel reconstruction loss
87
+ lambda_gen: 1. # generator loss
88
+ lambda_slm: 1. # slm feature matching loss
89
+
90
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
91
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
92
+ TMA_epoch: 9 # TMA starting epoch (1st stage)
93
+
94
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
95
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
96
+ lambda_dur: 1. # duration loss (2nd stage)
97
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
98
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
99
+ lambda_diff: 1. # score matching loss (2nd stage)
100
+
101
+ diff_epoch: 0 # style diffusion starting epoch (2nd stage)
102
+ joint_epoch: 30 # joint training starting epoch (2nd stage)
103
+
104
+ optimizer_params:
105
+ lr: 0.0001 # general learning rate
106
+ bert_lr: 0.00001 # learning rate for PLBERT
107
+ ft_lr: 0.00001 # learning rate for acoustic modules
108
+
109
+ slmadv_params:
110
+ min_len: 400 # minimum length of samples
111
+ max_len: 500 # maximum length of samples
112
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
113
+ iter: 20 # update the discriminator every this iterations of generator update
114
+ thresh: 5 # gradient norm above which the gradient is scaled
115
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
116
+ sig: 1.5 # sigma for differentiable duration modeling
Configs/config_kanade.yml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/Style_Kanade_v02"
2
+ first_stage_path: ""
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 30 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 20 # number of peochs for second stage training (joint training)
8
+ batch_size: 64
9
+ max_len: 560 # maximum number of frames
10
+ pretrained_model: "Models/Style_Kanade_v02/epoch_2nd_00007.pth"
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/bst_00080.pth"
17
+
18
+ PLBERT_dir: 'Utils/PLBERT/'
19
+
20
+ data_params:
21
+ train_data: "Data/metadata_cleanest/DATA.csv"
22
+ val_data: "Data/VALID.txt"
23
+ root_path: ""
24
+ OOD_data: "Data/OOD_LargeScale_.csv"
25
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
26
+
27
+
28
+ preprocess_params:
29
+ sr: 24000
30
+ spect_params:
31
+ n_fft: 2048
32
+ win_length: 1200
33
+ hop_length: 300
34
+
35
+ model_params:
36
+ multispeaker: true
37
+
38
+ dim_in: 64
39
+ hidden_dim: 512
40
+ max_conv_dim: 512
41
+ n_layer: 3
42
+ n_mels: 80
43
+
44
+ n_token: 178 # number of phoneme tokens
45
+ max_dur: 50 # maximum duration of a single phoneme
46
+ style_dim: 128 # style vector size
47
+
48
+ dropout: 0.2
49
+
50
+ decoder:
51
+ type: 'istftnet' # either hifigan or istftnet
52
+ resblock_kernel_sizes: [3,7,11]
53
+ upsample_rates : [10, 6]
54
+ upsample_initial_channel: 512
55
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
56
+ upsample_kernel_sizes: [20, 12]
57
+ gen_istft_n_fft: 20
58
+ gen_istft_hop_size: 5
59
+
60
+
61
+
62
+ # speech language model config
63
+ slm:
64
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
65
+ sr: 16000 # sampling rate of SLM
66
+ hidden: 1280 # hidden size of SLM
67
+ nlayers: 33 # number of layers of SLM
68
+ initial_channel: 64 # initial channels of SLM discriminator head
69
+
70
+ # style diffusion model config
71
+ diffusion:
72
+ embedding_mask_proba: 0.1
73
+ # transformer config
74
+ transformer:
75
+ num_layers: 3
76
+ num_heads: 8
77
+ head_features: 64
78
+ multiplier: 2
79
+
80
+ # diffusion distribution config
81
+ dist:
82
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
83
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
84
+ mean: -3.0
85
+ std: 1.0
86
+
87
+ loss_params:
88
+ lambda_mel: 10. # mel reconstruction loss
89
+ lambda_gen: 1. # generator loss
90
+ lambda_slm: 1. # slm feature matching loss
91
+
92
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
93
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
94
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
95
+
96
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
97
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
98
+ lambda_dur: 1. # duration loss (2nd stage)
99
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
100
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
101
+ lambda_diff: 1. # score matching loss (2nd stage)
102
+
103
+ diff_epoch: 4 # style diffusion starting epoch (2nd stage)
104
+ joint_epoch: 999 # joint training starting epoch (2nd stage)
105
+
106
+ optimizer_params:
107
+ lr: 0.0001 # general learning rate
108
+ bert_lr: 0.00001 # learning rate for PLBERT
109
+ ft_lr: 0.00001 # learning rate for acoustic modules
110
+
111
+ slmadv_params:
112
+ min_len: 400 # minimum length of samples
113
+ max_len: 500 # maximum length of samples
114
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
115
+ iter: 20 # update the discriminator every this iterations of generator update
116
+ thresh: 5 # gradient norm above which the gradient is scaled
117
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
118
+ sig: 1.5 # sigma for differentiable duration modeling
Inference/infer_24khz_mod.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Inference/input_for_prompt.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ この俺に何度も同じことを説明させるな!! お前たちは俺の忠告を完全に無視して、とんでもない結果を招いてしまった。これが最後の警告だ。次は絶対に許さないぞ!
2
+ 時には、静けさの中にこそ、本当の答えが見つかるものですね。慌てる必要はないのです。
3
+ 人生には、表現しきれないほどの驚きがあると思うよ。それは、目には見えない力で、人々を繋ぐ不思議な絆だ。私は、その驚きを胸に秘め、日々を楽しく過ごしているんだ。言葉を伝えるたびに、未来への期待を込めて、元気に話す。それは、夢を叶えるための魔法のようなものだ。
4
+ かなたの次元より迫り来る混沌の使者たちよ、貴様らの野望を我が焔の業火で焼き尽くす! 運命の歯車は、我が意思と共にすでに動き出したのだ。我が宿命の敵に立ち向かうため、禁断の呪文を紡ぐ時は今ここに訪れる。さあ、見るがよい。我が力を!!
Inference/prompt.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ A male voice that resonates with deep, thunderous intensity. His rapid-fire words slam like aggressive drumbeats, each syllable charged with intense rage. The expressive tone fluctuates between restrained fury and explosive outbursts.
2
+ A female voice with a distinctively deep, low pitch that commands attention. Her slightly monotone delivery creates an air of composure and gravitas, while maintaining a calm, measured pace. Her voice carries a soothing weight, like gentle thunder in the distance, making her words feel grounded and reassuring.
3
+ a female voice that is gentle and soft, with a slightly high pitch that adds a comforting warmth to her words. Her tone is moderate, neither too expressive nor too flat, creating a balanced and soothing atmosphere. The slow speed of her speech gives her words a deliberate and thoughtful cadence, allowing each phrase to resonate fully. There's a sense of wonder and optimism in her voice, as if she is constantly marveling at the mysteries of life. Her gentle demeanor and soft delivery make her sound approachable and kind, inviting listeners to share in her sense of wonder.
4
+ A female voice that resonates with deep intensity and a distinctly low pitch. Her words flow with the force and rhythm of a relentless tide, each syllable weighted with profound determination. The expressive tone navigates between measured intensity and powerful surges.
Inference/random_texts.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Akashi: 不思議な人ですね、レザさんは。たまには子供扱いしてくれてちょっとむきになりますけど、とても頼りがいのある人だと思いますよ?
2
+ Kimiji: 人生は、果てしない探求の旅のようなもの。私たちは、自分自身や周囲の世界について、常に新しい発見をしていく。それは、時として喜びをもたらすこともあれば、困難に直面することもある。しかしそれら全てが、自分を形作る貴重な経験である。
3
+ Reira: 私に何度も同じことを説明させないでよ! お前たちは私の忠告を完全に無視して、とんでもない結果を招いてしまった!! これが最後の警告だ。次は絶対に許さないぞ! ------------------------------------------- (NOTE: enable the diffusion, then set the Intensity to 3, also remove this line!)
4
+ Yoriko: この世には、言葉にできない悲しみがある。 それは、胸の奥に沈んでいくような重さで、時間が経つにつれて、じわじわと広がっていく。私は、その悲しみを抱えながら、日々を過ごしていいた。 言葉を発するたびに、心の中で何度も繰り返し、慎重に選び抜いている。それは、痛みを和らげるための儀式のようなものだ.
5
+ Kimiji: 人生には、表現しきれないほどの驚きがあると思います。それは、目には見えない力で、人々を繋ぐ不思議な絆です。私は、その驚きを胸に秘め、日々を楽しく過ごしています。言葉を伝えるたびに、未来への期待を込めて、元気に話す。それは、夢を叶えるための魔法のようなものです。
6
+ Teppei: そうだな、この新しいシステムの仕組みについて説明しておこう。基本的には三層構造になっていて、各層が独立して機能している。一番下のレイヤーでデータの処理と保存を行い、真ん中の層でビジネスロジックを実装している。ユーザーが直接触れるのは最上位層だけで、そこでインターフェースの制御をしているんだ。パフォーマンスを考えると、データベースへのアクセスを最小限に抑える必要があるから、キャッシュの実装も検討している。ただ、システムの規模を考えると、当面は現状の構成で十分だと思われる。将来的にスケールする必要が出てきた場合は、その時点で見直しを検討すればいいだろう。
7
+ Kirara: べ、別にそんなことじゃないってば!あんたのことなんて全然気にしてないんだからね!
8
+ Maiko: ねえ、ちょっと今日の空を見上げて。朝から少しずつ変わっていく雲の形が、まるで漫画の中の風景みたい。東の方からゆっくりと暖かい風が吹いてきて、桜の花びらが舞い散るように、優しく大地を撫でていくの。春の陽気が徐々に夏の暑さに変わっていくこの季節は、なんだかわくわくするよね。でも、週間予報によると、明後日あたりから天気が崩れるみたいで、しばらくは傘の出番かもしれないんだ。梅雨の時期が近づいているから、空気も少しずつ湿っぽくなってきているのを感じない?でも、雨上がりの空気って、なんだか特別な匂いがして、私、結構好きなんだよね。
9
+ Amane: そ、そうかな?あなたと過ごした今までの時間は、私に、とても、とーっても大事だよ?だって、あなたの言葉には、どこか特別な温もりがあるような気がする。まるで、心の中に直接響いてくるようで、ほんの少しの間だけ、世界が優しくなった気がするよ。
10
+ Shioji: つい昨日のことみたいなのに、もう随分と昔のお話なんだね ! 今でも古い本の匂いがすると、あの静かで穏やかな時間がまるで透明で綺麗な海みたいに心の中でふわりと溶けていくの。
11
+ Riemi: かなたの次元より迫り来る混沌の使者たちよ、貴様らの野望を我が焔の業火で焼き尽くす! 運命の歯車は、我が意思と共にすでに動き出したのだ。我が宿命の敵に立ち向かうため、禁断の呪文を紡ぐ時は今ここに訪れる。さあ、見るがよい。我が力を!!
12
+ だめですよ?そんなことは。もっと慎重に行動してくださいね。
13
+ Kazunari: ojousama wa, katachi dake no tsukibito ga ireba sore de ii, to omotte orareru you desuga, katachi bakari de mo, sannen ka o tomo ni seikatsu o suru koto ni narimasɯ. --------------------------------------- (NOTE!: your spacing will impact the intonations. also remove this line!)
14
+ Kimiu: お前 wa muzukashiku 考えると kekka to shite 空回り suru taipu dakara na. 大体, sonna kanji de oboete ikeba iin da. -------------------------------------------------- (NOTE!: your spacing will impact the intonations. also remove this line!)
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE FOR STYLETTS2 DEMO PAGE: GPL (I KNOW I DON'T LIKE GPL BUT I HAVE TO BECAUSE OF PHONEMIZER REQUIREMENT)
2
+
3
+
4
+ styletts 2 license:
5
+
6
+ Copyright (c) 2023 Aaron (Yinghao) Li
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+
Models/Style_Tsukasa_v02/Top_ckpt_24khz.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f50001d997d5a73328c9f12155451372de0cdaf3a10ec5ab9cab7d577cd69b5
3
+ size 2044670782
Models/Style_Tsukasa_v02/config_kanade.yml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02"
2
+ first_stage_path: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02/epoch_1st_00026.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 30 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 20 # number of peochs for second stage training (joint training)
8
+ batch_size: 64
9
+ max_len: 4000 # maximum number of frames
10
+ pretrained_model: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02/epoch_2nd_00007.pth"
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ # CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch train_first.py --config_path ./Configs/config_kanade.yml
15
+ # CUDA_VISIBLE_DEVICES=6,7 accelerate launch accelerate_train_second.py --config_path ./Configs/config_kanade_test.yml
16
+
17
+ F0_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/JDC/bst.t7"
18
+ ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
19
+ ASR_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/bst_00080.pth"
20
+
21
+ PLBERT_dir: 'Utils/PLBERT/'
22
+
23
+ data_params:
24
+ train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/filtered_train_list_no_nsp_plus.csv"
25
+ val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/mg_valid.txt"
26
+ root_path: ""
27
+ OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
28
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
29
+
30
+
31
+ preprocess_params:
32
+ sr: 24000
33
+ spect_params:
34
+ n_fft: 2048
35
+ win_length: 1200
36
+ hop_length: 300
37
+
38
+ model_params:
39
+ multispeaker: true
40
+
41
+ dim_in: 64
42
+ hidden_dim: 512
43
+ max_conv_dim: 512
44
+ n_layer: 3
45
+ n_mels: 80
46
+
47
+ n_token: 178 # number of phoneme tokens
48
+ max_dur: 50 # maximum duration of a single phoneme
49
+ style_dim: 128 # style vector size
50
+
51
+ dropout: 0.2
52
+
53
+ decoder:
54
+ type: 'istftnet' # either hifigan or istftnet
55
+ resblock_kernel_sizes: [3,7,11]
56
+ upsample_rates : [10, 6]
57
+ upsample_initial_channel: 512
58
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
59
+ upsample_kernel_sizes: [20, 12]
60
+ gen_istft_n_fft: 20
61
+ gen_istft_hop_size: 5
62
+
63
+
64
+
65
+ # speech language model config
66
+ slm:
67
+ model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
68
+ sr: 16000 # sampling rate of SLM
69
+ hidden: 1280 # hidden size of SLM
70
+ nlayers: 33 # number of layers of SLM
71
+ initial_channel: 64 # initial channels of SLM discriminator head
72
+
73
+ # style diffusion model config
74
+ diffusion:
75
+ embedding_mask_proba: 0.1
76
+ # transformer config
77
+ transformer:
78
+ num_layers: 3
79
+ num_heads: 8
80
+ head_features: 64
81
+ multiplier: 2
82
+
83
+ # diffusion distribution config
84
+ dist:
85
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
86
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
87
+ mean: -3.0
88
+ std: 1.0
89
+
90
+ loss_params:
91
+ lambda_mel: 10. # mel reconstruction loss
92
+ lambda_gen: 1. # generator loss
93
+ lambda_slm: 1. # slm feature matching loss
94
+
95
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
96
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
97
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
98
+
99
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
100
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
101
+ lambda_dur: 1. # duration loss (2nd stage)
102
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
103
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
104
+ lambda_diff: 1. # score matching loss (2nd stage)
105
+
106
+ diff_epoch: 4 # style diffusion starting epoch (2nd stage)
107
+ joint_epoch: 999 # joint training starting epoch (2nd stage)
108
+
109
+ optimizer_params:
110
+ lr: 0.0001 # general learning rate
111
+ bert_lr: 0.00001 # learning rate for PLBERT
112
+ ft_lr: 0.00001 # learning rate for acoustic modules
113
+
114
+ slmadv_params:
115
+ min_len: 400 # minimum length of samples
116
+ max_len: 500 # maximum length of samples
117
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
118
+ iter: 20 # update the discriminator every this iterations of generator update
119
+ thresh: 5 # gradient norm above which the gradient is scaled
120
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
121
+ sig: 1.5 # sigma for differentiable duration modeling
Modules/KotoDama_sampler.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
2
+ import torch
3
+ import torch.nn as nn
4
+ from text_utils import TextCleaner
5
+ textclenaer = TextCleaner()
6
+
7
+
8
+ def length_to_mask(lengths):
9
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
10
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
11
+ return mask
12
+
13
+
14
+
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+
19
+ # tokenizer_koto_prompt = AutoTokenizer.from_pretrained("google/mt5-small", trust_remote_code=True)
20
+ tokenizer_koto_prompt = AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese", trust_remote_code=True)
21
+ tokenizer_koto_text = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True)
22
+
23
+ class KotoDama_Prompt(PreTrainedModel):
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+
28
+ self.backbone = AutoModel.from_config(config)
29
+
30
+ self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
31
+ nn.LeakyReLU(0.2),
32
+ nn.Linear(512, config.num_labels))
33
+
34
+
35
+
36
+ def forward(
37
+ self,
38
+ input_ids,
39
+ attention_mask=None,
40
+ token_type_ids=None,
41
+ position_ids=None,
42
+ labels=None,
43
+ ):
44
+ outputs = self.backbone(
45
+ input_ids,
46
+ attention_mask=attention_mask,
47
+ token_type_ids=token_type_ids,
48
+ position_ids=position_ids,
49
+ )
50
+
51
+
52
+ sequence_output = outputs.last_hidden_state[:, 0, :]
53
+ outputs = self.output(sequence_output)
54
+
55
+ # if labels, then we are training
56
+ loss = None
57
+ if labels is not None:
58
+
59
+ loss_fn = nn.MSELoss()
60
+ # labels = labels.unsqueeze(1)
61
+ loss = loss_fn(outputs, labels)
62
+
63
+ return {
64
+ "loss": loss,
65
+ "logits": outputs
66
+ }
67
+
68
+
69
+ class KotoDama_Text(PreTrainedModel):
70
+
71
+ def __init__(self, config):
72
+ super().__init__(config)
73
+
74
+ self.backbone = AutoModel.from_config(config)
75
+
76
+ self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
77
+ nn.LeakyReLU(0.2),
78
+ nn.Linear(512, config.num_labels))
79
+
80
+
81
+
82
+ def forward(
83
+ self,
84
+ input_ids,
85
+ attention_mask=None,
86
+ # token_type_ids=None,
87
+ # position_ids=None,
88
+ labels=None,
89
+ ):
90
+ outputs = self.backbone(
91
+ input_ids,
92
+ attention_mask=attention_mask,
93
+ # token_type_ids=token_type_ids,
94
+ # position_ids=position_ids,
95
+ )
96
+
97
+
98
+ sequence_output = outputs.last_hidden_state[:, 0, :]
99
+ outputs = self.output(sequence_output)
100
+
101
+ # if labels, then we are training
102
+ loss = None
103
+ if labels is not None:
104
+
105
+ loss_fn = nn.MSELoss()
106
+ # labels = labels.unsqueeze(1)
107
+ loss = loss_fn(outputs, labels)
108
+
109
+ return {
110
+ "loss": loss,
111
+ "logits": outputs
112
+ }
113
+
114
+
115
+ def inference(model, diffusion_sampler, text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.):
116
+
117
+ tokens = textclenaer(text)
118
+ tokens.insert(0, 0)
119
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
120
+
121
+ with torch.no_grad():
122
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
123
+
124
+ text_mask = length_to_mask(input_lengths).to(device)
125
+
126
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
127
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
128
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
129
+
130
+
131
+
132
+ s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
133
+ embedding=bert_dur,
134
+ embedding_scale=embedding_scale,
135
+ features=ref_s, # reference from the same speaker as the embedding
136
+ num_steps=diffusion_steps).squeeze(1)
137
+
138
+
139
+ s = s_pred[:, 128:]
140
+ ref = s_pred[:, :128]
141
+
142
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
143
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
144
+
145
+ d = model.predictor.text_encoder(d_en,
146
+ s, input_lengths, text_mask)
147
+
148
+
149
+
150
+ x = model.predictor.lstm(d)
151
+ x_mod = model.predictor.prepare_projection(x)
152
+ duration = model.predictor.duration_proj(x_mod)
153
+
154
+
155
+ duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
156
+
157
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
158
+
159
+
160
+
161
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
162
+
163
+ c_frame = 0
164
+ for i in range(pred_aln_trg.size(0)):
165
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
166
+ c_frame += int(pred_dur[i].data)
167
+
168
+ # encode prosody
169
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
170
+
171
+
172
+
173
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
174
+
175
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
176
+
177
+
178
+ out = model.decoder(asr,
179
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
180
+
181
+
182
+ return out.squeeze().cpu().numpy()[..., :-50]
183
+
184
+
185
+ def Longform(model, diffusion_sampler, text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0):
186
+
187
+ tokens = textclenaer(text)
188
+ tokens.insert(0, 0)
189
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
190
+
191
+ with torch.no_grad():
192
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
193
+ text_mask = length_to_mask(input_lengths).to(device)
194
+
195
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
196
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
197
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
198
+
199
+ s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
200
+ embedding=bert_dur,
201
+ embedding_scale=embedding_scale,
202
+ features=ref_s,
203
+ num_steps=diffusion_steps).squeeze(1)
204
+
205
+ if s_prev is not None:
206
+ # convex combination of previous and current style
207
+ s_pred = t * s_prev + (1 - t) * s_pred
208
+
209
+ s = s_pred[:, 128:]
210
+ ref = s_pred[:, :128]
211
+
212
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
213
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
214
+
215
+ s_pred = torch.cat([ref, s], dim=-1)
216
+
217
+ d = model.predictor.text_encoder(d_en,
218
+ s, input_lengths, text_mask)
219
+
220
+ x = model.predictor.lstm(d)
221
+ x_mod = model.predictor.prepare_projection(x) # 640 -> 512
222
+ duration = model.predictor.duration_proj(x_mod)
223
+
224
+ duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
225
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
226
+
227
+
228
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
229
+ c_frame = 0
230
+ for i in range(pred_aln_trg.size(0)):
231
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
232
+ c_frame += int(pred_dur[i].data)
233
+
234
+ # encode prosody
235
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
236
+
237
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
238
+
239
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
240
+
241
+ out = model.decoder(asr,
242
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
243
+
244
+
245
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred
246
+
247
+
248
+ def merge_short_elements(lst):
249
+ i = 0
250
+ while i < len(lst):
251
+ if i > 0 and len(lst[i]) < 10:
252
+ lst[i-1] += ' ' + lst[i]
253
+ lst.pop(i)
254
+ else:
255
+ i += 1
256
+ return lst
257
+
258
+
259
+ def merge_three(text_list, maxim=2):
260
+
261
+ merged_list = []
262
+ for i in range(0, len(text_list), maxim):
263
+ merged_text = ' '.join(text_list[i:i+maxim])
264
+ merged_list.append(merged_text)
265
+ return merged_list
266
+
267
+
268
+ def merging_sentences(lst):
269
+ return merge_three(merge_short_elements(lst))
Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/__pycache__/KotoDama_sampler.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
Modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (178 Bytes). View file
 
Modules/__pycache__/discriminators.cpython-311.pyc ADDED
Binary file (12.2 kB). View file
 
Modules/__pycache__/hifigan.cpython-311.pyc ADDED
Binary file (30 kB). View file
 
Modules/__pycache__/istftnet.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
Modules/__pycache__/slmadv.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
Modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.18 kB). View file
 
Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (188 Bytes). View file
 
Modules/diffusion/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (5.55 kB). View file
 
Modules/diffusion/__pycache__/modules.cpython-311.pyc ADDED
Binary file (32.8 kB). View file
 
Modules/diffusion/__pycache__/sampler.cpython-311.pyc ADDED
Binary file (37.8 kB). View file
 
Modules/diffusion/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
Modules/diffusion/audio_diffusion_pytorch/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .components import LTPlugin, MelSpectrogram, UNetV0, XUNet
2
+ from .diffusion import (
3
+ Diffusion,
4
+ Distribution,
5
+ LinearSchedule,
6
+ Sampler,
7
+ Schedule,
8
+ UniformDistribution,
9
+ VDiffusion,
10
+ VInpainter,
11
+ VSampler,
12
+ )
13
+ from .models import (
14
+ DiffusionAE,
15
+ DiffusionAR,
16
+ DiffusionModel,
17
+ DiffusionUpsampler,
18
+ DiffusionVocoder,
19
+ EncoderBase,
20
+ )
Modules/diffusion/audio_diffusion_pytorch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (921 Bytes). View file
 
Modules/diffusion/audio_diffusion_pytorch/__pycache__/components.cpython-311.pyc ADDED
Binary file (9.72 kB). View file
 
Modules/diffusion/audio_diffusion_pytorch/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (22.5 kB). View file
 
Modules/diffusion/audio_diffusion_pytorch/__pycache__/models.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
Modules/diffusion/audio_diffusion_pytorch/__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.28 kB). View file
 
Modules/diffusion/audio_diffusion_pytorch/components.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Sequence
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from a_unet import (
6
+ ClassifierFreeGuidancePlugin,
7
+ Conv,
8
+ Module,
9
+ TextConditioningPlugin,
10
+ TimeConditioningPlugin,
11
+ default,
12
+ exists,
13
+ )
14
+ from a_unet.apex import (
15
+ AttentionItem,
16
+ CrossAttentionItem,
17
+ InjectChannelsItem,
18
+ ModulationItem,
19
+ ResnetItem,
20
+ SkipCat,
21
+ SkipModulate,
22
+ XBlock,
23
+ XUNet,
24
+ )
25
+ from einops import pack, unpack
26
+ from torch import Tensor, nn
27
+ from torchaudio import transforms
28
+
29
+ """
30
+ UNets (built with a-unet: https://github.com/archinetai/a-unet)
31
+ """
32
+
33
+
34
+ def UNetV0(
35
+ dim: int,
36
+ in_channels: int,
37
+ channels: Sequence[int],
38
+ factors: Sequence[int],
39
+ items: Sequence[int],
40
+ attentions: Optional[Sequence[int]] = None,
41
+ cross_attentions: Optional[Sequence[int]] = None,
42
+ context_channels: Optional[Sequence[int]] = None,
43
+ attention_features: Optional[int] = None,
44
+ attention_heads: Optional[int] = None,
45
+ embedding_features: Optional[int] = None,
46
+ resnet_groups: int = 8,
47
+ use_modulation: bool = True,
48
+ modulation_features: int = 1024,
49
+ embedding_max_length: Optional[int] = None,
50
+ use_time_conditioning: bool = True,
51
+ use_embedding_cfg: bool = False,
52
+ use_text_conditioning: bool = False,
53
+ out_channels: Optional[int] = None,
54
+ ):
55
+ # Set defaults and check lengths
56
+ num_layers = len(channels)
57
+ attentions = default(attentions, [0] * num_layers)
58
+ cross_attentions = default(cross_attentions, [0] * num_layers)
59
+ context_channels = default(context_channels, [0] * num_layers)
60
+ xs = (channels, factors, items, attentions, cross_attentions, context_channels)
61
+ assert all(len(x) == num_layers for x in xs) # type: ignore
62
+
63
+ # Define UNet type
64
+ UNetV0 = XUNet
65
+
66
+ if use_embedding_cfg:
67
+ msg = "use_embedding_cfg requires embedding_max_length"
68
+ assert exists(embedding_max_length), msg
69
+ UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length)
70
+
71
+ if use_text_conditioning:
72
+ UNetV0 = TextConditioningPlugin(UNetV0)
73
+
74
+ if use_time_conditioning:
75
+ assert use_modulation, "use_time_conditioning requires use_modulation=True"
76
+ UNetV0 = TimeConditioningPlugin(UNetV0)
77
+
78
+ # Build
79
+ return UNetV0(
80
+ dim=dim,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ blocks=[
84
+ XBlock(
85
+ channels=channels,
86
+ factor=factor,
87
+ context_channels=ctx_channels,
88
+ items=(
89
+ [ResnetItem]
90
+ + [ModulationItem] * use_modulation
91
+ + [InjectChannelsItem] * (ctx_channels > 0)
92
+ + [AttentionItem] * att
93
+ + [CrossAttentionItem] * cross
94
+ )
95
+ * items,
96
+ )
97
+ for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa
98
+ ],
99
+ skip_t=SkipModulate if use_modulation else SkipCat,
100
+ attention_features=attention_features,
101
+ attention_heads=attention_heads,
102
+ embedding_features=embedding_features,
103
+ modulation_features=modulation_features,
104
+ resnet_groups=resnet_groups,
105
+ )
106
+
107
+
108
+ """
109
+ Plugins
110
+ """
111
+
112
+
113
+ def LTPlugin(
114
+ net_t: Callable, num_filters: int, window_length: int, stride: int
115
+ ) -> Callable[..., nn.Module]:
116
+ """Learned Transform Plugin"""
117
+
118
+ def Net(
119
+ dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs
120
+ ) -> nn.Module:
121
+ out_channels = default(out_channels, in_channels)
122
+ in_channel_transform = in_channels * num_filters
123
+ out_channel_transform = out_channels * num_filters # type: ignore
124
+
125
+ padding = window_length // 2 - stride // 2
126
+ encode = Conv(
127
+ dim=dim,
128
+ in_channels=in_channels,
129
+ out_channels=in_channel_transform,
130
+ kernel_size=window_length,
131
+ stride=stride,
132
+ padding=padding,
133
+ padding_mode="reflect",
134
+ bias=False,
135
+ )
136
+ decode = nn.ConvTranspose1d(
137
+ in_channels=out_channel_transform,
138
+ out_channels=out_channels, # type: ignore
139
+ kernel_size=window_length,
140
+ stride=stride,
141
+ padding=padding,
142
+ bias=False,
143
+ )
144
+ net = net_t( # type: ignore
145
+ dim=dim,
146
+ in_channels=in_channel_transform,
147
+ out_channels=out_channel_transform,
148
+ **kwargs
149
+ )
150
+
151
+ def forward(x: Tensor, *args, **kwargs):
152
+ x = encode(x)
153
+ x = net(x, *args, **kwargs)
154
+ x = decode(x)
155
+ return x
156
+
157
+ return Module([encode, decode, net], forward)
158
+
159
+ return Net
160
+
161
+
162
+ def AppendChannelsPlugin(
163
+ net_t: Callable,
164
+ channels: int,
165
+ ):
166
+ def Net(
167
+ in_channels: int, out_channels: Optional[int] = None, **kwargs
168
+ ) -> nn.Module:
169
+ out_channels = default(out_channels, in_channels)
170
+ net = net_t( # type: ignore
171
+ in_channels=in_channels + channels, out_channels=out_channels, **kwargs
172
+ )
173
+
174
+ def forward(x: Tensor, *args, append_channels: Tensor, **kwargs):
175
+ x = torch.cat([x, append_channels], dim=1)
176
+ return net(x, *args, **kwargs)
177
+
178
+ return Module([net], forward)
179
+
180
+ return Net
181
+
182
+
183
+ """
184
+ Other
185
+ """
186
+
187
+
188
+ class MelSpectrogram(nn.Module):
189
+ def __init__(
190
+ self,
191
+ n_fft: int,
192
+ hop_length: int,
193
+ win_length: int,
194
+ sample_rate: int,
195
+ n_mel_channels: int,
196
+ center: bool = False,
197
+ normalize: bool = False,
198
+ normalize_log: bool = False,
199
+ ):
200
+ super().__init__()
201
+ self.padding = (n_fft - hop_length) // 2
202
+ self.normalize = normalize
203
+ self.normalize_log = normalize_log
204
+ self.hop_length = hop_length
205
+
206
+ self.to_spectrogram = transforms.Spectrogram(
207
+ n_fft=n_fft,
208
+ hop_length=hop_length,
209
+ win_length=win_length,
210
+ center=center,
211
+ power=None,
212
+ )
213
+
214
+ self.to_mel_scale = transforms.MelScale(
215
+ n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate
216
+ )
217
+
218
+ def forward(self, waveform: Tensor) -> Tensor:
219
+ # Pack non-time dimension
220
+ waveform, ps = pack([waveform], "* t")
221
+ # Pad waveform
222
+ waveform = F.pad(waveform, [self.padding] * 2, mode="reflect")
223
+ # Compute STFT
224
+ spectrogram = self.to_spectrogram(waveform)
225
+ # Compute magnitude
226
+ spectrogram = torch.abs(spectrogram)
227
+ # Convert to mel scale
228
+ mel_spectrogram = self.to_mel_scale(spectrogram)
229
+ # Normalize
230
+ if self.normalize:
231
+ mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram)
232
+ mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1
233
+ if self.normalize_log:
234
+ mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5))
235
+ # Unpack non-spectrogram dimension
236
+ return unpack(mel_spectrogram, ps, "* f l")[0]
Modules/diffusion/audio_diffusion_pytorch/diffusion.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from typing import Any, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from torch import Tensor
9
+ from tqdm import tqdm
10
+
11
+ from .utils import default
12
+
13
+ """ Distributions """
14
+
15
+
16
+ class Distribution:
17
+ """Interface used by different distributions"""
18
+
19
+ def __call__(self, num_samples: int, device: torch.device):
20
+ raise NotImplementedError()
21
+
22
+
23
+ class UniformDistribution(Distribution):
24
+ def __init__(self, vmin: float = 0.0, vmax: float = 1.0):
25
+ super().__init__()
26
+ self.vmin, self.vmax = vmin, vmax
27
+
28
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
29
+ vmax, vmin = self.vmax, self.vmin
30
+ return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin
31
+
32
+
33
+ """ Diffusion Methods """
34
+
35
+
36
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
37
+ # Pads additional ndims to the right of the tensor
38
+ return x.view(*x.shape, *((1,) * ndim))
39
+
40
+
41
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
42
+ if dynamic_threshold == 0.0:
43
+ return x.clamp(-1.0, 1.0)
44
+ else:
45
+ # Dynamic thresholding
46
+ # Find dynamic threshold quantile for each batch
47
+ x_flat = rearrange(x, "b ... -> b (...)")
48
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
49
+ # Clamp to a min of 1.0
50
+ scale.clamp_(min=1.0)
51
+ # Clamp all values and scale
52
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
53
+ x = x.clamp(-scale, scale) / scale
54
+ return x
55
+
56
+
57
+ def extend_dim(x: Tensor, dim: int):
58
+ # e.g. if dim = 4: shape [b] => [b, 1, 1, 1],
59
+ return x.view(*x.shape + (1,) * (dim - x.ndim))
60
+
61
+
62
+ class Diffusion(nn.Module):
63
+ """Interface used by different diffusion methods"""
64
+
65
+ pass
66
+
67
+
68
+ class VDiffusion(Diffusion):
69
+ def __init__(
70
+ self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution(), loss_fn: Any = F.mse_loss
71
+ ):
72
+ super().__init__()
73
+ self.net = net
74
+ self.sigma_distribution = sigma_distribution
75
+ self.loss_fn = loss_fn
76
+
77
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
78
+ angle = sigmas * pi / 2
79
+ alpha, beta = torch.cos(angle), torch.sin(angle)
80
+ return alpha, beta
81
+
82
+ def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore
83
+ batch_size, device = x.shape[0], x.device
84
+ # Sample amount of noise to add for each batch element
85
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
86
+ sigmas_batch = extend_dim(sigmas, dim=x.ndim)
87
+ # Get noise
88
+ noise = torch.randn_like(x)
89
+ # Combine input and noise weighted by half-circle
90
+ alphas, betas = self.get_alpha_beta(sigmas_batch)
91
+ x_noisy = alphas * x + betas * noise
92
+ v_target = alphas * noise - betas * x
93
+ # Predict velocity and return loss
94
+ v_pred = self.net(x_noisy, sigmas, **kwargs)
95
+ return self.loss_fn(v_pred, v_target)
96
+
97
+
98
+ class ARVDiffusion(Diffusion):
99
+ def __init__(self, net: nn.Module, length: int, num_splits: int, loss_fn: Any = F.mse_loss):
100
+ super().__init__()
101
+ assert length % num_splits == 0, "length must be divisible by num_splits"
102
+ self.net = net
103
+ self.length = length
104
+ self.num_splits = num_splits
105
+ self.split_length = length // num_splits
106
+ self.loss_fn = loss_fn
107
+
108
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
109
+ angle = sigmas * pi / 2
110
+ alpha, beta = torch.cos(angle), torch.sin(angle)
111
+ return alpha, beta
112
+
113
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
114
+ """Returns diffusion loss of v-objective with different noises per split"""
115
+ b, _, t, device, dtype = *x.shape, x.device, x.dtype
116
+ assert t == self.length, "input length must match length"
117
+ # Sample amount of noise to add for each split
118
+ sigmas = torch.rand((b, 1, self.num_splits), device=device, dtype=dtype)
119
+ sigmas = repeat(sigmas, "b 1 n -> b 1 (n l)", l=self.split_length)
120
+ # Get noise
121
+ noise = torch.randn_like(x)
122
+ # Combine input and noise weighted by half-circle
123
+ alphas, betas = self.get_alpha_beta(sigmas)
124
+ x_noisy = alphas * x + betas * noise
125
+ v_target = alphas * noise - betas * x
126
+ # Sigmas will be provided as additional channel
127
+ channels = torch.cat([x_noisy, sigmas], dim=1)
128
+ # Predict velocity and return loss
129
+ v_pred = self.net(channels, **kwargs)
130
+ return self.loss_fn(v_pred, v_target)
131
+
132
+ """ Schedules """
133
+
134
+
135
+ class Schedule(nn.Module):
136
+ """Interface used by different sampling schedules"""
137
+
138
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
139
+ raise NotImplementedError()
140
+
141
+
142
+ class LinearSchedule(Schedule):
143
+ def __init__(self, start: float = 1.0, end: float = 0.0):
144
+ super().__init__()
145
+ self.start, self.end = start, end
146
+
147
+ def forward(self, num_steps: int, device: Any) -> Tensor:
148
+ return torch.linspace(self.start, self.end, num_steps, device=device)
149
+
150
+
151
+ """ Samplers """
152
+
153
+
154
+ class Sampler(nn.Module):
155
+ pass
156
+
157
+
158
+ class VSampler(Sampler):
159
+
160
+ diffusion_types = [VDiffusion]
161
+
162
+ def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
163
+ super().__init__()
164
+ self.net = net
165
+ self.schedule = schedule
166
+
167
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
168
+ angle = sigmas * pi / 2
169
+ alpha, beta = torch.cos(angle), torch.sin(angle)
170
+ return alpha, beta
171
+
172
+ @torch.no_grad()
173
+ def forward( # type: ignore
174
+ self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
175
+ ) -> Tensor:
176
+ b = x_noisy.shape[0]
177
+ sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
178
+ sigmas = repeat(sigmas, "i -> i b", b=b)
179
+ sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
180
+ alphas, betas = self.get_alpha_beta(sigmas_batch)
181
+ progress_bar = tqdm(range(num_steps), disable=not show_progress)
182
+
183
+ for i in progress_bar:
184
+ v_pred = self.net(x_noisy, sigmas[i], **kwargs)
185
+ x_pred = alphas[i] * x_noisy - betas[i] * v_pred
186
+ noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
187
+ x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
188
+ progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
189
+
190
+ return x_noisy
191
+
192
+
193
+ class ARVSampler(Sampler):
194
+ def __init__(self, net: nn.Module, in_channels: int, length: int, num_splits: int):
195
+ super().__init__()
196
+ assert length % num_splits == 0, "length must be divisible by num_splits"
197
+ self.length = length
198
+ self.in_channels = in_channels
199
+ self.num_splits = num_splits
200
+ self.split_length = length // num_splits
201
+ self.net = net
202
+
203
+ @property
204
+ def device(self):
205
+ return next(self.net.parameters()).device
206
+
207
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
208
+ angle = sigmas * pi / 2
209
+ alpha = torch.cos(angle)
210
+ beta = torch.sin(angle)
211
+ return alpha, beta
212
+
213
+ def get_sigmas_ladder(self, num_items: int, num_steps_per_split: int) -> Tensor:
214
+ b, n, l, i = num_items, self.num_splits, self.split_length, num_steps_per_split
215
+ n_half = n // 2 # Only half ladder, rest is zero, to leave some context
216
+ sigmas = torch.linspace(1, 0, i * n_half, device=self.device)
217
+ sigmas = repeat(sigmas, "(n i) -> i b 1 (n l)", b=b, l=l, n=n_half)
218
+ sigmas = torch.flip(sigmas, dims=[-1]) # Lowest noise level first
219
+ sigmas = F.pad(sigmas, pad=[0, 0, 0, 0, 0, 0, 0, 1]) # Add index i+1
220
+ sigmas[-1, :, :, l:] = sigmas[0, :, :, :-l] # Loop back at index i+1
221
+ return torch.cat([torch.zeros_like(sigmas), sigmas], dim=-1)
222
+
223
+ def sample_loop(
224
+ self, current: Tensor, sigmas: Tensor, show_progress: bool = False, **kwargs
225
+ ) -> Tensor:
226
+ num_steps = sigmas.shape[0] - 1
227
+ alphas, betas = self.get_alpha_beta(sigmas)
228
+ progress_bar = tqdm(range(num_steps), disable=not show_progress)
229
+
230
+ for i in progress_bar:
231
+ channels = torch.cat([current, sigmas[i]], dim=1)
232
+ v_pred = self.net(channels, **kwargs)
233
+ x_pred = alphas[i] * current - betas[i] * v_pred
234
+ noise_pred = betas[i] * current + alphas[i] * v_pred
235
+ current = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
236
+ progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0,0,0]:.2f})")
237
+
238
+ return current
239
+
240
+ def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor:
241
+ b, c, t = num_items, self.in_channels, self.length
242
+ # Same sigma schedule over all chunks
243
+ sigmas = torch.linspace(1, 0, num_steps + 1, device=self.device)
244
+ sigmas = repeat(sigmas, "i -> i b 1 t", b=b, t=t)
245
+ noise = torch.randn((b, c, t), device=self.device) * sigmas[0]
246
+ # Sample start
247
+ return self.sample_loop(current=noise, sigmas=sigmas, **kwargs)
248
+
249
+ @torch.no_grad()
250
+ def forward(
251
+ self,
252
+ num_items: int,
253
+ num_chunks: int,
254
+ num_steps: int,
255
+ start: Optional[Tensor] = None,
256
+ show_progress: bool = False,
257
+ **kwargs,
258
+ ) -> Tensor:
259
+ assert_message = f"required at least {self.num_splits} chunks"
260
+ assert num_chunks >= self.num_splits, assert_message
261
+
262
+ # Sample initial chunks
263
+ start = self.sample_start(num_items=num_items, num_steps=num_steps, **kwargs)
264
+ # Return start if only num_splits chunks
265
+ if num_chunks == self.num_splits:
266
+ return start
267
+
268
+ # Get sigmas for autoregressive ladder
269
+ b, n = num_items, self.num_splits
270
+ assert num_steps >= n, "num_steps must be greater than num_splits"
271
+ sigmas = self.get_sigmas_ladder(
272
+ num_items=b,
273
+ num_steps_per_split=num_steps // self.num_splits,
274
+ )
275
+ alphas, betas = self.get_alpha_beta(sigmas)
276
+
277
+ # Noise start to match ladder and set starting chunks
278
+ start_noise = alphas[0] * start + betas[0] * torch.randn_like(start)
279
+ chunks = list(start_noise.chunk(chunks=n, dim=-1))
280
+
281
+ # Loop over ladder shifts
282
+ num_shifts = num_chunks # - self.num_splits
283
+ progress_bar = tqdm(range(num_shifts), disable=not show_progress)
284
+
285
+ for j in progress_bar:
286
+ # Decrease ladder noise of last n chunks
287
+ updated = self.sample_loop(
288
+ current=torch.cat(chunks[-n:], dim=-1), sigmas=sigmas, **kwargs
289
+ )
290
+ # Update chunks
291
+ chunks[-n:] = list(updated.chunk(chunks=n, dim=-1))
292
+ # Add fresh noise chunk
293
+ shape = (b, self.in_channels, self.split_length)
294
+ chunks += [torch.randn(shape, device=self.device)]
295
+
296
+ return torch.cat(chunks[:num_chunks], dim=-1)
297
+
298
+
299
+ """ Inpainters """
300
+
301
+
302
+ class Inpainter(nn.Module):
303
+ pass
304
+
305
+
306
+ class VInpainter(Inpainter):
307
+
308
+ diffusion_types = [VDiffusion]
309
+
310
+ def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
311
+ super().__init__()
312
+ self.net = net
313
+ self.schedule = schedule
314
+
315
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
316
+ angle = sigmas * pi / 2
317
+ alpha, beta = torch.cos(angle), torch.sin(angle)
318
+ return alpha, beta
319
+
320
+ @torch.no_grad()
321
+ def forward( # type: ignore
322
+ self,
323
+ source: Tensor,
324
+ mask: Tensor,
325
+ num_steps: int,
326
+ num_resamples: int,
327
+ show_progress: bool = False,
328
+ x_noisy: Optional[Tensor] = None,
329
+ **kwargs,
330
+ ) -> Tensor:
331
+ x_noisy = default(x_noisy, lambda: torch.randn_like(source))
332
+ b = x_noisy.shape[0]
333
+ sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
334
+ sigmas = repeat(sigmas, "i -> i b", b=b)
335
+ sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
336
+ alphas, betas = self.get_alpha_beta(sigmas_batch)
337
+ progress_bar = tqdm(range(num_steps), disable=not show_progress)
338
+
339
+ for i in progress_bar:
340
+ for r in range(num_resamples):
341
+ v_pred = self.net(x_noisy, sigmas[i], **kwargs)
342
+ x_pred = alphas[i] * x_noisy - betas[i] * v_pred
343
+ noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
344
+ # Renoise to current noise level if resampling
345
+ j = r == num_resamples - 1
346
+ x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred
347
+ s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like(
348
+ source
349
+ )
350
+ x_noisy = s_noisy * mask + x_noisy * ~mask
351
+
352
+ progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})")
353
+
354
+ return x_noisy
Modules/diffusion/audio_diffusion_pytorch/models.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from math import floor
3
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import pack, rearrange, unpack
7
+ from torch import Generator, Tensor, nn
8
+
9
+ from .components import AppendChannelsPlugin, MelSpectrogram
10
+ from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler
11
+ from .utils import (
12
+ closest_power_2,
13
+ default,
14
+ downsample,
15
+ exists,
16
+ groupby,
17
+ randn_like,
18
+ upsample,
19
+ )
20
+
21
+
22
+ class DiffusionModel(nn.Module):
23
+ def __init__(
24
+ self,
25
+ net_t: Callable,
26
+ diffusion_t: Callable = VDiffusion,
27
+ sampler_t: Callable = VSampler,
28
+ loss_fn: Callable = torch.nn.functional.mse_loss,
29
+ dim: int = 1,
30
+ **kwargs,
31
+ ):
32
+ super().__init__()
33
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
34
+ sampler_kwargs, kwargs = groupby("sampler_", kwargs)
35
+
36
+ self.net = net_t(dim=dim, **kwargs)
37
+ self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs)
38
+ self.sampler = sampler_t(net=self.net, **sampler_kwargs)
39
+
40
+ def forward(self, *args, **kwargs) -> Tensor:
41
+ return self.diffusion(*args, **kwargs)
42
+
43
+ @torch.no_grad()
44
+ def sample(self, *args, **kwargs) -> Tensor:
45
+ return self.sampler(*args, **kwargs)
46
+
47
+
48
+ class EncoderBase(nn.Module, ABC):
49
+ """Abstract class for DiffusionAE encoder"""
50
+
51
+ @abstractmethod
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.out_channels = None
55
+ self.downsample_factor = None
56
+
57
+
58
+ class AdapterBase(nn.Module, ABC):
59
+ """Abstract class for DiffusionAE encoder"""
60
+
61
+ @abstractmethod
62
+ def encode(self, x: Tensor) -> Tensor:
63
+ pass
64
+
65
+ @abstractmethod
66
+ def decode(self, x: Tensor) -> Tensor:
67
+ pass
68
+
69
+
70
+ class DiffusionAE(DiffusionModel):
71
+ """Diffusion Auto Encoder"""
72
+
73
+ def __init__(
74
+ self,
75
+ in_channels: int,
76
+ channels: Sequence[int],
77
+ encoder: EncoderBase,
78
+ inject_depth: int,
79
+ latent_factor: Optional[int] = None,
80
+ adapter: Optional[AdapterBase] = None,
81
+ **kwargs,
82
+ ):
83
+ context_channels = [0] * len(channels)
84
+ context_channels[inject_depth] = encoder.out_channels
85
+ super().__init__(
86
+ in_channels=in_channels,
87
+ channels=channels,
88
+ context_channels=context_channels,
89
+ **kwargs,
90
+ )
91
+ self.in_channels = in_channels
92
+ self.encoder = encoder
93
+ self.inject_depth = inject_depth
94
+ # Optional custom latent factor and adapter
95
+ self.latent_factor = default(latent_factor, self.encoder.downsample_factor)
96
+ self.adapter = adapter.requires_grad_(False) if exists(adapter) else None
97
+
98
+ def forward( # type: ignore
99
+ self, x: Tensor, with_info: bool = False, **kwargs
100
+ ) -> Union[Tensor, Tuple[Tensor, Any]]:
101
+ # Encode input to latent channels
102
+ latent, info = self.encode(x, with_info=True)
103
+ channels = [None] * self.inject_depth + [latent]
104
+ # Adapt input to diffusion if adapter provided
105
+ x = self.adapter.encode(x) if exists(self.adapter) else x
106
+ # Compute diffusion loss
107
+ loss = super().forward(x, channels=channels, **kwargs)
108
+ return (loss, info) if with_info else loss
109
+
110
+ def encode(self, *args, **kwargs):
111
+ return self.encoder(*args, **kwargs)
112
+
113
+ @torch.no_grad()
114
+ def decode(
115
+ self, latent: Tensor, generator: Optional[Generator] = None, **kwargs
116
+ ) -> Tensor:
117
+ b = latent.shape[0]
118
+ noise_length = closest_power_2(latent.shape[2] * self.latent_factor)
119
+ # Compute noise by inferring shape from latent length
120
+ noise = torch.randn(
121
+ (b, self.in_channels, noise_length),
122
+ device=latent.device,
123
+ dtype=latent.dtype,
124
+ generator=generator,
125
+ )
126
+ # Compute context from latent
127
+ channels = [None] * self.inject_depth + [latent] # type: ignore
128
+ # Decode by sampling while conditioning on latent channels
129
+ out = super().sample(noise, channels=channels, **kwargs)
130
+ # Decode output with adapter if provided
131
+ return self.adapter.decode(out) if exists(self.adapter) else out
132
+
133
+
134
+ class DiffusionUpsampler(DiffusionModel):
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ upsample_factor: int,
139
+ net_t: Callable,
140
+ **kwargs,
141
+ ):
142
+ self.upsample_factor = upsample_factor
143
+ super().__init__(
144
+ net_t=AppendChannelsPlugin(net_t, channels=in_channels),
145
+ in_channels=in_channels,
146
+ **kwargs,
147
+ )
148
+
149
+ def reupsample(self, x: Tensor) -> Tensor:
150
+ x = x.clone()
151
+ x = downsample(x, factor=self.upsample_factor)
152
+ x = upsample(x, factor=self.upsample_factor)
153
+ return x
154
+
155
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore
156
+ reupsampled = self.reupsample(x)
157
+ return super().forward(x, *args, append_channels=reupsampled, **kwargs)
158
+
159
+ @torch.no_grad()
160
+ def sample( # type: ignore
161
+ self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs
162
+ ) -> Tensor:
163
+ reupsampled = upsample(downsampled, factor=self.upsample_factor)
164
+ noise = randn_like(reupsampled, generator=generator)
165
+ return super().sample(noise, append_channels=reupsampled, **kwargs)
166
+
167
+
168
+ class DiffusionVocoder(DiffusionModel):
169
+ def __init__(
170
+ self,
171
+ net_t: Callable,
172
+ mel_channels: int,
173
+ mel_n_fft: int,
174
+ mel_hop_length: Optional[int] = None,
175
+ mel_win_length: Optional[int] = None,
176
+ in_channels: int = 1, # Ignored: channels are automatically batched.
177
+ **kwargs,
178
+ ):
179
+ mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4)
180
+ mel_win_length = default(mel_win_length, mel_n_fft)
181
+ mel_kwargs, kwargs = groupby("mel_", kwargs)
182
+ super().__init__(
183
+ net_t=AppendChannelsPlugin(net_t, channels=1),
184
+ in_channels=1,
185
+ **kwargs,
186
+ )
187
+ self.to_spectrogram = MelSpectrogram(
188
+ n_fft=mel_n_fft,
189
+ hop_length=mel_hop_length,
190
+ win_length=mel_win_length,
191
+ n_mel_channels=mel_channels,
192
+ **mel_kwargs,
193
+ )
194
+ self.to_flat = nn.ConvTranspose1d(
195
+ in_channels=mel_channels,
196
+ out_channels=1,
197
+ kernel_size=mel_win_length,
198
+ stride=mel_hop_length,
199
+ padding=(mel_win_length - mel_hop_length) // 2,
200
+ bias=False,
201
+ )
202
+
203
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore
204
+ # Get spectrogram, pack channels and flatten
205
+ spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l")
206
+ spectrogram_flat = self.to_flat(spectrogram)
207
+ # Pack wave channels
208
+ x = rearrange(x, "b c t -> (b c) 1 t")
209
+ return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs)
210
+
211
+ @torch.no_grad()
212
+ def sample( # type: ignore
213
+ self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs
214
+ ) -> Tensor: # type: ignore
215
+ # Pack channels and flatten spectrogram
216
+ spectrogram, ps = pack([spectrogram], "* f l")
217
+ spectrogram_flat = self.to_flat(spectrogram)
218
+ # Get start noise and sample
219
+ noise = randn_like(spectrogram_flat, generator=generator)
220
+ waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs)
221
+ # Unpack wave channels
222
+ waveform = rearrange(waveform, "... 1 t -> ... t")
223
+ waveform = unpack(waveform, ps, "* t")[0]
224
+ return waveform
225
+
226
+
227
+ class DiffusionAR(DiffusionModel):
228
+ def __init__(
229
+ self,
230
+ in_channels: int,
231
+ length: int,
232
+ num_splits: int,
233
+ diffusion_t: Callable = ARVDiffusion,
234
+ sampler_t: Callable = ARVSampler,
235
+ **kwargs,
236
+ ):
237
+ super().__init__(
238
+ in_channels=in_channels + 1,
239
+ out_channels=in_channels,
240
+ diffusion_t=diffusion_t,
241
+ diffusion_length=length,
242
+ diffusion_num_splits=num_splits,
243
+ sampler_t=sampler_t,
244
+ sampler_in_channels=in_channels,
245
+ sampler_length=length,
246
+ sampler_num_splits=num_splits,
247
+ use_time_conditioning=False,
248
+ use_modulation=False,
249
+ **kwargs,
250
+ )
Modules/diffusion/audio_diffusion_pytorch/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+
52
+ """
53
+ Kwargs Utils
54
+ """
55
+
56
+
57
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
58
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
59
+ for key in d.keys():
60
+ no_prefix = int(not key.startswith(prefix))
61
+ return_dicts[no_prefix][key] = d[key]
62
+ return return_dicts
63
+
64
+
65
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
66
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
67
+ if keep_prefix:
68
+ return kwargs_with_prefix, kwargs
69
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
70
+ return kwargs_no_prefix, kwargs
71
+
72
+
73
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
74
+ return {prefix + str(k): v for k, v in d.items()}
75
+
76
+
77
+ """
78
+ DSP Utils
79
+ """
80
+
81
+
82
+ def resample(
83
+ waveforms: Tensor,
84
+ factor_in: int,
85
+ factor_out: int,
86
+ rolloff: float = 0.99,
87
+ lowpass_filter_width: int = 6,
88
+ ) -> Tensor:
89
+ """Resamples a waveform using sinc interpolation, adapted from torchaudio"""
90
+ b, _, length = waveforms.shape
91
+ length_target = int(factor_out * length / factor_in)
92
+ d = dict(device=waveforms.device, dtype=waveforms.dtype)
93
+
94
+ base_factor = min(factor_in, factor_out) * rolloff
95
+ width = ceil(lowpass_filter_width * factor_in / base_factor)
96
+ idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
97
+ t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
98
+ t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi
99
+
100
+ window = torch.cos(t / lowpass_filter_width / 2) ** 2
101
+ scale = base_factor / factor_in
102
+ kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
103
+ kernels *= window * scale
104
+
105
+ waveforms = rearrange(waveforms, "b c t -> (b c) t")
106
+ waveforms = F.pad(waveforms, (width, width + factor_in))
107
+ resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in)
108
+ resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b)
109
+ return resampled[..., :length_target]
110
+
111
+
112
+ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
113
+ return resample(waveforms, factor_in=factor, factor_out=1, **kwargs)
114
+
115
+
116
+ def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
117
+ return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
118
+
119
+
120
+ """ Torch Utils """
121
+
122
+
123
+ def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs):
124
+ """randn_like that supports generator"""
125
+ return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor)
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
93
+
94
+
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self, style_dim, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.fc = nn.Linear(style_dim, channels*2)
25
+
26
+ def forward(self, x, s):
27
+ x = x.transpose(-1, -2)
28
+ x = x.transpose(1, -1)
29
+
30
+ h = self.fc(s)
31
+ h = h.view(h.size(0), h.size(1), 1)
32
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
33
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
34
+
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+ class StyleTransformer1d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_layers: int,
44
+ channels: int,
45
+ num_heads: int,
46
+ head_features: int,
47
+ multiplier: int,
48
+ use_context_time: bool = True,
49
+ use_rel_pos: bool = False,
50
+ context_features_multiplier: int = 1,
51
+ rel_pos_num_buckets: Optional[int] = None,
52
+ rel_pos_max_distance: Optional[int] = None,
53
+ context_features: Optional[int] = None,
54
+ context_embedding_features: Optional[int] = None,
55
+ embedding_max_length: int = 512,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.blocks = nn.ModuleList(
60
+ [
61
+ StyleTransformerBlock(
62
+ features=channels + context_embedding_features,
63
+ head_features=head_features,
64
+ num_heads=num_heads,
65
+ multiplier=multiplier,
66
+ style_dim=context_features,
67
+ use_rel_pos=use_rel_pos,
68
+ rel_pos_num_buckets=rel_pos_num_buckets,
69
+ rel_pos_max_distance=rel_pos_max_distance,
70
+ )
71
+ for i in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ self.to_out = nn.Sequential(
76
+ Rearrange("b t c -> b c t"),
77
+ nn.Conv1d(
78
+ in_channels=channels + context_embedding_features,
79
+ out_channels=channels,
80
+ kernel_size=1,
81
+ ),
82
+ )
83
+
84
+ use_context_features = exists(context_features)
85
+ self.use_context_features = use_context_features
86
+ self.use_context_time = use_context_time
87
+
88
+ if use_context_time or use_context_features:
89
+ context_mapping_features = channels + context_embedding_features
90
+
91
+ self.to_mapping = nn.Sequential(
92
+ nn.Linear(context_mapping_features, context_mapping_features),
93
+ nn.GELU(),
94
+ nn.Linear(context_mapping_features, context_mapping_features),
95
+ nn.GELU(),
96
+ )
97
+
98
+ if use_context_time:
99
+ assert exists(context_mapping_features)
100
+ self.to_time = nn.Sequential(
101
+ TimePositionalEmbedding(
102
+ dim=channels, out_features=context_mapping_features
103
+ ),
104
+ nn.GELU(),
105
+ )
106
+
107
+ if use_context_features:
108
+ assert exists(context_features) and exists(context_mapping_features)
109
+ self.to_features = nn.Sequential(
110
+ nn.Linear(
111
+ in_features=context_features, out_features=context_mapping_features
112
+ ),
113
+ nn.GELU(),
114
+ )
115
+
116
+ self.fixed_embedding = FixedEmbedding(
117
+ max_length=embedding_max_length, features=context_embedding_features
118
+ )
119
+
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+
146
+ mapping = self.get_mapping(time, features)
147
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
148
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
149
+
150
+ for block in self.blocks:
151
+ x = x + mapping
152
+ x = block(x, features)
153
+
154
+ x = x.mean(axis=1).unsqueeze(1)
155
+ x = self.to_out(x)
156
+ x = x.transpose(-1, -2)
157
+
158
+ return x
159
+
160
+ def forward(self, x: Tensor,
161
+ time: Tensor,
162
+ embedding_mask_proba: float = 0.0,
163
+ embedding: Optional[Tensor] = None,
164
+ features: Optional[Tensor] = None,
165
+ embedding_scale: float = 1.0) -> Tensor:
166
+
167
+ b, device = embedding.shape[0], embedding.device
168
+ fixed_embedding = self.fixed_embedding(embedding)
169
+ if embedding_mask_proba > 0.0:
170
+ # Randomly mask embedding
171
+ batch_mask = rand_bool(
172
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
173
+ )
174
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
+
176
+ if embedding_scale != 1.0:
177
+ # Compute both normal and fixed embedding outputs
178
+ out = self.run(x, time, embedding=embedding, features=features)
179
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
+ # Scale conditional output using classifier-free guidance
181
+ return out_masked + (out - out_masked) * embedding_scale
182
+ else:
183
+ return self.run(x, time, embedding=embedding, features=features)
184
+
185
+ return x
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ features: int,
192
+ num_heads: int,
193
+ head_features: int,
194
+ style_dim: int,
195
+ multiplier: int,
196
+ use_rel_pos: bool,
197
+ rel_pos_num_buckets: Optional[int] = None,
198
+ rel_pos_max_distance: Optional[int] = None,
199
+ context_features: Optional[int] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.use_cross_attention = exists(context_features) and context_features > 0
204
+
205
+ self.attention = StyleAttention(
206
+ features=features,
207
+ style_dim=style_dim,
208
+ num_heads=num_heads,
209
+ head_features=head_features,
210
+ use_rel_pos=use_rel_pos,
211
+ rel_pos_num_buckets=rel_pos_num_buckets,
212
+ rel_pos_max_distance=rel_pos_max_distance,
213
+ )
214
+
215
+ if self.use_cross_attention:
216
+ self.cross_attention = StyleAttention(
217
+ features=features,
218
+ style_dim=style_dim,
219
+ num_heads=num_heads,
220
+ head_features=head_features,
221
+ context_features=context_features,
222
+ use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets,
224
+ rel_pos_max_distance=rel_pos_max_distance,
225
+ )
226
+
227
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
+
229
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
+ x = self.attention(x, s) + x
231
+ if self.use_cross_attention:
232
+ x = self.cross_attention(x, s, context=context) + x
233
+ x = self.feed_forward(x) + x
234
+ return x
235
+
236
+ class StyleAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ features: int,
240
+ *,
241
+ style_dim: int,
242
+ head_features: int,
243
+ num_heads: int,
244
+ context_features: Optional[int] = None,
245
+ use_rel_pos: bool,
246
+ rel_pos_num_buckets: Optional[int] = None,
247
+ rel_pos_max_distance: Optional[int] = None,
248
+ ):
249
+ super().__init__()
250
+ self.context_features = context_features
251
+ mid_features = head_features * num_heads
252
+ context_features = default(context_features, features)
253
+
254
+ self.norm = AdaLayerNorm(style_dim, features)
255
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
256
+ self.to_q = nn.Linear(
257
+ in_features=features, out_features=mid_features, bias=False
258
+ )
259
+ self.to_kv = nn.Linear(
260
+ in_features=context_features, out_features=mid_features * 2, bias=False
261
+ )
262
+ self.attention = AttentionBase(
263
+ features,
264
+ num_heads=num_heads,
265
+ head_features=head_features,
266
+ use_rel_pos=use_rel_pos,
267
+ rel_pos_num_buckets=rel_pos_num_buckets,
268
+ rel_pos_max_distance=rel_pos_max_distance,
269
+ )
270
+
271
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
+ assert_message = "You must provide a context when using context_features"
273
+ assert not self.context_features or exists(context), assert_message
274
+ # Use context if provided
275
+ context = default(context, x)
276
+ # Normalize then compute q from input and k,v from context
277
+ x, context = self.norm(x, s), self.norm_context(context, s)
278
+
279
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
+ # Compute and return attention
281
+ return self.attention(q, k, v)
282
+
283
+ class Transformer1d(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_layers: int,
287
+ channels: int,
288
+ num_heads: int,
289
+ head_features: int,
290
+ multiplier: int,
291
+ use_context_time: bool = True,
292
+ use_rel_pos: bool = False,
293
+ context_features_multiplier: int = 1,
294
+ rel_pos_num_buckets: Optional[int] = None,
295
+ rel_pos_max_distance: Optional[int] = None,
296
+ context_features: Optional[int] = None,
297
+ context_embedding_features: Optional[int] = None,
298
+ embedding_max_length: int = 512,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.blocks = nn.ModuleList(
303
+ [
304
+ TransformerBlock(
305
+ features=channels + context_embedding_features,
306
+ head_features=head_features,
307
+ num_heads=num_heads,
308
+ multiplier=multiplier,
309
+ use_rel_pos=use_rel_pos,
310
+ rel_pos_num_buckets=rel_pos_num_buckets,
311
+ rel_pos_max_distance=rel_pos_max_distance,
312
+ )
313
+ for i in range(num_layers)
314
+ ]
315
+ )
316
+
317
+ self.to_out = nn.Sequential(
318
+ Rearrange("b t c -> b c t"),
319
+ nn.Conv1d(
320
+ in_channels=channels + context_embedding_features,
321
+ out_channels=channels,
322
+ kernel_size=1,
323
+ ),
324
+ )
325
+
326
+ use_context_features = exists(context_features)
327
+ self.use_context_features = use_context_features
328
+ self.use_context_time = use_context_time
329
+
330
+ if use_context_time or use_context_features:
331
+ context_mapping_features = channels + context_embedding_features
332
+
333
+ self.to_mapping = nn.Sequential(
334
+ nn.Linear(context_mapping_features, context_mapping_features),
335
+ nn.GELU(),
336
+ nn.Linear(context_mapping_features, context_mapping_features),
337
+ nn.GELU(),
338
+ )
339
+
340
+ if use_context_time:
341
+ assert exists(context_mapping_features)
342
+ self.to_time = nn.Sequential(
343
+ TimePositionalEmbedding(
344
+ dim=channels, out_features=context_mapping_features
345
+ ),
346
+ nn.GELU(),
347
+ )
348
+
349
+ if use_context_features:
350
+ assert exists(context_features) and exists(context_mapping_features)
351
+ self.to_features = nn.Sequential(
352
+ nn.Linear(
353
+ in_features=context_features, out_features=context_mapping_features
354
+ ),
355
+ nn.GELU(),
356
+ )
357
+
358
+ self.fixed_embedding = FixedEmbedding(
359
+ max_length=embedding_max_length, features=context_embedding_features
360
+ )
361
+
362
+
363
+ def get_mapping(
364
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
+ ) -> Optional[Tensor]:
366
+ """Combines context time features and features into mapping"""
367
+ items, mapping = [], None
368
+ # Compute time features
369
+ if self.use_context_time:
370
+ assert_message = "use_context_time=True but no time features provided"
371
+ assert exists(time), assert_message
372
+ items += [self.to_time(time)]
373
+ # Compute features
374
+ if self.use_context_features:
375
+ assert_message = "context_features exists but no features provided"
376
+ assert exists(features), assert_message
377
+ items += [self.to_features(features)]
378
+
379
+ # Compute joint mapping
380
+ if self.use_context_time or self.use_context_features:
381
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
+ mapping = self.to_mapping(mapping)
383
+
384
+ return mapping
385
+
386
+ def run(self, x, time, embedding, features):
387
+
388
+ mapping = self.get_mapping(time, features)
389
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
+
392
+ for block in self.blocks:
393
+ x = x + mapping
394
+ x = block(x)
395
+
396
+ x = x.mean(axis=1).unsqueeze(1)
397
+ x = self.to_out(x)
398
+ x = x.transpose(-1, -2)
399
+
400
+ return x
401
+
402
+ def forward(self, x: Tensor,
403
+ time: Tensor,
404
+ embedding_mask_proba: float = 0.0,
405
+ embedding: Optional[Tensor] = None,
406
+ features: Optional[Tensor] = None,
407
+ embedding_scale: float = 1.0) -> Tensor:
408
+
409
+ b, device = embedding.shape[0], embedding.device
410
+ fixed_embedding = self.fixed_embedding(embedding)
411
+ if embedding_mask_proba > 0.0:
412
+ # Randomly mask embedding
413
+ batch_mask = rand_bool(
414
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
+ )
416
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
+
418
+ if embedding_scale != 1.0:
419
+ # Compute both normal and fixed embedding outputs
420
+ out = self.run(x, time, embedding=embedding, features=features)
421
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
+ # Scale conditional output using classifier-free guidance
423
+ return out_masked + (out - out_masked) * embedding_scale
424
+ else:
425
+ return self.run(x, time, embedding=embedding, features=features)
426
+
427
+ return x
428
+
429
+
430
+ """
431
+ Attention Components
432
+ """
433
+
434
+
435
+ class RelativePositionBias(nn.Module):
436
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
+ super().__init__()
438
+ self.num_buckets = num_buckets
439
+ self.max_distance = max_distance
440
+ self.num_heads = num_heads
441
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
+
443
+ @staticmethod
444
+ def _relative_position_bucket(
445
+ relative_position: Tensor, num_buckets: int, max_distance: int
446
+ ):
447
+ num_buckets //= 2
448
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
449
+ n = torch.abs(relative_position)
450
+
451
+ max_exact = num_buckets // 2
452
+ is_small = n < max_exact
453
+
454
+ val_if_large = (
455
+ max_exact
456
+ + (
457
+ torch.log(n.float() / max_exact)
458
+ / log(max_distance / max_exact)
459
+ * (num_buckets - max_exact)
460
+ ).long()
461
+ )
462
+ val_if_large = torch.min(
463
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
+ )
465
+
466
+ ret += torch.where(is_small, n, val_if_large)
467
+ return ret
468
+
469
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
473
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
+
475
+ relative_position_bucket = self._relative_position_bucket(
476
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
+ )
478
+
479
+ bias = self.relative_attention_bias(relative_position_bucket)
480
+ bias = rearrange(bias, "m n h -> 1 h m n")
481
+ return bias
482
+
483
+
484
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
485
+ mid_features = features * multiplier
486
+ return nn.Sequential(
487
+ nn.Linear(in_features=features, out_features=mid_features),
488
+ nn.GELU(),
489
+ nn.Linear(in_features=mid_features, out_features=features),
490
+ )
491
+
492
+
493
+ class AttentionBase(nn.Module):
494
+ def __init__(
495
+ self,
496
+ features: int,
497
+ *,
498
+ head_features: int,
499
+ num_heads: int,
500
+ use_rel_pos: bool,
501
+ out_features: Optional[int] = None,
502
+ rel_pos_num_buckets: Optional[int] = None,
503
+ rel_pos_max_distance: Optional[int] = None,
504
+ ):
505
+ super().__init__()
506
+ self.scale = head_features ** -0.5
507
+ self.num_heads = num_heads
508
+ self.use_rel_pos = use_rel_pos
509
+ mid_features = head_features * num_heads
510
+
511
+ if use_rel_pos:
512
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
+ self.rel_pos = RelativePositionBias(
514
+ num_buckets=rel_pos_num_buckets,
515
+ max_distance=rel_pos_max_distance,
516
+ num_heads=num_heads,
517
+ )
518
+ if out_features is None:
519
+ out_features = features
520
+
521
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
522
+
523
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
524
+ # Split heads
525
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
526
+ # Compute similarity matrix
527
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
528
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
529
+ sim = sim * self.scale
530
+ # Get attention matrix with softmax
531
+ attn = sim.softmax(dim=-1)
532
+ # Compute values
533
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
534
+ out = rearrange(out, "b h n d -> b n (h d)")
535
+ return self.to_out(out)
536
+
537
+
538
+ class Attention(nn.Module):
539
+ def __init__(
540
+ self,
541
+ features: int,
542
+ *,
543
+ head_features: int,
544
+ num_heads: int,
545
+ out_features: Optional[int] = None,
546
+ context_features: Optional[int] = None,
547
+ use_rel_pos: bool,
548
+ rel_pos_num_buckets: Optional[int] = None,
549
+ rel_pos_max_distance: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.context_features = context_features
553
+ mid_features = head_features * num_heads
554
+ context_features = default(context_features, features)
555
+
556
+ self.norm = nn.LayerNorm(features)
557
+ self.norm_context = nn.LayerNorm(context_features)
558
+ self.to_q = nn.Linear(
559
+ in_features=features, out_features=mid_features, bias=False
560
+ )
561
+ self.to_kv = nn.Linear(
562
+ in_features=context_features, out_features=mid_features * 2, bias=False
563
+ )
564
+
565
+ self.attention = AttentionBase(
566
+ features,
567
+ out_features=out_features,
568
+ num_heads=num_heads,
569
+ head_features=head_features,
570
+ use_rel_pos=use_rel_pos,
571
+ rel_pos_num_buckets=rel_pos_num_buckets,
572
+ rel_pos_max_distance=rel_pos_max_distance,
573
+ )
574
+
575
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
576
+ assert_message = "You must provide a context when using context_features"
577
+ assert not self.context_features or exists(context), assert_message
578
+ # Use context if provided
579
+ context = default(context, x)
580
+ # Normalize then compute q from input and k,v from context
581
+ x, context = self.norm(x), self.norm_context(context)
582
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
583
+ # Compute and return attention
584
+ return self.attention(q, k, v)
585
+
586
+
587
+ """
588
+ Transformer Blocks
589
+ """
590
+
591
+
592
+ class TransformerBlock(nn.Module):
593
+ def __init__(
594
+ self,
595
+ features: int,
596
+ num_heads: int,
597
+ head_features: int,
598
+ multiplier: int,
599
+ use_rel_pos: bool,
600
+ rel_pos_num_buckets: Optional[int] = None,
601
+ rel_pos_max_distance: Optional[int] = None,
602
+ context_features: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+
606
+ self.use_cross_attention = exists(context_features) and context_features > 0
607
+
608
+ self.attention = Attention(
609
+ features=features,
610
+ num_heads=num_heads,
611
+ head_features=head_features,
612
+ use_rel_pos=use_rel_pos,
613
+ rel_pos_num_buckets=rel_pos_num_buckets,
614
+ rel_pos_max_distance=rel_pos_max_distance,
615
+ )
616
+
617
+ if self.use_cross_attention:
618
+ self.cross_attention = Attention(
619
+ features=features,
620
+ num_heads=num_heads,
621
+ head_features=head_features,
622
+ context_features=context_features,
623
+ use_rel_pos=use_rel_pos,
624
+ rel_pos_num_buckets=rel_pos_num_buckets,
625
+ rel_pos_max_distance=rel_pos_max_distance,
626
+ )
627
+
628
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
+
630
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
+ x = self.attention(x) + x
632
+ if self.use_cross_attention:
633
+ x = self.cross_attention(x, context=context) + x
634
+ x = self.feed_forward(x) + x
635
+ return x
636
+
637
+
638
+
639
+ """
640
+ Time Embeddings
641
+ """
642
+
643
+
644
+ class SinusoidalEmbedding(nn.Module):
645
+ def __init__(self, dim: int):
646
+ super().__init__()
647
+ self.dim = dim
648
+
649
+ def forward(self, x: Tensor) -> Tensor:
650
+ device, half_dim = x.device, self.dim // 2
651
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
+
656
+
657
+ class LearnedPositionalEmbedding(nn.Module):
658
+ """Used for continuous time"""
659
+
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ assert (dim % 2) == 0
663
+ half_dim = dim // 2
664
+ self.weights = nn.Parameter(torch.randn(half_dim))
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ x = rearrange(x, "b -> b 1")
668
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
669
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
670
+ fouriered = torch.cat((x, fouriered), dim=-1)
671
+ return fouriered
672
+
673
+
674
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
675
+ return nn.Sequential(
676
+ LearnedPositionalEmbedding(dim),
677
+ nn.Linear(in_features=dim + 1, out_features=out_features),
678
+ )
679
+
680
+ class FixedEmbedding(nn.Module):
681
+ def __init__(self, max_length: int, features: int):
682
+ super().__init__()
683
+ self.max_length = max_length
684
+ self.embedding = nn.Embedding(max_length, features)
685
+
686
+ def forward(self, x: Tensor) -> Tensor:
687
+ batch_size, length, device = *x.shape[0:2], x.device
688
+ assert_message = "Input sequence length must be <= max_length"
689
+ assert length <= self.max_length, assert_message
690
+ position = torch.arange(length, device=device)
691
+ fixed_embedding = self.embedding(position)
692
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
693
+ return fixed_embedding
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ deploy:
20
+
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v3
25
+ - name: Set up Python
26
+ uses: actions/setup-python@v3
27
+ with:
28
+ python-version: '3.x'
29
+ - name: Install dependencies
30
+ run: |
31
+ python -m pip install --upgrade pip
32
+ pip install build
33
+ - name: Build package
34
+ run: python -m build
35
+ - name: Publish package
36
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37
+ with:
38
+ user: __token__
39
+ password: ${{ secrets.PYPI_API_TOKEN }}
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .mypy_cache
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.pre-commit-config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v2.3.0
4
+ hooks:
5
+ - id: end-of-file-fixer
6
+ - id: trailing-whitespace
7
+
8
+ # Formats code correctly
9
+ - repo: https://github.com/psf/black
10
+ rev: 21.12b0
11
+ hooks:
12
+ - id: black
13
+ args: [
14
+ '--experimental-string-processing'
15
+ ]
16
+
17
+ # Sorts imports
18
+ - repo: https://github.com/pycqa/isort
19
+ rev: 5.10.1
20
+ hooks:
21
+ - id: isort
22
+ name: isort (python)
23
+ args: ["--profile", "black"]
24
+
25
+ # Checks unused imports, like lengths, etc
26
+ - repo: https://gitlab.com/pycqa/flake8
27
+ rev: 4.0.0
28
+ hooks:
29
+ - id: flake8
30
+ args: [
31
+ '--per-file-ignores=__init__.py:F401',
32
+ '--max-line-length=88',
33
+ '--ignore=E203,W503'
34
+ ]
35
+
36
+ # Checks types
37
+ - repo: https://github.com/pre-commit/mirrors-mypy
38
+ rev: 'v0.971'
39
+ hooks:
40
+ - id: mypy
41
+ additional_dependencies: [data-science-types>=0.2, torch>=1.6]
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 archinet.ai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LOGO.png ADDED
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/README.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <img src="./LOGO.png"></img>
2
+
3
+ A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. **Notes: (1) no pre-trained models are provided here, (2) the configs shown are indicative and untested, see [Moûsai](https://arxiv.org/abs/2301.11757) for the configs used in the paper.**
4
+
5
+
6
+ ## Install
7
+
8
+ ```bash
9
+ pip install audio-diffusion-pytorch
10
+ ```
11
+
12
+ [![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/)
13
+ [![Downloads](https://static.pepy.tech/personalized-badge/audio-diffusion-pytorch?period=total&units=international_system&left_color=black&right_color=black&left_text=Downloads)](https://pepy.tech/project/audio-diffusion-pytorch)
14
+
15
+
16
+ ## Usage
17
+
18
+ ### Unconditional Generator
19
+
20
+ ```py
21
+ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
22
+
23
+ model = DiffusionModel(
24
+ net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
25
+ in_channels=2, # U-Net: number of input/output (audio) channels
26
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
27
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
28
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
29
+ attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
30
+ attention_heads=8, # U-Net: number of attention heads per attention item
31
+ attention_features=64, # U-Net: number of attention features per attention item
32
+ diffusion_t=VDiffusion, # The diffusion method used
33
+ sampler_t=VSampler, # The diffusion sampler used
34
+ )
35
+
36
+ # Train model with audio waveforms
37
+ audio = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length]
38
+ loss = model(audio)
39
+ loss.backward()
40
+
41
+ # Turn noise into new audio sample with diffusion
42
+ noise = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length]
43
+ sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-100
44
+ ```
45
+
46
+ ### Text-Conditional Generator
47
+ A text-to-audio diffusion model that conditions the generation with `t5-base` text embeddings, requires `pip install transformers`.
48
+ ```py
49
+ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
50
+
51
+ model = DiffusionModel(
52
+ # ... same as unconditional model
53
+ use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
54
+ use_embedding_cfg=True, # U-Net: enables classifier free guidance
55
+ embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
56
+ embedding_features=768, # U-Net: text mbedding features (default for T5-base)
57
+ cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
58
+ )
59
+
60
+ # Train model with audio waveforms
61
+ audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
62
+ loss = model(
63
+ audio_wave,
64
+ text=['The audio description'], # Text conditioning, one element per batch
65
+ embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
66
+ )
67
+ loss.backward()
68
+
69
+ # Turn noise into new audio sample with diffusion
70
+ noise = torch.randn(1, 2, 2**18)
71
+ sample = model.sample(
72
+ noise,
73
+ text=['The audio description'],
74
+ embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
75
+ num_steps=2 # Higher for better quality, suggested num_steps: 10-100
76
+ )
77
+ ```
78
+
79
+ ### Diffusion Upsampler
80
+ Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz.
81
+ ```py
82
+ from audio_diffusion_pytorch import DiffusionUpsampler, UNetV0, VDiffusion, VSampler
83
+
84
+ upsampler = DiffusionUpsampler(
85
+ net_t=UNetV0, # The model type used for diffusion
86
+ upsample_factor=16, # The upsample factor (e.g. 16 can be used for 3kHz to 48kHz)
87
+ in_channels=2, # U-Net: number of input/output (audio) channels
88
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
89
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
90
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
91
+ diffusion_t=VDiffusion, # The diffusion method used
92
+ sampler_t=VSampler, # The diffusion sampler used
93
+ )
94
+
95
+ # Train model with high sample rate audio waveforms
96
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
97
+ loss = upsampler(audio)
98
+ loss.backward()
99
+
100
+ # Turn low sample rate audio into high sample rate
101
+ downsampled_audio = torch.randn(1, 2, 2**14) # [batch, in_channels, length]
102
+ sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [1, 2, 2**18]
103
+ ```
104
+
105
+ ### Diffusion Vocoder
106
+ Convert a mel-spectrogram to wavefrom using diffusion.
107
+ ```py
108
+ from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler
109
+
110
+ vocoder = DiffusionVocoder(
111
+ mel_n_fft=1024, # Mel-spectrogram n_fft
112
+ mel_channels=80, # Mel-spectrogram channels
113
+ mel_sample_rate=48000, # Mel-spectrogram sample rate
114
+ mel_normalize_log=True, # Mel-spectrogram log normalization (alternative is mel_normalize=True for [-1,1] power normalization)
115
+ net_t=UNetV0, # The model type used for diffusion vocoding
116
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
117
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
118
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
119
+ diffusion_t=VDiffusion, # The diffusion method used
120
+ sampler_t=VSampler, # The diffusion sampler used
121
+ )
122
+
123
+ # Train model on waveforms (automatically converted to mel internally)
124
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
125
+ loss = vocoder(audio)
126
+ loss.backward()
127
+
128
+ # Turn mel spectrogram into waveform
129
+ mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channels, mel_length]
130
+ sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18]
131
+ ```
132
+
133
+ ### Diffusion Autoencoder
134
+ Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it subclasses the `EncoderBase` class or contains an `out_channels` and `downsample_factor` field.
135
+ ```py
136
+ from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
137
+ from audio_encoders_pytorch import MelE1d, TanhBottleneck
138
+
139
+ autoencoder = DiffusionAE(
140
+ encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
141
+ in_channels=2,
142
+ channels=512,
143
+ multipliers=[1, 1],
144
+ factors=[2],
145
+ num_blocks=[12],
146
+ out_channels=32,
147
+ mel_channels=80,
148
+ mel_sample_rate=48000,
149
+ mel_normalize_log=True,
150
+ bottleneck=TanhBottleneck(),
151
+ ),
152
+ inject_depth=6,
153
+ net_t=UNetV0, # The model type used for diffusion upsampling
154
+ in_channels=2, # U-Net: number of input/output (audio) channels
155
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
156
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
157
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
158
+ diffusion_t=VDiffusion, # The diffusion method used
159
+ sampler_t=VSampler, # The diffusion sampler used
160
+ )
161
+
162
+ # Train autoencoder with audio samples
163
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
164
+ loss = autoencoder(audio)
165
+ loss.backward()
166
+
167
+ # Encode/decode audio
168
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
169
+ latent = autoencoder.encode(audio) # Encode
170
+ sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent
171
+ ```
172
+
173
+ ## Other
174
+
175
+ ### Inpainting
176
+ ```py
177
+ from audio_diffusion_pytorch import UNetV0, VInpainter
178
+
179
+ # The diffusion UNetV0 (this is an example, the net must be trained to work)
180
+ net = UNetV0(
181
+ dim=1,
182
+ in_channels=2, # U-Net: number of input/output (audio) channels
183
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
184
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
185
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
186
+ attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
187
+ attention_heads=8, # U-Net: number of attention heads per attention block
188
+ attention_features=64, # U-Net: number of attention features per attention block,
189
+ )
190
+
191
+ # Instantiate inpainter with trained net
192
+ inpainter = VInpainter(net=net)
193
+
194
+ # Inpaint source
195
+ y = inpainter(
196
+ source=torch.randn(1, 2, 2**18), # Start source
197
+ mask=torch.randint(0, 2, (1, 2, 2 ** 18), dtype=torch.bool), # Set to `True` the parts you want to keep
198
+ num_steps=10, # Number of inpainting steps
199
+ num_resamples=2, # Number of resampling steps
200
+ show_progress=True,
201
+ ) # [1, 2, 2 ** 18]
202
+ ```
203
+
204
+ ## Appreciation
205
+
206
+ * [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions.
207
+ * [ETH Zurich](https://inf.ethz.ch/) for the resources, [Zhijing Jin](https://zhijing-jin.com/), [Bernhard Schoelkopf](https://is.mpg.de/~bs), and [Mrinmaya Sachan](http://www.mrinmaya.io/) for supervising this Thesis.
208
+ * [Phil Wang](https://github.com/lucidrains) for the beautiful open source contributions on [diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch) and [Imagen](https://github.com/lucidrains/imagen-pytorch).
209
+ * [Katherine Crowson](https://github.com/crowsonkb) for the experiments with [k-diffusion](https://github.com/crowsonkb/k-diffusion) and the insane collection of samplers.
210
+
211
+ ## Citations
212
+
213
+ DDPM Diffusion
214
+ ```bibtex
215
+ @misc{2006.11239,
216
+ Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
217
+ Title = {Denoising Diffusion Probabilistic Models},
218
+ Year = {2020},
219
+ Eprint = {arXiv:2006.11239},
220
+ }
221
+ ```
222
+
223
+ DDIM (V-Sampler)
224
+ ```bibtex
225
+ @misc{2010.02502,
226
+ Author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
227
+ Title = {Denoising Diffusion Implicit Models},
228
+ Year = {2020},
229
+ Eprint = {arXiv:2010.02502},
230
+ }
231
+ ```
232
+
233
+ V-Diffusion
234
+ ```bibtex
235
+ @misc{2202.00512,
236
+ Author = {Tim Salimans and Jonathan Ho},
237
+ Title = {Progressive Distillation for Fast Sampling of Diffusion Models},
238
+ Year = {2022},
239
+ Eprint = {arXiv:2202.00512},
240
+ }
241
+ ```
242
+
243
+ Imagen (T5 Text Conditioning)
244
+ ```bibtex
245
+ @misc{2205.11487,
246
+ Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi},
247
+ Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
248
+ Year = {2022},
249
+ Eprint = {arXiv:2205.11487},
250
+ }
251
+ ```
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="audio-diffusion-pytorch",
5
+ packages=find_packages(exclude=[]),
6
+ version="0.1.3",
7
+ license="MIT",
8
+ description="Audio Diffusion - PyTorch",
9
+ long_description_content_type="text/markdown",
10
+ author="Flavio Schneider",
11
+ author_email="archinetai@protonmail.com",
12
+ url="https://github.com/archinetai/audio-diffusion-pytorch",
13
+ keywords=["artificial intelligence", "deep learning", "audio generation"],
14
+ install_requires=[
15
+ "tqdm",
16
+ "torch>=1.6",
17
+ "torchaudio",
18
+ "data-science-types>=0.2",
19
+ "einops>=0.6",
20
+ "a-unet",
21
+ ],
22
+ classifiers=[
23
+ "Development Status :: 4 - Beta",
24
+ "Intended Audience :: Developers",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "License :: OSI Approved :: MIT License",
27
+ "Programming Language :: Python :: 3.6",
28
+ ],
29
+ )
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/tests/testcustomloss.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
4
+ from audio_encoders_pytorch import MelE1d, TanhBottleneck
5
+ from auraloss.freq import MultiResolutionSTFTLoss
6
+
7
+ autoencoder = DiffusionAE(
8
+ encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
9
+ in_channels=2,
10
+ channels=512,
11
+ multipliers=[1, 1],
12
+ factors=[2],
13
+ num_blocks=[12],
14
+ out_channels=32,
15
+ mel_channels=80,
16
+ mel_sample_rate=48000,
17
+ mel_normalize_log=True,
18
+ bottleneck=TanhBottleneck(),
19
+ ),
20
+ inject_depth=6,
21
+ net_t=UNetV0, # The model type used for diffusion upsampling
22
+ in_channels=2, # U-Net: number of input/output (audio) channels
23
+ channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
24
+ factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
25
+ items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
26
+ diffusion_t=VDiffusion, # The diffusion method used
27
+ sampler_t=VSampler, # The diffusion sampler used
28
+ loss_fn=MultiResolutionSTFTLoss(), # The loss function used
29
+ )
30
+
31
+ # Train autoencoder with audio samples
32
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
33
+ loss = autoencoder(audio)
34
+ loss.backward()
35
+
36
+ # Encode/decode audio
37
+ audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
38
+ latent = autoencoder.encode(audio) # Encode
39
+ sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+
102
+ alias: str = ""
103
+
104
+ """Base diffusion class"""
105
+
106
+ def denoise_fn(
107
+ self,
108
+ x_noisy: Tensor,
109
+ sigmas: Optional[Tensor] = None,
110
+ sigma: Optional[float] = None,
111
+ **kwargs,
112
+ ) -> Tensor:
113
+ raise NotImplementedError("Diffusion class missing denoise_fn")
114
+
115
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
+ raise NotImplementedError("Diffusion class missing forward function")
117
+
118
+
119
+ class VDiffusion(Diffusion):
120
+
121
+ alias = "v"
122
+
123
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
+ super().__init__()
125
+ self.net = net
126
+ self.sigma_distribution = sigma_distribution
127
+
128
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
+ angle = sigmas * pi / 2
130
+ alpha = torch.cos(angle)
131
+ beta = torch.sin(angle)
132
+ return alpha, beta
133
+
134
+ def denoise_fn(
135
+ self,
136
+ x_noisy: Tensor,
137
+ sigmas: Optional[Tensor] = None,
138
+ sigma: Optional[float] = None,
139
+ **kwargs,
140
+ ) -> Tensor:
141
+ batch_size, device = x_noisy.shape[0], x_noisy.device
142
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
+ return self.net(x_noisy, sigmas, **kwargs)
144
+
145
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
+ batch_size, device = x.shape[0], x.device
147
+
148
+ # Sample amount of noise to add for each batch element
149
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
+
152
+ # Get noise
153
+ noise = default(noise, lambda: torch.randn_like(x))
154
+
155
+ # Combine input and noise weighted by half-circle
156
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
157
+ x_noisy = x * alpha + noise * beta
158
+ x_target = noise * alpha - x * beta
159
+
160
+ # Denoise and return loss
161
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
+ return F.mse_loss(x_denoised, x_target)
163
+
164
+
165
+ class KDiffusion(Diffusion):
166
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
+
168
+ alias = "k"
169
+
170
+ def __init__(
171
+ self,
172
+ net: nn.Module,
173
+ *,
174
+ sigma_distribution: Distribution,
175
+ sigma_data: float, # data distribution standard deviation
176
+ dynamic_threshold: float = 0.0,
177
+ ):
178
+ super().__init__()
179
+ self.net = net
180
+ self.sigma_data = sigma_data
181
+ self.sigma_distribution = sigma_distribution
182
+ self.dynamic_threshold = dynamic_threshold
183
+
184
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
185
+ sigma_data = self.sigma_data
186
+ c_noise = torch.log(sigmas) * 0.25
187
+ sigmas = rearrange(sigmas, "b -> b 1 1")
188
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
189
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
190
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
191
+ return c_skip, c_out, c_in, c_noise
192
+
193
+ def denoise_fn(
194
+ self,
195
+ x_noisy: Tensor,
196
+ sigmas: Optional[Tensor] = None,
197
+ sigma: Optional[float] = None,
198
+ **kwargs,
199
+ ) -> Tensor:
200
+ batch_size, device = x_noisy.shape[0], x_noisy.device
201
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
+
203
+ # Predict network output and add skip connection
204
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
+ x_denoised = c_skip * x_noisy + c_out * x_pred
207
+
208
+ return x_denoised
209
+
210
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
211
+ # Computes weight depending on data distribution
212
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
+
214
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
+ batch_size, device = x.shape[0], x.device
216
+ from einops import rearrange, reduce
217
+
218
+ # Sample amount of noise to add for each batch element
219
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
+
222
+ # Add noise to input
223
+ noise = default(noise, lambda: torch.randn_like(x))
224
+ x_noisy = x + sigmas_padded * noise
225
+
226
+ # Compute denoised values
227
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
+
229
+ # Compute weighted loss
230
+ losses = F.mse_loss(x_denoised, x, reduction="none")
231
+ losses = reduce(losses, "b ... -> b", "mean")
232
+ losses = losses * self.loss_weight(sigmas)
233
+ loss = losses.mean()
234
+ return loss
235
+
236
+
237
+ class VKDiffusion(Diffusion):
238
+
239
+ alias = "vk"
240
+
241
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
+ super().__init__()
243
+ self.net = net
244
+ self.sigma_distribution = sigma_distribution
245
+
246
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
+ sigma_data = 1.0
248
+ sigmas = rearrange(sigmas, "b -> b 1 1")
249
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
+ return c_skip, c_out, c_in
253
+
254
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
+ return sigmas.atan() / pi * 2
256
+
257
+ def t_to_sigma(self, t: Tensor) -> Tensor:
258
+ return (t * pi / 2).tan()
259
+
260
+ def denoise_fn(
261
+ self,
262
+ x_noisy: Tensor,
263
+ sigmas: Optional[Tensor] = None,
264
+ sigma: Optional[float] = None,
265
+ **kwargs,
266
+ ) -> Tensor:
267
+ batch_size, device = x_noisy.shape[0], x_noisy.device
268
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
+
270
+ # Predict network output and add skip connection
271
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
+ x_denoised = c_skip * x_noisy + c_out * x_pred
274
+ return x_denoised
275
+
276
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
+ batch_size, device = x.shape[0], x.device
278
+
279
+ # Sample amount of noise to add for each batch element
280
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
+
283
+ # Add noise to input
284
+ noise = default(noise, lambda: torch.randn_like(x))
285
+ x_noisy = x + sigmas_padded * noise
286
+
287
+ # Compute model output
288
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
+
291
+ # Compute v-objective target
292
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
+
294
+ # Compute loss
295
+ loss = F.mse_loss(x_pred, v_target)
296
+ return loss
297
+
298
+
299
+ """
300
+ Diffusion Sampling
301
+ """
302
+
303
+ """ Schedules """
304
+
305
+
306
+ class Schedule(nn.Module):
307
+ """Interface used by different sampling schedules"""
308
+
309
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
+ raise NotImplementedError()
311
+
312
+
313
+ class LinearSchedule(Schedule):
314
+ def forward(self, num_steps: int, device: Any) -> Tensor:
315
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
+ return sigmas
317
+
318
+
319
+ class KarrasSchedule(Schedule):
320
+ """https://arxiv.org/abs/2206.00364 equation 5"""
321
+
322
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
323
+ super().__init__()
324
+ self.sigma_min = sigma_min
325
+ self.sigma_max = sigma_max
326
+ self.rho = rho
327
+
328
+ def forward(self, num_steps: int, device: Any) -> Tensor:
329
+ rho_inv = 1.0 / self.rho
330
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
331
+ sigmas = (
332
+ self.sigma_max ** rho_inv
333
+ + (steps / (num_steps - 1))
334
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
335
+ ) ** self.rho
336
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
337
+ return sigmas
338
+
339
+
340
+ """ Samplers """
341
+
342
+
343
+ class Sampler(nn.Module):
344
+
345
+ diffusion_types: List[Type[Diffusion]] = []
346
+
347
+ def forward(
348
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
349
+ ) -> Tensor:
350
+ raise NotImplementedError()
351
+
352
+ def inpaint(
353
+ self,
354
+ source: Tensor,
355
+ mask: Tensor,
356
+ fn: Callable,
357
+ sigmas: Tensor,
358
+ num_steps: int,
359
+ num_resamples: int,
360
+ ) -> Tensor:
361
+ raise NotImplementedError("Inpainting not available with current sampler")
362
+
363
+
364
+ class VSampler(Sampler):
365
+
366
+ diffusion_types = [VDiffusion]
367
+
368
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
+ angle = sigma * pi / 2
370
+ alpha = cos(angle)
371
+ beta = sin(angle)
372
+ return alpha, beta
373
+
374
+ def forward(
375
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
+ ) -> Tensor:
377
+ x = sigmas[0] * noise
378
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
+
380
+ for i in range(num_steps - 1):
381
+ is_last = i == num_steps - 1
382
+
383
+ x_denoised = fn(x, sigma=sigmas[i])
384
+ x_pred = x * alpha - x_denoised * beta
385
+ x_eps = x * beta + x_denoised * alpha
386
+
387
+ if not is_last:
388
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
+ x = x_pred * alpha + x_eps * beta
390
+
391
+ return x_pred
392
+
393
+
394
+ class KarrasSampler(Sampler):
395
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
+
397
+ diffusion_types = [KDiffusion, VKDiffusion]
398
+
399
+ def __init__(
400
+ self,
401
+ s_tmin: float = 0,
402
+ s_tmax: float = float("inf"),
403
+ s_churn: float = 0.0,
404
+ s_noise: float = 1.0,
405
+ ):
406
+ super().__init__()
407
+ self.s_tmin = s_tmin
408
+ self.s_tmax = s_tmax
409
+ self.s_noise = s_noise
410
+ self.s_churn = s_churn
411
+
412
+ def step(
413
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
+ ) -> Tensor:
415
+ """Algorithm 2 (step)"""
416
+ # Select temporarily increased noise level
417
+ sigma_hat = sigma + gamma * sigma
418
+ # Add noise to move from sigma to sigma_hat
419
+ epsilon = self.s_noise * torch.randn_like(x)
420
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
+ # Evaluate ∂x/∂sigma at sigma_hat
422
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
+ # Take euler step from sigma_hat to sigma_next
424
+ x_next = x_hat + (sigma_next - sigma_hat) * d
425
+ # Second order correction
426
+ if sigma_next != 0:
427
+ model_out_next = fn(x_next, sigma=sigma_next)
428
+ d_prime = (x_next - model_out_next) / sigma_next
429
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
+ return x_next
431
+
432
+ def forward(
433
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
+ ) -> Tensor:
435
+ x = sigmas[0] * noise
436
+ # Compute gammas
437
+ gammas = torch.where(
438
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
+ min(self.s_churn / num_steps, sqrt(2) - 1),
440
+ 0.0,
441
+ )
442
+ # Denoise to sample
443
+ for i in range(num_steps - 1):
444
+ x = self.step(
445
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
+ )
447
+
448
+ return x
449
+
450
+
451
+ class AEulerSampler(Sampler):
452
+
453
+ diffusion_types = [KDiffusion, VKDiffusion]
454
+
455
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
+ return sigma_up, sigma_down
459
+
460
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
+ # Sigma steps
462
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
+ # Derivative at sigma (∂x/∂sigma)
464
+ d = (x - fn(x, sigma=sigma)) / sigma
465
+ # Euler method
466
+ x_next = x + d * (sigma_down - sigma)
467
+ # Add randomness
468
+ x_next = x_next + torch.randn_like(x) * sigma_up
469
+ return x_next
470
+
471
+ def forward(
472
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
+ ) -> Tensor:
474
+ x = sigmas[0] * noise
475
+ # Denoise to sample
476
+ for i in range(num_steps - 1):
477
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
+ return x
479
+
480
+
481
+ class ADPM2Sampler(Sampler):
482
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
+
484
+ diffusion_types = [KDiffusion, VKDiffusion]
485
+
486
+ def __init__(self, rho: float = 1.0):
487
+ super().__init__()
488
+ self.rho = rho
489
+
490
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
491
+ r = self.rho
492
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
493
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
494
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
495
+ return sigma_up, sigma_down, sigma_mid
496
+
497
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
498
+ # Sigma steps
499
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
500
+ # Derivative at sigma (∂x/∂sigma)
501
+ d = (x - fn(x, sigma=sigma)) / sigma
502
+ # Denoise to midpoint
503
+ x_mid = x + d * (sigma_mid - sigma)
504
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
505
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
506
+ # Denoise to next
507
+ x = x + d_mid * (sigma_down - sigma)
508
+ # Add randomness
509
+ x_next = x + torch.randn_like(x) * sigma_up
510
+ return x_next
511
+
512
+ def forward(
513
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
+ ) -> Tensor:
515
+ x = sigmas[0] * noise
516
+ # Denoise to sample
517
+ for i in range(num_steps - 1):
518
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
+ return x
520
+
521
+ def inpaint(
522
+ self,
523
+ source: Tensor,
524
+ mask: Tensor,
525
+ fn: Callable,
526
+ sigmas: Tensor,
527
+ num_steps: int,
528
+ num_resamples: int,
529
+ ) -> Tensor:
530
+ x = sigmas[0] * torch.randn_like(source)
531
+
532
+ for i in range(num_steps - 1):
533
+ # Noise source to current noise level
534
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
535
+ for r in range(num_resamples):
536
+ # Merge noisy source and current then denoise
537
+ x = source_noisy * mask + x * ~mask
538
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
+ # Renoise if not last resample step
540
+ if r < num_resamples - 1:
541
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
+ x = x + sigma * torch.randn_like(x)
543
+
544
+ return source * mask + x * ~mask
545
+
546
+
547
+ """ Main Classes """
548
+
549
+
550
+ class DiffusionSampler(nn.Module):
551
+ def __init__(
552
+ self,
553
+ diffusion: Diffusion,
554
+ *,
555
+ sampler: Sampler,
556
+ sigma_schedule: Schedule,
557
+ num_steps: Optional[int] = None,
558
+ clamp: bool = True,
559
+ ):
560
+ super().__init__()
561
+ self.denoise_fn = diffusion.denoise_fn
562
+ self.sampler = sampler
563
+ self.sigma_schedule = sigma_schedule
564
+ self.num_steps = num_steps
565
+ self.clamp = clamp
566
+
567
+ # Check sampler is compatible with diffusion type
568
+ sampler_class = sampler.__class__.__name__
569
+ diffusion_class = diffusion.__class__.__name__
570
+ message = f"{sampler_class} incompatible with {diffusion_class}"
571
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
+
573
+ def forward(
574
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
+ ) -> Tensor:
576
+ device = noise.device
577
+ num_steps = default(num_steps, self.num_steps) # type: ignore
578
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
579
+ # Compute sigmas using schedule
580
+ sigmas = self.sigma_schedule(num_steps, device)
581
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
582
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
583
+ # Sample using sampler
584
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
586
+ return x
587
+
588
+
589
+ class DiffusionInpainter(nn.Module):
590
+ def __init__(
591
+ self,
592
+ diffusion: Diffusion,
593
+ *,
594
+ num_steps: int,
595
+ num_resamples: int,
596
+ sampler: Sampler,
597
+ sigma_schedule: Schedule,
598
+ ):
599
+ super().__init__()
600
+ self.denoise_fn = diffusion.denoise_fn
601
+ self.num_steps = num_steps
602
+ self.num_resamples = num_resamples
603
+ self.inpaint_fn = sampler.inpaint
604
+ self.sigma_schedule = sigma_schedule
605
+
606
+ @torch.no_grad()
607
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
+ x = self.inpaint_fn(
609
+ source=inpaint,
610
+ mask=inpaint_mask,
611
+ fn=self.denoise_fn,
612
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
+ num_steps=self.num_steps,
614
+ num_resamples=self.num_resamples,
615
+ )
616
+ return x
617
+
618
+
619
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
620
+ length, device = like.shape[2], like.device
621
+ mask = torch.ones_like(like, dtype=torch.bool)
622
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
+ return mask
624
+
625
+
626
+ class SpanBySpanComposer(nn.Module):
627
+ def __init__(
628
+ self,
629
+ inpainter: DiffusionInpainter,
630
+ *,
631
+ num_spans: int,
632
+ ):
633
+ super().__init__()
634
+ self.inpainter = inpainter
635
+ self.num_spans = num_spans
636
+
637
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
+ half_length = start.shape[2] // 2
639
+
640
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
+ # Inpaint second half from first half
642
+ inpaint = torch.zeros_like(start)
643
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
644
+ inpaint_mask = sequential_mask(like=start, start=half_length)
645
+
646
+ for i in range(self.num_spans):
647
+ # Inpaint second half
648
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
+ # Replace first half with generated second half
650
+ second_half = span[:, :, half_length:]
651
+ inpaint[:, :, :half_length] = second_half
652
+ # Save generated span
653
+ spans.append(second_half)
654
+
655
+ return torch.cat(spans, dim=2)
656
+
657
+
658
+ class XDiffusion(nn.Module):
659
+ def __init__(self, type: str, net: nn.Module, **kwargs):
660
+ super().__init__()
661
+
662
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
664
+ message = f"type='{type}' must be one of {*aliases,}"
665
+ assert type in aliases, message
666
+ self.net = net
667
+
668
+ for XDiffusion in diffusion_classes:
669
+ if XDiffusion.alias == type: # type: ignore
670
+ self.diffusion = XDiffusion(net=net, **kwargs)
671
+
672
+ def forward(self, *args, **kwargs) -> Tensor:
673
+ return self.diffusion(*args, **kwargs)
674
+
675
+ def sample(
676
+ self,
677
+ noise: Tensor,
678
+ num_steps: int,
679
+ sigma_schedule: Schedule,
680
+ sampler: Sampler,
681
+ clamp: bool,
682
+ **kwargs,
683
+ ) -> Tensor:
684
+ diffusion_sampler = DiffusionSampler(
685
+ diffusion=self.diffusion,
686
+ sampler=sampler,
687
+ sigma_schedule=sigma_schedule,
688
+ num_steps=num_steps,
689
+ clamp=clamp,
690
+ )
691
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def rand_bool(shape, proba, device = None):
52
+ if proba == 1:
53
+ return torch.ones(shape, device=device, dtype=torch.bool)
54
+ elif proba == 0:
55
+ return torch.zeros(shape, device=device, dtype=torch.bool)
56
+ else:
57
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
+
59
+
60
+ """
61
+ Kwargs Utils
62
+ """
63
+
64
+
65
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
67
+ for key in d.keys():
68
+ no_prefix = int(not key.startswith(prefix))
69
+ return_dicts[no_prefix][key] = d[key]
70
+ return return_dicts
71
+
72
+
73
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
+ if keep_prefix:
76
+ return kwargs_with_prefix, kwargs
77
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
+ return kwargs_no_prefix, kwargs
79
+
80
+
81
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
82
+ return {prefix + str(k): v for k, v in d.items()}
Modules/discriminators.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+ def stft(x, fft_size, hop_size, win_length, window):
12
+ """Perform STFT and convert to magnitude spectrogram.
13
+ Args:
14
+ x (Tensor): Input signal tensor (B, T).
15
+ fft_size (int): FFT size.
16
+ hop_size (int): Hop size.
17
+ win_length (int): Window length.
18
+ window (str): Window function type.
19
+ Returns:
20
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
21
+ """
22
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
23
+ return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+ class SpecDiscriminator(nn.Module):
30
+ """docstring for Discriminator."""
31
+
32
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
33
+ super(SpecDiscriminator, self).__init__()
34
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
35
+ self.fft_size = fft_size
36
+ self.shift_size = shift_size
37
+ self.win_length = win_length
38
+ self.window = getattr(torch, window)(win_length)
39
+ self.discriminators = nn.ModuleList([
40
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
41
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
42
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
43
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
44
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
45
+ ])
46
+
47
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
48
+
49
+ def forward(self, y):
50
+
51
+ fmap = []
52
+ y = y.squeeze(1)
53
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
54
+ y = y.unsqueeze(1)
55
+ for i, d in enumerate(self.discriminators):
56
+ y = d(y)
57
+ y = F.leaky_relu(y, LRELU_SLOPE)
58
+ fmap.append(y)
59
+
60
+ y = self.out(y)
61
+ fmap.append(y)
62
+
63
+ return torch.flatten(y, 1, -1), fmap
64
+
65
+ class MultiResSpecDiscriminator(torch.nn.Module):
66
+
67
+ def __init__(self,
68
+ fft_sizes=[1024, 2048, 512],
69
+ hop_sizes=[120, 240, 50],
70
+ win_lengths=[600, 1200, 240],
71
+ window="hann_window"):
72
+
73
+ super(MultiResSpecDiscriminator, self).__init__()
74
+ self.discriminators = nn.ModuleList([
75
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
76
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
77
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
78
+ ])
79
+
80
+ def forward(self, y, y_hat):
81
+ y_d_rs = []
82
+ y_d_gs = []
83
+ fmap_rs = []
84
+ fmap_gs = []
85
+ for i, d in enumerate(self.discriminators):
86
+ y_d_r, fmap_r = d(y)
87
+ y_d_g, fmap_g = d(y_hat)
88
+ y_d_rs.append(y_d_r)
89
+ fmap_rs.append(fmap_r)
90
+ y_d_gs.append(y_d_g)
91
+ fmap_gs.append(fmap_g)
92
+
93
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
94
+
95
+
96
+ class DiscriminatorP(torch.nn.Module):
97
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
98
+ super(DiscriminatorP, self).__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
101
+ self.convs = nn.ModuleList([
102
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
104
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
105
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
106
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
107
+ ])
108
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
109
+
110
+ def forward(self, x):
111
+ fmap = []
112
+
113
+ # 1d to 2d
114
+ b, c, t = x.shape
115
+ if t % self.period != 0: # pad first
116
+ n_pad = self.period - (t % self.period)
117
+ x = F.pad(x, (0, n_pad), "reflect")
118
+ t = t + n_pad
119
+ x = x.view(b, c, t // self.period, self.period)
120
+
121
+ for l in self.convs:
122
+ x = l(x)
123
+ x = F.leaky_relu(x, LRELU_SLOPE)
124
+ fmap.append(x)
125
+ x = self.conv_post(x)
126
+ fmap.append(x)
127
+ x = torch.flatten(x, 1, -1)
128
+
129
+ return x, fmap
130
+
131
+
132
+ class MultiPeriodDiscriminator(torch.nn.Module):
133
+ def __init__(self):
134
+ super(MultiPeriodDiscriminator, self).__init__()
135
+ self.discriminators = nn.ModuleList([
136
+ DiscriminatorP(2),
137
+ DiscriminatorP(3),
138
+ DiscriminatorP(5),
139
+ DiscriminatorP(7),
140
+ DiscriminatorP(11),
141
+ ])
142
+
143
+ def forward(self, y, y_hat):
144
+ y_d_rs = []
145
+ y_d_gs = []
146
+ fmap_rs = []
147
+ fmap_gs = []
148
+ for i, d in enumerate(self.discriminators):
149
+ y_d_r, fmap_r = d(y)
150
+ y_d_g, fmap_g = d(y_hat)
151
+ y_d_rs.append(y_d_r)
152
+ fmap_rs.append(fmap_r)
153
+ y_d_gs.append(y_d_g)
154
+ fmap_gs.append(fmap_g)
155
+
156
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
157
+
158
+ class WavLMDiscriminator(nn.Module):
159
+ """docstring for Discriminator."""
160
+
161
+ def __init__(self, slm_hidden=768,
162
+ slm_layers=13,
163
+ initial_channel=64,
164
+ use_spectral_norm=False):
165
+ super(WavLMDiscriminator, self).__init__()
166
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
167
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168
+
169
+ self.convs = nn.ModuleList([
170
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
171
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
172
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
173
+ ])
174
+
175
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
176
+
177
+ def forward(self, x):
178
+ x = self.pre(x)
179
+
180
+ fmap = []
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x