Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- .gradio/certificate.pem +31 -0
- Configs/config.yml +116 -0
- Configs/config_ft.yml +116 -0
- Configs/config_kanade.yml +118 -0
- Inference/infer_24khz_mod.ipynb +0 -0
- Inference/input_for_prompt.txt +4 -0
- Inference/prompt.txt +4 -0
- Inference/random_texts.txt +14 -0
- LICENSE +25 -0
- Models/Style_Tsukasa_v02/Top_ckpt_24khz.pth +3 -0
- Models/Style_Tsukasa_v02/config_kanade.yml +121 -0
- Modules/KotoDama_sampler.py +269 -0
- Modules/__init__.py +1 -0
- Modules/__pycache__/KotoDama_sampler.cpython-311.pyc +0 -0
- Modules/__pycache__/__init__.cpython-311.pyc +0 -0
- Modules/__pycache__/discriminators.cpython-311.pyc +0 -0
- Modules/__pycache__/hifigan.cpython-311.pyc +0 -0
- Modules/__pycache__/istftnet.cpython-311.pyc +0 -0
- Modules/__pycache__/slmadv.cpython-311.pyc +0 -0
- Modules/__pycache__/utils.cpython-311.pyc +0 -0
- Modules/diffusion/__init__.py +1 -0
- Modules/diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
- Modules/diffusion/__pycache__/diffusion.cpython-311.pyc +0 -0
- Modules/diffusion/__pycache__/modules.cpython-311.pyc +0 -0
- Modules/diffusion/__pycache__/sampler.cpython-311.pyc +0 -0
- Modules/diffusion/__pycache__/utils.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/__init__.py +20 -0
- Modules/diffusion/audio_diffusion_pytorch/__pycache__/__init__.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/__pycache__/components.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/__pycache__/diffusion.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/__pycache__/models.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/__pycache__/utils.cpython-311.pyc +0 -0
- Modules/diffusion/audio_diffusion_pytorch/components.py +236 -0
- Modules/diffusion/audio_diffusion_pytorch/diffusion.py +354 -0
- Modules/diffusion/audio_diffusion_pytorch/models.py +250 -0
- Modules/diffusion/audio_diffusion_pytorch/utils.py +125 -0
- Modules/diffusion/diffusion.py +94 -0
- Modules/diffusion/modules.py +693 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.github/workflows/python-publish.yml +39 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.gitignore +2 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.pre-commit-config.yaml +41 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LICENSE +21 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LOGO.png +0 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/README.md +251 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/setup.py +29 -0
- Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/tests/testcustomloss.py +39 -0
- Modules/diffusion/sampler.py +691 -0
- Modules/diffusion/utils.py +82 -0
- Modules/discriminators.py +188 -0
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Utils/PLBERT/step_1050000.t7 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
reference_sample_wavs/01008270.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
reference_sample_wavs/kaede_san.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
reference_sample_wavs/riamu_zeroshot_02.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
reference_sample_wavs/sample_ref01.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
reference_sample_wavs/sample_ref02.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
reference_sample_wavs/shiki_fine05.wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
Configs/config.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/LJSpeech"
|
2 |
+
first_stage_path: "first_stage.pth"
|
3 |
+
save_freq: 2
|
4 |
+
log_interval: 10
|
5 |
+
device: "cuda"
|
6 |
+
epochs_1st: 200 # number of epochs for first stage training (pre-training)
|
7 |
+
epochs_2nd: 100 # number of peochs for second stage training (joint training)
|
8 |
+
batch_size: 16
|
9 |
+
max_len: 400 # maximum number of frames
|
10 |
+
pretrained_model: ""
|
11 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
12 |
+
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
F0_path: "Utils/JDC/bst.t7"
|
15 |
+
ASR_config: "Utils/ASR/config.yml"
|
16 |
+
ASR_path: "Utils/ASR/epoch_00080.pth"
|
17 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
18 |
+
|
19 |
+
data_params:
|
20 |
+
train_data: "Data/train_list.txt"
|
21 |
+
val_data: "Data/val_list.txt"
|
22 |
+
root_path: "/local/LJSpeech-1.1/wavs"
|
23 |
+
OOD_data: "Data/OOD_texts.txt"
|
24 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
25 |
+
|
26 |
+
preprocess_params:
|
27 |
+
sr: 24000
|
28 |
+
spect_params:
|
29 |
+
n_fft: 2048
|
30 |
+
win_length: 1200
|
31 |
+
hop_length: 300
|
32 |
+
|
33 |
+
model_params:
|
34 |
+
multispeaker: false
|
35 |
+
|
36 |
+
dim_in: 64
|
37 |
+
hidden_dim: 512
|
38 |
+
max_conv_dim: 512
|
39 |
+
n_layer: 3
|
40 |
+
n_mels: 80
|
41 |
+
|
42 |
+
n_token: 178 # number of phoneme tokens
|
43 |
+
max_dur: 50 # maximum duration of a single phoneme
|
44 |
+
style_dim: 128 # style vector size
|
45 |
+
|
46 |
+
dropout: 0.2
|
47 |
+
|
48 |
+
# config for decoder
|
49 |
+
decoder:
|
50 |
+
type: 'istftnet' # either hifigan or istftnet
|
51 |
+
resblock_kernel_sizes: [3,7,11]
|
52 |
+
upsample_rates : [10, 6]
|
53 |
+
upsample_initial_channel: 512
|
54 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
55 |
+
upsample_kernel_sizes: [20, 12]
|
56 |
+
gen_istft_n_fft: 20
|
57 |
+
gen_istft_hop_size: 5
|
58 |
+
|
59 |
+
# speech language model config
|
60 |
+
slm:
|
61 |
+
model: 'microsoft/wavlm-base-plus'
|
62 |
+
sr: 16000 # sampling rate of SLM
|
63 |
+
hidden: 768 # hidden size of SLM
|
64 |
+
nlayers: 13 # number of layers of SLM
|
65 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
66 |
+
|
67 |
+
# style diffusion model config
|
68 |
+
diffusion:
|
69 |
+
embedding_mask_proba: 0.1
|
70 |
+
# transformer config
|
71 |
+
transformer:
|
72 |
+
num_layers: 3
|
73 |
+
num_heads: 8
|
74 |
+
head_features: 64
|
75 |
+
multiplier: 2
|
76 |
+
|
77 |
+
# diffusion distribution config
|
78 |
+
dist:
|
79 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
80 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
81 |
+
mean: -3.0
|
82 |
+
std: 1.0
|
83 |
+
|
84 |
+
loss_params:
|
85 |
+
lambda_mel: 5. # mel reconstruction loss
|
86 |
+
lambda_gen: 1. # generator loss
|
87 |
+
lambda_slm: 1. # slm feature matching loss
|
88 |
+
|
89 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
90 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
91 |
+
TMA_epoch: 50 # TMA starting epoch (1st stage)
|
92 |
+
|
93 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
94 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
95 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
96 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
97 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
98 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
99 |
+
|
100 |
+
diff_epoch: 20 # style diffusion starting epoch (2nd stage)
|
101 |
+
joint_epoch: 50 # joint training starting epoch (2nd stage)
|
102 |
+
|
103 |
+
optimizer_params:
|
104 |
+
lr: 0.0001 # general learning rate
|
105 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
106 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
107 |
+
|
108 |
+
slmadv_params:
|
109 |
+
min_len: 400 # minimum length of samples
|
110 |
+
max_len: 500 # maximum length of samples
|
111 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
112 |
+
iter: 10 # update the discriminator every this iterations of generator update
|
113 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
114 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
115 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
116 |
+
|
Configs/config_ft.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/IMAS_FineTuned"
|
2 |
+
save_freq: 1
|
3 |
+
log_interval: 10
|
4 |
+
device: "cuda"
|
5 |
+
epochs: 50 # number of finetuning epoch (1 hour of data)
|
6 |
+
batch_size: 3
|
7 |
+
max_len: 2500 # maximum number of frames
|
8 |
+
pretrained_model: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade/NO_SLM_3_epoch_2nd_00002.pth"
|
9 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
10 |
+
load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
|
11 |
+
|
12 |
+
F0_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/JDC/bst.t7"
|
13 |
+
ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
|
14 |
+
ASR_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/bst_00080.pth"
|
15 |
+
|
16 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
17 |
+
|
18 |
+
data_params:
|
19 |
+
train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/FT_imas.csv"
|
20 |
+
val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/FT_imas_valid.csv"
|
21 |
+
root_path: ""
|
22 |
+
OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
|
23 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
24 |
+
|
25 |
+
|
26 |
+
preprocess_params:
|
27 |
+
sr: 24000
|
28 |
+
spect_params:
|
29 |
+
n_fft: 2048
|
30 |
+
win_length: 1200
|
31 |
+
hop_length: 300
|
32 |
+
|
33 |
+
model_params:
|
34 |
+
multispeaker: true
|
35 |
+
|
36 |
+
dim_in: 64
|
37 |
+
hidden_dim: 512
|
38 |
+
max_conv_dim: 512
|
39 |
+
n_layer: 3
|
40 |
+
n_mels: 80
|
41 |
+
|
42 |
+
n_token: 178 # number of phoneme tokens
|
43 |
+
max_dur: 50 # maximum duration of a single phoneme
|
44 |
+
style_dim: 128 # style vector size
|
45 |
+
|
46 |
+
dropout: 0.2
|
47 |
+
|
48 |
+
decoder:
|
49 |
+
type: 'istftnet' # either hifigan or istftnet
|
50 |
+
resblock_kernel_sizes: [3,7,11]
|
51 |
+
upsample_rates : [10, 6]
|
52 |
+
upsample_initial_channel: 512
|
53 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
54 |
+
upsample_kernel_sizes: [20, 12]
|
55 |
+
gen_istft_n_fft: 20
|
56 |
+
gen_istft_hop_size: 5
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# speech language model config
|
61 |
+
slm:
|
62 |
+
model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
|
63 |
+
sr: 16000 # sampling rate of SLM
|
64 |
+
hidden: 1280 # hidden size of SLM
|
65 |
+
nlayers: 33 # number of layers of SLM
|
66 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
67 |
+
|
68 |
+
# style diffusion model config
|
69 |
+
diffusion:
|
70 |
+
embedding_mask_proba: 0.1
|
71 |
+
# transformer config
|
72 |
+
transformer:
|
73 |
+
num_layers: 3
|
74 |
+
num_heads: 8
|
75 |
+
head_features: 64
|
76 |
+
multiplier: 2
|
77 |
+
|
78 |
+
# diffusion distribution config
|
79 |
+
dist:
|
80 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
81 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
82 |
+
mean: -3.0
|
83 |
+
std: 1.0
|
84 |
+
|
85 |
+
loss_params:
|
86 |
+
lambda_mel: 10. # mel reconstruction loss
|
87 |
+
lambda_gen: 1. # generator loss
|
88 |
+
lambda_slm: 1. # slm feature matching loss
|
89 |
+
|
90 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
91 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
92 |
+
TMA_epoch: 9 # TMA starting epoch (1st stage)
|
93 |
+
|
94 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
95 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
96 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
97 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
98 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
99 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
100 |
+
|
101 |
+
diff_epoch: 0 # style diffusion starting epoch (2nd stage)
|
102 |
+
joint_epoch: 30 # joint training starting epoch (2nd stage)
|
103 |
+
|
104 |
+
optimizer_params:
|
105 |
+
lr: 0.0001 # general learning rate
|
106 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
107 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
108 |
+
|
109 |
+
slmadv_params:
|
110 |
+
min_len: 400 # minimum length of samples
|
111 |
+
max_len: 500 # maximum length of samples
|
112 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
113 |
+
iter: 20 # update the discriminator every this iterations of generator update
|
114 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
115 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
116 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
Configs/config_kanade.yml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/Style_Kanade_v02"
|
2 |
+
first_stage_path: ""
|
3 |
+
save_freq: 1
|
4 |
+
log_interval: 10
|
5 |
+
device: "cuda"
|
6 |
+
epochs_1st: 30 # number of epochs for first stage training (pre-training)
|
7 |
+
epochs_2nd: 20 # number of peochs for second stage training (joint training)
|
8 |
+
batch_size: 64
|
9 |
+
max_len: 560 # maximum number of frames
|
10 |
+
pretrained_model: "Models/Style_Kanade_v02/epoch_2nd_00007.pth"
|
11 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
12 |
+
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
F0_path: "Utils/JDC/bst.t7"
|
15 |
+
ASR_config: "Utils/ASR/config.yml"
|
16 |
+
ASR_path: "Utils/ASR/bst_00080.pth"
|
17 |
+
|
18 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
19 |
+
|
20 |
+
data_params:
|
21 |
+
train_data: "Data/metadata_cleanest/DATA.csv"
|
22 |
+
val_data: "Data/VALID.txt"
|
23 |
+
root_path: ""
|
24 |
+
OOD_data: "Data/OOD_LargeScale_.csv"
|
25 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
26 |
+
|
27 |
+
|
28 |
+
preprocess_params:
|
29 |
+
sr: 24000
|
30 |
+
spect_params:
|
31 |
+
n_fft: 2048
|
32 |
+
win_length: 1200
|
33 |
+
hop_length: 300
|
34 |
+
|
35 |
+
model_params:
|
36 |
+
multispeaker: true
|
37 |
+
|
38 |
+
dim_in: 64
|
39 |
+
hidden_dim: 512
|
40 |
+
max_conv_dim: 512
|
41 |
+
n_layer: 3
|
42 |
+
n_mels: 80
|
43 |
+
|
44 |
+
n_token: 178 # number of phoneme tokens
|
45 |
+
max_dur: 50 # maximum duration of a single phoneme
|
46 |
+
style_dim: 128 # style vector size
|
47 |
+
|
48 |
+
dropout: 0.2
|
49 |
+
|
50 |
+
decoder:
|
51 |
+
type: 'istftnet' # either hifigan or istftnet
|
52 |
+
resblock_kernel_sizes: [3,7,11]
|
53 |
+
upsample_rates : [10, 6]
|
54 |
+
upsample_initial_channel: 512
|
55 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
56 |
+
upsample_kernel_sizes: [20, 12]
|
57 |
+
gen_istft_n_fft: 20
|
58 |
+
gen_istft_hop_size: 5
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
# speech language model config
|
63 |
+
slm:
|
64 |
+
model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
|
65 |
+
sr: 16000 # sampling rate of SLM
|
66 |
+
hidden: 1280 # hidden size of SLM
|
67 |
+
nlayers: 33 # number of layers of SLM
|
68 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
69 |
+
|
70 |
+
# style diffusion model config
|
71 |
+
diffusion:
|
72 |
+
embedding_mask_proba: 0.1
|
73 |
+
# transformer config
|
74 |
+
transformer:
|
75 |
+
num_layers: 3
|
76 |
+
num_heads: 8
|
77 |
+
head_features: 64
|
78 |
+
multiplier: 2
|
79 |
+
|
80 |
+
# diffusion distribution config
|
81 |
+
dist:
|
82 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
83 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
84 |
+
mean: -3.0
|
85 |
+
std: 1.0
|
86 |
+
|
87 |
+
loss_params:
|
88 |
+
lambda_mel: 10. # mel reconstruction loss
|
89 |
+
lambda_gen: 1. # generator loss
|
90 |
+
lambda_slm: 1. # slm feature matching loss
|
91 |
+
|
92 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
93 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
94 |
+
TMA_epoch: 5 # TMA starting epoch (1st stage)
|
95 |
+
|
96 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
97 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
98 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
99 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
100 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
101 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
102 |
+
|
103 |
+
diff_epoch: 4 # style diffusion starting epoch (2nd stage)
|
104 |
+
joint_epoch: 999 # joint training starting epoch (2nd stage)
|
105 |
+
|
106 |
+
optimizer_params:
|
107 |
+
lr: 0.0001 # general learning rate
|
108 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
109 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
110 |
+
|
111 |
+
slmadv_params:
|
112 |
+
min_len: 400 # minimum length of samples
|
113 |
+
max_len: 500 # maximum length of samples
|
114 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
115 |
+
iter: 20 # update the discriminator every this iterations of generator update
|
116 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
117 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
118 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
Inference/infer_24khz_mod.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Inference/input_for_prompt.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
この俺に何度も同じことを説明させるな!! お前たちは俺の忠告を完全に無視して、とんでもない結果を招いてしまった。これが最後の警告だ。次は絶対に許さないぞ!
|
2 |
+
時には、静けさの中にこそ、本当の答えが見つかるものですね。慌てる必要はないのです。
|
3 |
+
人生には、表現しきれないほどの驚きがあると思うよ。それは、目には見えない力で、人々を繋ぐ不思議な絆だ。私は、その驚きを胸に秘め、日々を楽しく過ごしているんだ。言葉を伝えるたびに、未来への期待を込めて、元気に話す。それは、夢を叶えるための魔法のようなものだ。
|
4 |
+
かなたの次元より迫り来る混沌の使者たちよ、貴様らの野望を我が焔の業火で焼き尽くす! 運命の歯車は、我が意思と共にすでに動き出したのだ。我が宿命の敵に立ち向かうため、禁断の呪文を紡ぐ時は今ここに訪れる。さあ、見るがよい。我が力を!!
|
Inference/prompt.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A male voice that resonates with deep, thunderous intensity. His rapid-fire words slam like aggressive drumbeats, each syllable charged with intense rage. The expressive tone fluctuates between restrained fury and explosive outbursts.
|
2 |
+
A female voice with a distinctively deep, low pitch that commands attention. Her slightly monotone delivery creates an air of composure and gravitas, while maintaining a calm, measured pace. Her voice carries a soothing weight, like gentle thunder in the distance, making her words feel grounded and reassuring.
|
3 |
+
a female voice that is gentle and soft, with a slightly high pitch that adds a comforting warmth to her words. Her tone is moderate, neither too expressive nor too flat, creating a balanced and soothing atmosphere. The slow speed of her speech gives her words a deliberate and thoughtful cadence, allowing each phrase to resonate fully. There's a sense of wonder and optimism in her voice, as if she is constantly marveling at the mysteries of life. Her gentle demeanor and soft delivery make her sound approachable and kind, inviting listeners to share in her sense of wonder.
|
4 |
+
A female voice that resonates with deep intensity and a distinctly low pitch. Her words flow with the force and rhythm of a relentless tide, each syllable weighted with profound determination. The expressive tone navigates between measured intensity and powerful surges.
|
Inference/random_texts.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Akashi: 不思議な人ですね、レザさんは。たまには子供扱いしてくれてちょっとむきになりますけど、とても頼りがいのある人だと思いますよ?
|
2 |
+
Kimiji: 人生は、果てしない探求の旅のようなもの。私たちは、自分自身や周囲の世界について、常に新しい発見をしていく。それは、時として喜びをもたらすこともあれば、困難に直面することもある。しかしそれら全てが、自分を形作る貴重な経験である。
|
3 |
+
Reira: 私に何度も同じことを説明させないでよ! お前たちは私の忠告を完全に無視して、とんでもない結果を招いてしまった!! これが最後の警告だ。次は絶対に許さないぞ! ------------------------------------------- (NOTE: enable the diffusion, then set the Intensity to 3, also remove this line!)
|
4 |
+
Yoriko: この世には、言葉にできない悲しみがある。 それは、胸の奥に沈んでいくような重さで、時間が経つにつれて、じわじわと広がっていく。私は、その悲しみを抱えながら、日々を過ごしていいた。 言葉を発するたびに、心の中で何度も繰り返し、慎重に選び抜いている。それは、痛みを和らげるための儀式のようなものだ.
|
5 |
+
Kimiji: 人生には、表現しきれないほどの驚きがあると思います。それは、目には見えない力で、人々を繋ぐ不思議な絆です。私は、その驚きを胸に秘め、日々を楽しく過ごしています。言葉を伝えるたびに、未来への期待を込めて、元気に話す。それは、夢を叶えるための魔法のようなものです。
|
6 |
+
Teppei: そうだな、この新しいシステムの仕組みについて説明しておこう。基本的には三層構造になっていて、各層が独立して機能している。一番下のレイヤーでデータの処理と保存を行い、真ん中の層でビジネスロジックを実装している。ユーザーが直接触れるのは最上位層だけで、そこでインターフェースの制御をしているんだ。パフォーマンスを考えると、データベースへのアクセスを最小限に抑える必要があるから、キャッシュの実装も検討している。ただ、システムの規模を考えると、当面は現状の構成で十分だと思われる。将来的にスケールする必要が出てきた場合は、その時点で見直しを検討すればいいだろう。
|
7 |
+
Kirara: べ、別にそんなことじゃないってば!あんたのことなんて全然気にしてないんだからね!
|
8 |
+
Maiko: ねえ、ちょっと今日の空を見上げて。朝から少しずつ変わっていく雲の形が、まるで漫画の中の風景みたい。東の方からゆっくりと暖かい風が吹いてきて、桜の花びらが舞い散るように、優しく大地を撫でていくの。春の陽気が徐々に夏の暑さに変わっていくこの季節は、なんだかわくわくするよね。でも、週間予報によると、明後日あたりから天気が崩れるみたいで、しばらくは傘の出番かもしれないんだ。梅雨の時期が近づいているから、空気も少しずつ湿っぽくなってきているのを感じない?でも、雨上がりの空気って、なんだか特別な匂いがして、私、結構好きなんだよね。
|
9 |
+
Amane: そ、そうかな?あなたと過ごした今までの時間は、私に、とても、とーっても大事だよ?だって、あなたの言葉には、どこか特別な温もりがあるような気がする。まるで、心の中に直接響いてくるようで、ほんの少しの間だけ、世界が優しくなった気がするよ。
|
10 |
+
Shioji: つい昨日のことみたいなのに、もう随分と昔のお話なんだね ! 今でも古い本の匂いがすると、あの静かで穏やかな時間がまるで透明で綺麗な海みたいに心の中でふわりと溶けていくの。
|
11 |
+
Riemi: かなたの次元より迫り来る混沌の使者たちよ、貴様らの野望を我が焔の業火で焼き尽くす! 運命の歯車は、我が意思と共にすでに動き出したのだ。我が宿命の敵に立ち向かうため、禁断の呪文を紡ぐ時は今ここに訪れる。さあ、見るがよい。我が力を!!
|
12 |
+
だめですよ?そんなことは。もっと慎重に行動してくださいね。
|
13 |
+
Kazunari: ojousama wa, katachi dake no tsukibito ga ireba sore de ii, to omotte orareru you desuga, katachi bakari de mo, sannen ka o tomo ni seikatsu o suru koto ni narimasɯ. --------------------------------------- (NOTE!: your spacing will impact the intonations. also remove this line!)
|
14 |
+
Kimiu: お前 wa muzukashiku 考えると kekka to shite 空回り suru taipu dakara na. 大体, sonna kanji de oboete ikeba iin da. -------------------------------------------------- (NOTE!: your spacing will impact the intonations. also remove this line!)
|
LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE FOR STYLETTS2 DEMO PAGE: GPL (I KNOW I DON'T LIKE GPL BUT I HAVE TO BECAUSE OF PHONEMIZER REQUIREMENT)
|
2 |
+
|
3 |
+
|
4 |
+
styletts 2 license:
|
5 |
+
|
6 |
+
Copyright (c) 2023 Aaron (Yinghao) Li
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
25 |
+
|
Models/Style_Tsukasa_v02/Top_ckpt_24khz.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f50001d997d5a73328c9f12155451372de0cdaf3a10ec5ab9cab7d577cd69b5
|
3 |
+
size 2044670782
|
Models/Style_Tsukasa_v02/config_kanade.yml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02"
|
2 |
+
first_stage_path: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02/epoch_1st_00026.pth"
|
3 |
+
save_freq: 1
|
4 |
+
log_interval: 10
|
5 |
+
device: "cuda"
|
6 |
+
epochs_1st: 30 # number of epochs for first stage training (pre-training)
|
7 |
+
epochs_2nd: 20 # number of peochs for second stage training (joint training)
|
8 |
+
batch_size: 64
|
9 |
+
max_len: 4000 # maximum number of frames
|
10 |
+
pretrained_model: "/home/austin/disk2/llmvcs/tt/stylekan/Models/Style_Kanade_v02/epoch_2nd_00007.pth"
|
11 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
12 |
+
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
# CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch train_first.py --config_path ./Configs/config_kanade.yml
|
15 |
+
# CUDA_VISIBLE_DEVICES=6,7 accelerate launch accelerate_train_second.py --config_path ./Configs/config_kanade_test.yml
|
16 |
+
|
17 |
+
F0_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/JDC/bst.t7"
|
18 |
+
ASR_config: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/config.yml"
|
19 |
+
ASR_path: "/home/austin/disk2/llmvcs/tt/stylekan/Utils/ASR/bst_00080.pth"
|
20 |
+
|
21 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
22 |
+
|
23 |
+
data_params:
|
24 |
+
train_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/metadata_cleanest/filtered_train_list_no_nsp_plus.csv"
|
25 |
+
val_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/mg_valid.txt"
|
26 |
+
root_path: ""
|
27 |
+
OOD_data: "/home/austin/disk2/llmvcs/tt/stylekan/Data/OOD_LargeScale_.csv"
|
28 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
29 |
+
|
30 |
+
|
31 |
+
preprocess_params:
|
32 |
+
sr: 24000
|
33 |
+
spect_params:
|
34 |
+
n_fft: 2048
|
35 |
+
win_length: 1200
|
36 |
+
hop_length: 300
|
37 |
+
|
38 |
+
model_params:
|
39 |
+
multispeaker: true
|
40 |
+
|
41 |
+
dim_in: 64
|
42 |
+
hidden_dim: 512
|
43 |
+
max_conv_dim: 512
|
44 |
+
n_layer: 3
|
45 |
+
n_mels: 80
|
46 |
+
|
47 |
+
n_token: 178 # number of phoneme tokens
|
48 |
+
max_dur: 50 # maximum duration of a single phoneme
|
49 |
+
style_dim: 128 # style vector size
|
50 |
+
|
51 |
+
dropout: 0.2
|
52 |
+
|
53 |
+
decoder:
|
54 |
+
type: 'istftnet' # either hifigan or istftnet
|
55 |
+
resblock_kernel_sizes: [3,7,11]
|
56 |
+
upsample_rates : [10, 6]
|
57 |
+
upsample_initial_channel: 512
|
58 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
59 |
+
upsample_kernel_sizes: [20, 12]
|
60 |
+
gen_istft_n_fft: 20
|
61 |
+
gen_istft_hop_size: 5
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
# speech language model config
|
66 |
+
slm:
|
67 |
+
model: 'Respair/Whisper_Large_v2_Encoder_Block' # The model itself is hardcoded, change it through -> losses.py
|
68 |
+
sr: 16000 # sampling rate of SLM
|
69 |
+
hidden: 1280 # hidden size of SLM
|
70 |
+
nlayers: 33 # number of layers of SLM
|
71 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
72 |
+
|
73 |
+
# style diffusion model config
|
74 |
+
diffusion:
|
75 |
+
embedding_mask_proba: 0.1
|
76 |
+
# transformer config
|
77 |
+
transformer:
|
78 |
+
num_layers: 3
|
79 |
+
num_heads: 8
|
80 |
+
head_features: 64
|
81 |
+
multiplier: 2
|
82 |
+
|
83 |
+
# diffusion distribution config
|
84 |
+
dist:
|
85 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
86 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
87 |
+
mean: -3.0
|
88 |
+
std: 1.0
|
89 |
+
|
90 |
+
loss_params:
|
91 |
+
lambda_mel: 10. # mel reconstruction loss
|
92 |
+
lambda_gen: 1. # generator loss
|
93 |
+
lambda_slm: 1. # slm feature matching loss
|
94 |
+
|
95 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
96 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
97 |
+
TMA_epoch: 5 # TMA starting epoch (1st stage)
|
98 |
+
|
99 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
100 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
101 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
102 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
103 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
104 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
105 |
+
|
106 |
+
diff_epoch: 4 # style diffusion starting epoch (2nd stage)
|
107 |
+
joint_epoch: 999 # joint training starting epoch (2nd stage)
|
108 |
+
|
109 |
+
optimizer_params:
|
110 |
+
lr: 0.0001 # general learning rate
|
111 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
112 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
113 |
+
|
114 |
+
slmadv_params:
|
115 |
+
min_len: 400 # minimum length of samples
|
116 |
+
max_len: 500 # maximum length of samples
|
117 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
118 |
+
iter: 20 # update the discriminator every this iterations of generator update
|
119 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
120 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
121 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
Modules/KotoDama_sampler.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from text_utils import TextCleaner
|
5 |
+
textclenaer = TextCleaner()
|
6 |
+
|
7 |
+
|
8 |
+
def length_to_mask(lengths):
|
9 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
10 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
11 |
+
return mask
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
+
|
18 |
+
|
19 |
+
# tokenizer_koto_prompt = AutoTokenizer.from_pretrained("google/mt5-small", trust_remote_code=True)
|
20 |
+
tokenizer_koto_prompt = AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese", trust_remote_code=True)
|
21 |
+
tokenizer_koto_text = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True)
|
22 |
+
|
23 |
+
class KotoDama_Prompt(PreTrainedModel):
|
24 |
+
|
25 |
+
def __init__(self, config):
|
26 |
+
super().__init__(config)
|
27 |
+
|
28 |
+
self.backbone = AutoModel.from_config(config)
|
29 |
+
|
30 |
+
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
|
31 |
+
nn.LeakyReLU(0.2),
|
32 |
+
nn.Linear(512, config.num_labels))
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def forward(
|
37 |
+
self,
|
38 |
+
input_ids,
|
39 |
+
attention_mask=None,
|
40 |
+
token_type_ids=None,
|
41 |
+
position_ids=None,
|
42 |
+
labels=None,
|
43 |
+
):
|
44 |
+
outputs = self.backbone(
|
45 |
+
input_ids,
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
token_type_ids=token_type_ids,
|
48 |
+
position_ids=position_ids,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
sequence_output = outputs.last_hidden_state[:, 0, :]
|
53 |
+
outputs = self.output(sequence_output)
|
54 |
+
|
55 |
+
# if labels, then we are training
|
56 |
+
loss = None
|
57 |
+
if labels is not None:
|
58 |
+
|
59 |
+
loss_fn = nn.MSELoss()
|
60 |
+
# labels = labels.unsqueeze(1)
|
61 |
+
loss = loss_fn(outputs, labels)
|
62 |
+
|
63 |
+
return {
|
64 |
+
"loss": loss,
|
65 |
+
"logits": outputs
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
class KotoDama_Text(PreTrainedModel):
|
70 |
+
|
71 |
+
def __init__(self, config):
|
72 |
+
super().__init__(config)
|
73 |
+
|
74 |
+
self.backbone = AutoModel.from_config(config)
|
75 |
+
|
76 |
+
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
|
77 |
+
nn.LeakyReLU(0.2),
|
78 |
+
nn.Linear(512, config.num_labels))
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
input_ids,
|
85 |
+
attention_mask=None,
|
86 |
+
# token_type_ids=None,
|
87 |
+
# position_ids=None,
|
88 |
+
labels=None,
|
89 |
+
):
|
90 |
+
outputs = self.backbone(
|
91 |
+
input_ids,
|
92 |
+
attention_mask=attention_mask,
|
93 |
+
# token_type_ids=token_type_ids,
|
94 |
+
# position_ids=position_ids,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
sequence_output = outputs.last_hidden_state[:, 0, :]
|
99 |
+
outputs = self.output(sequence_output)
|
100 |
+
|
101 |
+
# if labels, then we are training
|
102 |
+
loss = None
|
103 |
+
if labels is not None:
|
104 |
+
|
105 |
+
loss_fn = nn.MSELoss()
|
106 |
+
# labels = labels.unsqueeze(1)
|
107 |
+
loss = loss_fn(outputs, labels)
|
108 |
+
|
109 |
+
return {
|
110 |
+
"loss": loss,
|
111 |
+
"logits": outputs
|
112 |
+
}
|
113 |
+
|
114 |
+
|
115 |
+
def inference(model, diffusion_sampler, text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.):
|
116 |
+
|
117 |
+
tokens = textclenaer(text)
|
118 |
+
tokens.insert(0, 0)
|
119 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
123 |
+
|
124 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
125 |
+
|
126 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
127 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
128 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
133 |
+
embedding=bert_dur,
|
134 |
+
embedding_scale=embedding_scale,
|
135 |
+
features=ref_s, # reference from the same speaker as the embedding
|
136 |
+
num_steps=diffusion_steps).squeeze(1)
|
137 |
+
|
138 |
+
|
139 |
+
s = s_pred[:, 128:]
|
140 |
+
ref = s_pred[:, :128]
|
141 |
+
|
142 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
143 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
144 |
+
|
145 |
+
d = model.predictor.text_encoder(d_en,
|
146 |
+
s, input_lengths, text_mask)
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
x = model.predictor.lstm(d)
|
151 |
+
x_mod = model.predictor.prepare_projection(x)
|
152 |
+
duration = model.predictor.duration_proj(x_mod)
|
153 |
+
|
154 |
+
|
155 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
|
156 |
+
|
157 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
162 |
+
|
163 |
+
c_frame = 0
|
164 |
+
for i in range(pred_aln_trg.size(0)):
|
165 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
166 |
+
c_frame += int(pred_dur[i].data)
|
167 |
+
|
168 |
+
# encode prosody
|
169 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
174 |
+
|
175 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
176 |
+
|
177 |
+
|
178 |
+
out = model.decoder(asr,
|
179 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
180 |
+
|
181 |
+
|
182 |
+
return out.squeeze().cpu().numpy()[..., :-50]
|
183 |
+
|
184 |
+
|
185 |
+
def Longform(model, diffusion_sampler, text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0):
|
186 |
+
|
187 |
+
tokens = textclenaer(text)
|
188 |
+
tokens.insert(0, 0)
|
189 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
190 |
+
|
191 |
+
with torch.no_grad():
|
192 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
193 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
194 |
+
|
195 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
196 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
197 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
198 |
+
|
199 |
+
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
200 |
+
embedding=bert_dur,
|
201 |
+
embedding_scale=embedding_scale,
|
202 |
+
features=ref_s,
|
203 |
+
num_steps=diffusion_steps).squeeze(1)
|
204 |
+
|
205 |
+
if s_prev is not None:
|
206 |
+
# convex combination of previous and current style
|
207 |
+
s_pred = t * s_prev + (1 - t) * s_pred
|
208 |
+
|
209 |
+
s = s_pred[:, 128:]
|
210 |
+
ref = s_pred[:, :128]
|
211 |
+
|
212 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
213 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
214 |
+
|
215 |
+
s_pred = torch.cat([ref, s], dim=-1)
|
216 |
+
|
217 |
+
d = model.predictor.text_encoder(d_en,
|
218 |
+
s, input_lengths, text_mask)
|
219 |
+
|
220 |
+
x = model.predictor.lstm(d)
|
221 |
+
x_mod = model.predictor.prepare_projection(x) # 640 -> 512
|
222 |
+
duration = model.predictor.duration_proj(x_mod)
|
223 |
+
|
224 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
|
225 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
226 |
+
|
227 |
+
|
228 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
229 |
+
c_frame = 0
|
230 |
+
for i in range(pred_aln_trg.size(0)):
|
231 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
232 |
+
c_frame += int(pred_dur[i].data)
|
233 |
+
|
234 |
+
# encode prosody
|
235 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
236 |
+
|
237 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
238 |
+
|
239 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
240 |
+
|
241 |
+
out = model.decoder(asr,
|
242 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
243 |
+
|
244 |
+
|
245 |
+
return out.squeeze().cpu().numpy()[..., :-100], s_pred
|
246 |
+
|
247 |
+
|
248 |
+
def merge_short_elements(lst):
|
249 |
+
i = 0
|
250 |
+
while i < len(lst):
|
251 |
+
if i > 0 and len(lst[i]) < 10:
|
252 |
+
lst[i-1] += ' ' + lst[i]
|
253 |
+
lst.pop(i)
|
254 |
+
else:
|
255 |
+
i += 1
|
256 |
+
return lst
|
257 |
+
|
258 |
+
|
259 |
+
def merge_three(text_list, maxim=2):
|
260 |
+
|
261 |
+
merged_list = []
|
262 |
+
for i in range(0, len(text_list), maxim):
|
263 |
+
merged_text = ' '.join(text_list[i:i+maxim])
|
264 |
+
merged_list.append(merged_text)
|
265 |
+
return merged_list
|
266 |
+
|
267 |
+
|
268 |
+
def merging_sentences(lst):
|
269 |
+
return merge_three(merge_short_elements(lst))
|
Modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Modules/__pycache__/KotoDama_sampler.cpython-311.pyc
ADDED
Binary file (14.3 kB). View file
|
|
Modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (178 Bytes). View file
|
|
Modules/__pycache__/discriminators.cpython-311.pyc
ADDED
Binary file (12.2 kB). View file
|
|
Modules/__pycache__/hifigan.cpython-311.pyc
ADDED
Binary file (30 kB). View file
|
|
Modules/__pycache__/istftnet.cpython-311.pyc
ADDED
Binary file (34.4 kB). View file
|
|
Modules/__pycache__/slmadv.cpython-311.pyc
ADDED
Binary file (13.7 kB). View file
|
|
Modules/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (1.18 kB). View file
|
|
Modules/diffusion/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Modules/diffusion/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (188 Bytes). View file
|
|
Modules/diffusion/__pycache__/diffusion.cpython-311.pyc
ADDED
Binary file (5.55 kB). View file
|
|
Modules/diffusion/__pycache__/modules.cpython-311.pyc
ADDED
Binary file (32.8 kB). View file
|
|
Modules/diffusion/__pycache__/sampler.cpython-311.pyc
ADDED
Binary file (37.8 kB). View file
|
|
Modules/diffusion/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (5.87 kB). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .components import LTPlugin, MelSpectrogram, UNetV0, XUNet
|
2 |
+
from .diffusion import (
|
3 |
+
Diffusion,
|
4 |
+
Distribution,
|
5 |
+
LinearSchedule,
|
6 |
+
Sampler,
|
7 |
+
Schedule,
|
8 |
+
UniformDistribution,
|
9 |
+
VDiffusion,
|
10 |
+
VInpainter,
|
11 |
+
VSampler,
|
12 |
+
)
|
13 |
+
from .models import (
|
14 |
+
DiffusionAE,
|
15 |
+
DiffusionAR,
|
16 |
+
DiffusionModel,
|
17 |
+
DiffusionUpsampler,
|
18 |
+
DiffusionVocoder,
|
19 |
+
EncoderBase,
|
20 |
+
)
|
Modules/diffusion/audio_diffusion_pytorch/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (921 Bytes). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/__pycache__/components.cpython-311.pyc
ADDED
Binary file (9.72 kB). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/__pycache__/diffusion.cpython-311.pyc
ADDED
Binary file (22.5 kB). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/__pycache__/models.cpython-311.pyc
ADDED
Binary file (13.7 kB). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (8.28 kB). View file
|
|
Modules/diffusion/audio_diffusion_pytorch/components.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Sequence
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from a_unet import (
|
6 |
+
ClassifierFreeGuidancePlugin,
|
7 |
+
Conv,
|
8 |
+
Module,
|
9 |
+
TextConditioningPlugin,
|
10 |
+
TimeConditioningPlugin,
|
11 |
+
default,
|
12 |
+
exists,
|
13 |
+
)
|
14 |
+
from a_unet.apex import (
|
15 |
+
AttentionItem,
|
16 |
+
CrossAttentionItem,
|
17 |
+
InjectChannelsItem,
|
18 |
+
ModulationItem,
|
19 |
+
ResnetItem,
|
20 |
+
SkipCat,
|
21 |
+
SkipModulate,
|
22 |
+
XBlock,
|
23 |
+
XUNet,
|
24 |
+
)
|
25 |
+
from einops import pack, unpack
|
26 |
+
from torch import Tensor, nn
|
27 |
+
from torchaudio import transforms
|
28 |
+
|
29 |
+
"""
|
30 |
+
UNets (built with a-unet: https://github.com/archinetai/a-unet)
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
def UNetV0(
|
35 |
+
dim: int,
|
36 |
+
in_channels: int,
|
37 |
+
channels: Sequence[int],
|
38 |
+
factors: Sequence[int],
|
39 |
+
items: Sequence[int],
|
40 |
+
attentions: Optional[Sequence[int]] = None,
|
41 |
+
cross_attentions: Optional[Sequence[int]] = None,
|
42 |
+
context_channels: Optional[Sequence[int]] = None,
|
43 |
+
attention_features: Optional[int] = None,
|
44 |
+
attention_heads: Optional[int] = None,
|
45 |
+
embedding_features: Optional[int] = None,
|
46 |
+
resnet_groups: int = 8,
|
47 |
+
use_modulation: bool = True,
|
48 |
+
modulation_features: int = 1024,
|
49 |
+
embedding_max_length: Optional[int] = None,
|
50 |
+
use_time_conditioning: bool = True,
|
51 |
+
use_embedding_cfg: bool = False,
|
52 |
+
use_text_conditioning: bool = False,
|
53 |
+
out_channels: Optional[int] = None,
|
54 |
+
):
|
55 |
+
# Set defaults and check lengths
|
56 |
+
num_layers = len(channels)
|
57 |
+
attentions = default(attentions, [0] * num_layers)
|
58 |
+
cross_attentions = default(cross_attentions, [0] * num_layers)
|
59 |
+
context_channels = default(context_channels, [0] * num_layers)
|
60 |
+
xs = (channels, factors, items, attentions, cross_attentions, context_channels)
|
61 |
+
assert all(len(x) == num_layers for x in xs) # type: ignore
|
62 |
+
|
63 |
+
# Define UNet type
|
64 |
+
UNetV0 = XUNet
|
65 |
+
|
66 |
+
if use_embedding_cfg:
|
67 |
+
msg = "use_embedding_cfg requires embedding_max_length"
|
68 |
+
assert exists(embedding_max_length), msg
|
69 |
+
UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length)
|
70 |
+
|
71 |
+
if use_text_conditioning:
|
72 |
+
UNetV0 = TextConditioningPlugin(UNetV0)
|
73 |
+
|
74 |
+
if use_time_conditioning:
|
75 |
+
assert use_modulation, "use_time_conditioning requires use_modulation=True"
|
76 |
+
UNetV0 = TimeConditioningPlugin(UNetV0)
|
77 |
+
|
78 |
+
# Build
|
79 |
+
return UNetV0(
|
80 |
+
dim=dim,
|
81 |
+
in_channels=in_channels,
|
82 |
+
out_channels=out_channels,
|
83 |
+
blocks=[
|
84 |
+
XBlock(
|
85 |
+
channels=channels,
|
86 |
+
factor=factor,
|
87 |
+
context_channels=ctx_channels,
|
88 |
+
items=(
|
89 |
+
[ResnetItem]
|
90 |
+
+ [ModulationItem] * use_modulation
|
91 |
+
+ [InjectChannelsItem] * (ctx_channels > 0)
|
92 |
+
+ [AttentionItem] * att
|
93 |
+
+ [CrossAttentionItem] * cross
|
94 |
+
)
|
95 |
+
* items,
|
96 |
+
)
|
97 |
+
for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa
|
98 |
+
],
|
99 |
+
skip_t=SkipModulate if use_modulation else SkipCat,
|
100 |
+
attention_features=attention_features,
|
101 |
+
attention_heads=attention_heads,
|
102 |
+
embedding_features=embedding_features,
|
103 |
+
modulation_features=modulation_features,
|
104 |
+
resnet_groups=resnet_groups,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
"""
|
109 |
+
Plugins
|
110 |
+
"""
|
111 |
+
|
112 |
+
|
113 |
+
def LTPlugin(
|
114 |
+
net_t: Callable, num_filters: int, window_length: int, stride: int
|
115 |
+
) -> Callable[..., nn.Module]:
|
116 |
+
"""Learned Transform Plugin"""
|
117 |
+
|
118 |
+
def Net(
|
119 |
+
dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs
|
120 |
+
) -> nn.Module:
|
121 |
+
out_channels = default(out_channels, in_channels)
|
122 |
+
in_channel_transform = in_channels * num_filters
|
123 |
+
out_channel_transform = out_channels * num_filters # type: ignore
|
124 |
+
|
125 |
+
padding = window_length // 2 - stride // 2
|
126 |
+
encode = Conv(
|
127 |
+
dim=dim,
|
128 |
+
in_channels=in_channels,
|
129 |
+
out_channels=in_channel_transform,
|
130 |
+
kernel_size=window_length,
|
131 |
+
stride=stride,
|
132 |
+
padding=padding,
|
133 |
+
padding_mode="reflect",
|
134 |
+
bias=False,
|
135 |
+
)
|
136 |
+
decode = nn.ConvTranspose1d(
|
137 |
+
in_channels=out_channel_transform,
|
138 |
+
out_channels=out_channels, # type: ignore
|
139 |
+
kernel_size=window_length,
|
140 |
+
stride=stride,
|
141 |
+
padding=padding,
|
142 |
+
bias=False,
|
143 |
+
)
|
144 |
+
net = net_t( # type: ignore
|
145 |
+
dim=dim,
|
146 |
+
in_channels=in_channel_transform,
|
147 |
+
out_channels=out_channel_transform,
|
148 |
+
**kwargs
|
149 |
+
)
|
150 |
+
|
151 |
+
def forward(x: Tensor, *args, **kwargs):
|
152 |
+
x = encode(x)
|
153 |
+
x = net(x, *args, **kwargs)
|
154 |
+
x = decode(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
return Module([encode, decode, net], forward)
|
158 |
+
|
159 |
+
return Net
|
160 |
+
|
161 |
+
|
162 |
+
def AppendChannelsPlugin(
|
163 |
+
net_t: Callable,
|
164 |
+
channels: int,
|
165 |
+
):
|
166 |
+
def Net(
|
167 |
+
in_channels: int, out_channels: Optional[int] = None, **kwargs
|
168 |
+
) -> nn.Module:
|
169 |
+
out_channels = default(out_channels, in_channels)
|
170 |
+
net = net_t( # type: ignore
|
171 |
+
in_channels=in_channels + channels, out_channels=out_channels, **kwargs
|
172 |
+
)
|
173 |
+
|
174 |
+
def forward(x: Tensor, *args, append_channels: Tensor, **kwargs):
|
175 |
+
x = torch.cat([x, append_channels], dim=1)
|
176 |
+
return net(x, *args, **kwargs)
|
177 |
+
|
178 |
+
return Module([net], forward)
|
179 |
+
|
180 |
+
return Net
|
181 |
+
|
182 |
+
|
183 |
+
"""
|
184 |
+
Other
|
185 |
+
"""
|
186 |
+
|
187 |
+
|
188 |
+
class MelSpectrogram(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
n_fft: int,
|
192 |
+
hop_length: int,
|
193 |
+
win_length: int,
|
194 |
+
sample_rate: int,
|
195 |
+
n_mel_channels: int,
|
196 |
+
center: bool = False,
|
197 |
+
normalize: bool = False,
|
198 |
+
normalize_log: bool = False,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
self.padding = (n_fft - hop_length) // 2
|
202 |
+
self.normalize = normalize
|
203 |
+
self.normalize_log = normalize_log
|
204 |
+
self.hop_length = hop_length
|
205 |
+
|
206 |
+
self.to_spectrogram = transforms.Spectrogram(
|
207 |
+
n_fft=n_fft,
|
208 |
+
hop_length=hop_length,
|
209 |
+
win_length=win_length,
|
210 |
+
center=center,
|
211 |
+
power=None,
|
212 |
+
)
|
213 |
+
|
214 |
+
self.to_mel_scale = transforms.MelScale(
|
215 |
+
n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, waveform: Tensor) -> Tensor:
|
219 |
+
# Pack non-time dimension
|
220 |
+
waveform, ps = pack([waveform], "* t")
|
221 |
+
# Pad waveform
|
222 |
+
waveform = F.pad(waveform, [self.padding] * 2, mode="reflect")
|
223 |
+
# Compute STFT
|
224 |
+
spectrogram = self.to_spectrogram(waveform)
|
225 |
+
# Compute magnitude
|
226 |
+
spectrogram = torch.abs(spectrogram)
|
227 |
+
# Convert to mel scale
|
228 |
+
mel_spectrogram = self.to_mel_scale(spectrogram)
|
229 |
+
# Normalize
|
230 |
+
if self.normalize:
|
231 |
+
mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram)
|
232 |
+
mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1
|
233 |
+
if self.normalize_log:
|
234 |
+
mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5))
|
235 |
+
# Unpack non-spectrogram dimension
|
236 |
+
return unpack(mel_spectrogram, ps, "* f l")[0]
|
Modules/diffusion/audio_diffusion_pytorch/diffusion.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from typing import Any, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
from torch import Tensor
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .utils import default
|
12 |
+
|
13 |
+
""" Distributions """
|
14 |
+
|
15 |
+
|
16 |
+
class Distribution:
|
17 |
+
"""Interface used by different distributions"""
|
18 |
+
|
19 |
+
def __call__(self, num_samples: int, device: torch.device):
|
20 |
+
raise NotImplementedError()
|
21 |
+
|
22 |
+
|
23 |
+
class UniformDistribution(Distribution):
|
24 |
+
def __init__(self, vmin: float = 0.0, vmax: float = 1.0):
|
25 |
+
super().__init__()
|
26 |
+
self.vmin, self.vmax = vmin, vmax
|
27 |
+
|
28 |
+
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
29 |
+
vmax, vmin = self.vmax, self.vmin
|
30 |
+
return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin
|
31 |
+
|
32 |
+
|
33 |
+
""" Diffusion Methods """
|
34 |
+
|
35 |
+
|
36 |
+
def pad_dims(x: Tensor, ndim: int) -> Tensor:
|
37 |
+
# Pads additional ndims to the right of the tensor
|
38 |
+
return x.view(*x.shape, *((1,) * ndim))
|
39 |
+
|
40 |
+
|
41 |
+
def clip(x: Tensor, dynamic_threshold: float = 0.0):
|
42 |
+
if dynamic_threshold == 0.0:
|
43 |
+
return x.clamp(-1.0, 1.0)
|
44 |
+
else:
|
45 |
+
# Dynamic thresholding
|
46 |
+
# Find dynamic threshold quantile for each batch
|
47 |
+
x_flat = rearrange(x, "b ... -> b (...)")
|
48 |
+
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
|
49 |
+
# Clamp to a min of 1.0
|
50 |
+
scale.clamp_(min=1.0)
|
51 |
+
# Clamp all values and scale
|
52 |
+
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
|
53 |
+
x = x.clamp(-scale, scale) / scale
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
def extend_dim(x: Tensor, dim: int):
|
58 |
+
# e.g. if dim = 4: shape [b] => [b, 1, 1, 1],
|
59 |
+
return x.view(*x.shape + (1,) * (dim - x.ndim))
|
60 |
+
|
61 |
+
|
62 |
+
class Diffusion(nn.Module):
|
63 |
+
"""Interface used by different diffusion methods"""
|
64 |
+
|
65 |
+
pass
|
66 |
+
|
67 |
+
|
68 |
+
class VDiffusion(Diffusion):
|
69 |
+
def __init__(
|
70 |
+
self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution(), loss_fn: Any = F.mse_loss
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.net = net
|
74 |
+
self.sigma_distribution = sigma_distribution
|
75 |
+
self.loss_fn = loss_fn
|
76 |
+
|
77 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
78 |
+
angle = sigmas * pi / 2
|
79 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
80 |
+
return alpha, beta
|
81 |
+
|
82 |
+
def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore
|
83 |
+
batch_size, device = x.shape[0], x.device
|
84 |
+
# Sample amount of noise to add for each batch element
|
85 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
86 |
+
sigmas_batch = extend_dim(sigmas, dim=x.ndim)
|
87 |
+
# Get noise
|
88 |
+
noise = torch.randn_like(x)
|
89 |
+
# Combine input and noise weighted by half-circle
|
90 |
+
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
91 |
+
x_noisy = alphas * x + betas * noise
|
92 |
+
v_target = alphas * noise - betas * x
|
93 |
+
# Predict velocity and return loss
|
94 |
+
v_pred = self.net(x_noisy, sigmas, **kwargs)
|
95 |
+
return self.loss_fn(v_pred, v_target)
|
96 |
+
|
97 |
+
|
98 |
+
class ARVDiffusion(Diffusion):
|
99 |
+
def __init__(self, net: nn.Module, length: int, num_splits: int, loss_fn: Any = F.mse_loss):
|
100 |
+
super().__init__()
|
101 |
+
assert length % num_splits == 0, "length must be divisible by num_splits"
|
102 |
+
self.net = net
|
103 |
+
self.length = length
|
104 |
+
self.num_splits = num_splits
|
105 |
+
self.split_length = length // num_splits
|
106 |
+
self.loss_fn = loss_fn
|
107 |
+
|
108 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
109 |
+
angle = sigmas * pi / 2
|
110 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
111 |
+
return alpha, beta
|
112 |
+
|
113 |
+
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
114 |
+
"""Returns diffusion loss of v-objective with different noises per split"""
|
115 |
+
b, _, t, device, dtype = *x.shape, x.device, x.dtype
|
116 |
+
assert t == self.length, "input length must match length"
|
117 |
+
# Sample amount of noise to add for each split
|
118 |
+
sigmas = torch.rand((b, 1, self.num_splits), device=device, dtype=dtype)
|
119 |
+
sigmas = repeat(sigmas, "b 1 n -> b 1 (n l)", l=self.split_length)
|
120 |
+
# Get noise
|
121 |
+
noise = torch.randn_like(x)
|
122 |
+
# Combine input and noise weighted by half-circle
|
123 |
+
alphas, betas = self.get_alpha_beta(sigmas)
|
124 |
+
x_noisy = alphas * x + betas * noise
|
125 |
+
v_target = alphas * noise - betas * x
|
126 |
+
# Sigmas will be provided as additional channel
|
127 |
+
channels = torch.cat([x_noisy, sigmas], dim=1)
|
128 |
+
# Predict velocity and return loss
|
129 |
+
v_pred = self.net(channels, **kwargs)
|
130 |
+
return self.loss_fn(v_pred, v_target)
|
131 |
+
|
132 |
+
""" Schedules """
|
133 |
+
|
134 |
+
|
135 |
+
class Schedule(nn.Module):
|
136 |
+
"""Interface used by different sampling schedules"""
|
137 |
+
|
138 |
+
def forward(self, num_steps: int, device: torch.device) -> Tensor:
|
139 |
+
raise NotImplementedError()
|
140 |
+
|
141 |
+
|
142 |
+
class LinearSchedule(Schedule):
|
143 |
+
def __init__(self, start: float = 1.0, end: float = 0.0):
|
144 |
+
super().__init__()
|
145 |
+
self.start, self.end = start, end
|
146 |
+
|
147 |
+
def forward(self, num_steps: int, device: Any) -> Tensor:
|
148 |
+
return torch.linspace(self.start, self.end, num_steps, device=device)
|
149 |
+
|
150 |
+
|
151 |
+
""" Samplers """
|
152 |
+
|
153 |
+
|
154 |
+
class Sampler(nn.Module):
|
155 |
+
pass
|
156 |
+
|
157 |
+
|
158 |
+
class VSampler(Sampler):
|
159 |
+
|
160 |
+
diffusion_types = [VDiffusion]
|
161 |
+
|
162 |
+
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
|
163 |
+
super().__init__()
|
164 |
+
self.net = net
|
165 |
+
self.schedule = schedule
|
166 |
+
|
167 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
168 |
+
angle = sigmas * pi / 2
|
169 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
170 |
+
return alpha, beta
|
171 |
+
|
172 |
+
@torch.no_grad()
|
173 |
+
def forward( # type: ignore
|
174 |
+
self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
|
175 |
+
) -> Tensor:
|
176 |
+
b = x_noisy.shape[0]
|
177 |
+
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
|
178 |
+
sigmas = repeat(sigmas, "i -> i b", b=b)
|
179 |
+
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
|
180 |
+
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
181 |
+
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
182 |
+
|
183 |
+
for i in progress_bar:
|
184 |
+
v_pred = self.net(x_noisy, sigmas[i], **kwargs)
|
185 |
+
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
|
186 |
+
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
|
187 |
+
x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
|
188 |
+
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
|
189 |
+
|
190 |
+
return x_noisy
|
191 |
+
|
192 |
+
|
193 |
+
class ARVSampler(Sampler):
|
194 |
+
def __init__(self, net: nn.Module, in_channels: int, length: int, num_splits: int):
|
195 |
+
super().__init__()
|
196 |
+
assert length % num_splits == 0, "length must be divisible by num_splits"
|
197 |
+
self.length = length
|
198 |
+
self.in_channels = in_channels
|
199 |
+
self.num_splits = num_splits
|
200 |
+
self.split_length = length // num_splits
|
201 |
+
self.net = net
|
202 |
+
|
203 |
+
@property
|
204 |
+
def device(self):
|
205 |
+
return next(self.net.parameters()).device
|
206 |
+
|
207 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
208 |
+
angle = sigmas * pi / 2
|
209 |
+
alpha = torch.cos(angle)
|
210 |
+
beta = torch.sin(angle)
|
211 |
+
return alpha, beta
|
212 |
+
|
213 |
+
def get_sigmas_ladder(self, num_items: int, num_steps_per_split: int) -> Tensor:
|
214 |
+
b, n, l, i = num_items, self.num_splits, self.split_length, num_steps_per_split
|
215 |
+
n_half = n // 2 # Only half ladder, rest is zero, to leave some context
|
216 |
+
sigmas = torch.linspace(1, 0, i * n_half, device=self.device)
|
217 |
+
sigmas = repeat(sigmas, "(n i) -> i b 1 (n l)", b=b, l=l, n=n_half)
|
218 |
+
sigmas = torch.flip(sigmas, dims=[-1]) # Lowest noise level first
|
219 |
+
sigmas = F.pad(sigmas, pad=[0, 0, 0, 0, 0, 0, 0, 1]) # Add index i+1
|
220 |
+
sigmas[-1, :, :, l:] = sigmas[0, :, :, :-l] # Loop back at index i+1
|
221 |
+
return torch.cat([torch.zeros_like(sigmas), sigmas], dim=-1)
|
222 |
+
|
223 |
+
def sample_loop(
|
224 |
+
self, current: Tensor, sigmas: Tensor, show_progress: bool = False, **kwargs
|
225 |
+
) -> Tensor:
|
226 |
+
num_steps = sigmas.shape[0] - 1
|
227 |
+
alphas, betas = self.get_alpha_beta(sigmas)
|
228 |
+
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
229 |
+
|
230 |
+
for i in progress_bar:
|
231 |
+
channels = torch.cat([current, sigmas[i]], dim=1)
|
232 |
+
v_pred = self.net(channels, **kwargs)
|
233 |
+
x_pred = alphas[i] * current - betas[i] * v_pred
|
234 |
+
noise_pred = betas[i] * current + alphas[i] * v_pred
|
235 |
+
current = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
|
236 |
+
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0,0,0]:.2f})")
|
237 |
+
|
238 |
+
return current
|
239 |
+
|
240 |
+
def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor:
|
241 |
+
b, c, t = num_items, self.in_channels, self.length
|
242 |
+
# Same sigma schedule over all chunks
|
243 |
+
sigmas = torch.linspace(1, 0, num_steps + 1, device=self.device)
|
244 |
+
sigmas = repeat(sigmas, "i -> i b 1 t", b=b, t=t)
|
245 |
+
noise = torch.randn((b, c, t), device=self.device) * sigmas[0]
|
246 |
+
# Sample start
|
247 |
+
return self.sample_loop(current=noise, sigmas=sigmas, **kwargs)
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def forward(
|
251 |
+
self,
|
252 |
+
num_items: int,
|
253 |
+
num_chunks: int,
|
254 |
+
num_steps: int,
|
255 |
+
start: Optional[Tensor] = None,
|
256 |
+
show_progress: bool = False,
|
257 |
+
**kwargs,
|
258 |
+
) -> Tensor:
|
259 |
+
assert_message = f"required at least {self.num_splits} chunks"
|
260 |
+
assert num_chunks >= self.num_splits, assert_message
|
261 |
+
|
262 |
+
# Sample initial chunks
|
263 |
+
start = self.sample_start(num_items=num_items, num_steps=num_steps, **kwargs)
|
264 |
+
# Return start if only num_splits chunks
|
265 |
+
if num_chunks == self.num_splits:
|
266 |
+
return start
|
267 |
+
|
268 |
+
# Get sigmas for autoregressive ladder
|
269 |
+
b, n = num_items, self.num_splits
|
270 |
+
assert num_steps >= n, "num_steps must be greater than num_splits"
|
271 |
+
sigmas = self.get_sigmas_ladder(
|
272 |
+
num_items=b,
|
273 |
+
num_steps_per_split=num_steps // self.num_splits,
|
274 |
+
)
|
275 |
+
alphas, betas = self.get_alpha_beta(sigmas)
|
276 |
+
|
277 |
+
# Noise start to match ladder and set starting chunks
|
278 |
+
start_noise = alphas[0] * start + betas[0] * torch.randn_like(start)
|
279 |
+
chunks = list(start_noise.chunk(chunks=n, dim=-1))
|
280 |
+
|
281 |
+
# Loop over ladder shifts
|
282 |
+
num_shifts = num_chunks # - self.num_splits
|
283 |
+
progress_bar = tqdm(range(num_shifts), disable=not show_progress)
|
284 |
+
|
285 |
+
for j in progress_bar:
|
286 |
+
# Decrease ladder noise of last n chunks
|
287 |
+
updated = self.sample_loop(
|
288 |
+
current=torch.cat(chunks[-n:], dim=-1), sigmas=sigmas, **kwargs
|
289 |
+
)
|
290 |
+
# Update chunks
|
291 |
+
chunks[-n:] = list(updated.chunk(chunks=n, dim=-1))
|
292 |
+
# Add fresh noise chunk
|
293 |
+
shape = (b, self.in_channels, self.split_length)
|
294 |
+
chunks += [torch.randn(shape, device=self.device)]
|
295 |
+
|
296 |
+
return torch.cat(chunks[:num_chunks], dim=-1)
|
297 |
+
|
298 |
+
|
299 |
+
""" Inpainters """
|
300 |
+
|
301 |
+
|
302 |
+
class Inpainter(nn.Module):
|
303 |
+
pass
|
304 |
+
|
305 |
+
|
306 |
+
class VInpainter(Inpainter):
|
307 |
+
|
308 |
+
diffusion_types = [VDiffusion]
|
309 |
+
|
310 |
+
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
|
311 |
+
super().__init__()
|
312 |
+
self.net = net
|
313 |
+
self.schedule = schedule
|
314 |
+
|
315 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
316 |
+
angle = sigmas * pi / 2
|
317 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
318 |
+
return alpha, beta
|
319 |
+
|
320 |
+
@torch.no_grad()
|
321 |
+
def forward( # type: ignore
|
322 |
+
self,
|
323 |
+
source: Tensor,
|
324 |
+
mask: Tensor,
|
325 |
+
num_steps: int,
|
326 |
+
num_resamples: int,
|
327 |
+
show_progress: bool = False,
|
328 |
+
x_noisy: Optional[Tensor] = None,
|
329 |
+
**kwargs,
|
330 |
+
) -> Tensor:
|
331 |
+
x_noisy = default(x_noisy, lambda: torch.randn_like(source))
|
332 |
+
b = x_noisy.shape[0]
|
333 |
+
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
|
334 |
+
sigmas = repeat(sigmas, "i -> i b", b=b)
|
335 |
+
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
|
336 |
+
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
337 |
+
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
338 |
+
|
339 |
+
for i in progress_bar:
|
340 |
+
for r in range(num_resamples):
|
341 |
+
v_pred = self.net(x_noisy, sigmas[i], **kwargs)
|
342 |
+
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
|
343 |
+
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
|
344 |
+
# Renoise to current noise level if resampling
|
345 |
+
j = r == num_resamples - 1
|
346 |
+
x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred
|
347 |
+
s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like(
|
348 |
+
source
|
349 |
+
)
|
350 |
+
x_noisy = s_noisy * mask + x_noisy * ~mask
|
351 |
+
|
352 |
+
progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})")
|
353 |
+
|
354 |
+
return x_noisy
|
Modules/diffusion/audio_diffusion_pytorch/models.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from math import floor
|
3 |
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import pack, rearrange, unpack
|
7 |
+
from torch import Generator, Tensor, nn
|
8 |
+
|
9 |
+
from .components import AppendChannelsPlugin, MelSpectrogram
|
10 |
+
from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler
|
11 |
+
from .utils import (
|
12 |
+
closest_power_2,
|
13 |
+
default,
|
14 |
+
downsample,
|
15 |
+
exists,
|
16 |
+
groupby,
|
17 |
+
randn_like,
|
18 |
+
upsample,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DiffusionModel(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
net_t: Callable,
|
26 |
+
diffusion_t: Callable = VDiffusion,
|
27 |
+
sampler_t: Callable = VSampler,
|
28 |
+
loss_fn: Callable = torch.nn.functional.mse_loss,
|
29 |
+
dim: int = 1,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
34 |
+
sampler_kwargs, kwargs = groupby("sampler_", kwargs)
|
35 |
+
|
36 |
+
self.net = net_t(dim=dim, **kwargs)
|
37 |
+
self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs)
|
38 |
+
self.sampler = sampler_t(net=self.net, **sampler_kwargs)
|
39 |
+
|
40 |
+
def forward(self, *args, **kwargs) -> Tensor:
|
41 |
+
return self.diffusion(*args, **kwargs)
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def sample(self, *args, **kwargs) -> Tensor:
|
45 |
+
return self.sampler(*args, **kwargs)
|
46 |
+
|
47 |
+
|
48 |
+
class EncoderBase(nn.Module, ABC):
|
49 |
+
"""Abstract class for DiffusionAE encoder"""
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __init__(self):
|
53 |
+
super().__init__()
|
54 |
+
self.out_channels = None
|
55 |
+
self.downsample_factor = None
|
56 |
+
|
57 |
+
|
58 |
+
class AdapterBase(nn.Module, ABC):
|
59 |
+
"""Abstract class for DiffusionAE encoder"""
|
60 |
+
|
61 |
+
@abstractmethod
|
62 |
+
def encode(self, x: Tensor) -> Tensor:
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def decode(self, x: Tensor) -> Tensor:
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
class DiffusionAE(DiffusionModel):
|
71 |
+
"""Diffusion Auto Encoder"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
in_channels: int,
|
76 |
+
channels: Sequence[int],
|
77 |
+
encoder: EncoderBase,
|
78 |
+
inject_depth: int,
|
79 |
+
latent_factor: Optional[int] = None,
|
80 |
+
adapter: Optional[AdapterBase] = None,
|
81 |
+
**kwargs,
|
82 |
+
):
|
83 |
+
context_channels = [0] * len(channels)
|
84 |
+
context_channels[inject_depth] = encoder.out_channels
|
85 |
+
super().__init__(
|
86 |
+
in_channels=in_channels,
|
87 |
+
channels=channels,
|
88 |
+
context_channels=context_channels,
|
89 |
+
**kwargs,
|
90 |
+
)
|
91 |
+
self.in_channels = in_channels
|
92 |
+
self.encoder = encoder
|
93 |
+
self.inject_depth = inject_depth
|
94 |
+
# Optional custom latent factor and adapter
|
95 |
+
self.latent_factor = default(latent_factor, self.encoder.downsample_factor)
|
96 |
+
self.adapter = adapter.requires_grad_(False) if exists(adapter) else None
|
97 |
+
|
98 |
+
def forward( # type: ignore
|
99 |
+
self, x: Tensor, with_info: bool = False, **kwargs
|
100 |
+
) -> Union[Tensor, Tuple[Tensor, Any]]:
|
101 |
+
# Encode input to latent channels
|
102 |
+
latent, info = self.encode(x, with_info=True)
|
103 |
+
channels = [None] * self.inject_depth + [latent]
|
104 |
+
# Adapt input to diffusion if adapter provided
|
105 |
+
x = self.adapter.encode(x) if exists(self.adapter) else x
|
106 |
+
# Compute diffusion loss
|
107 |
+
loss = super().forward(x, channels=channels, **kwargs)
|
108 |
+
return (loss, info) if with_info else loss
|
109 |
+
|
110 |
+
def encode(self, *args, **kwargs):
|
111 |
+
return self.encoder(*args, **kwargs)
|
112 |
+
|
113 |
+
@torch.no_grad()
|
114 |
+
def decode(
|
115 |
+
self, latent: Tensor, generator: Optional[Generator] = None, **kwargs
|
116 |
+
) -> Tensor:
|
117 |
+
b = latent.shape[0]
|
118 |
+
noise_length = closest_power_2(latent.shape[2] * self.latent_factor)
|
119 |
+
# Compute noise by inferring shape from latent length
|
120 |
+
noise = torch.randn(
|
121 |
+
(b, self.in_channels, noise_length),
|
122 |
+
device=latent.device,
|
123 |
+
dtype=latent.dtype,
|
124 |
+
generator=generator,
|
125 |
+
)
|
126 |
+
# Compute context from latent
|
127 |
+
channels = [None] * self.inject_depth + [latent] # type: ignore
|
128 |
+
# Decode by sampling while conditioning on latent channels
|
129 |
+
out = super().sample(noise, channels=channels, **kwargs)
|
130 |
+
# Decode output with adapter if provided
|
131 |
+
return self.adapter.decode(out) if exists(self.adapter) else out
|
132 |
+
|
133 |
+
|
134 |
+
class DiffusionUpsampler(DiffusionModel):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
in_channels: int,
|
138 |
+
upsample_factor: int,
|
139 |
+
net_t: Callable,
|
140 |
+
**kwargs,
|
141 |
+
):
|
142 |
+
self.upsample_factor = upsample_factor
|
143 |
+
super().__init__(
|
144 |
+
net_t=AppendChannelsPlugin(net_t, channels=in_channels),
|
145 |
+
in_channels=in_channels,
|
146 |
+
**kwargs,
|
147 |
+
)
|
148 |
+
|
149 |
+
def reupsample(self, x: Tensor) -> Tensor:
|
150 |
+
x = x.clone()
|
151 |
+
x = downsample(x, factor=self.upsample_factor)
|
152 |
+
x = upsample(x, factor=self.upsample_factor)
|
153 |
+
return x
|
154 |
+
|
155 |
+
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore
|
156 |
+
reupsampled = self.reupsample(x)
|
157 |
+
return super().forward(x, *args, append_channels=reupsampled, **kwargs)
|
158 |
+
|
159 |
+
@torch.no_grad()
|
160 |
+
def sample( # type: ignore
|
161 |
+
self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs
|
162 |
+
) -> Tensor:
|
163 |
+
reupsampled = upsample(downsampled, factor=self.upsample_factor)
|
164 |
+
noise = randn_like(reupsampled, generator=generator)
|
165 |
+
return super().sample(noise, append_channels=reupsampled, **kwargs)
|
166 |
+
|
167 |
+
|
168 |
+
class DiffusionVocoder(DiffusionModel):
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
net_t: Callable,
|
172 |
+
mel_channels: int,
|
173 |
+
mel_n_fft: int,
|
174 |
+
mel_hop_length: Optional[int] = None,
|
175 |
+
mel_win_length: Optional[int] = None,
|
176 |
+
in_channels: int = 1, # Ignored: channels are automatically batched.
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4)
|
180 |
+
mel_win_length = default(mel_win_length, mel_n_fft)
|
181 |
+
mel_kwargs, kwargs = groupby("mel_", kwargs)
|
182 |
+
super().__init__(
|
183 |
+
net_t=AppendChannelsPlugin(net_t, channels=1),
|
184 |
+
in_channels=1,
|
185 |
+
**kwargs,
|
186 |
+
)
|
187 |
+
self.to_spectrogram = MelSpectrogram(
|
188 |
+
n_fft=mel_n_fft,
|
189 |
+
hop_length=mel_hop_length,
|
190 |
+
win_length=mel_win_length,
|
191 |
+
n_mel_channels=mel_channels,
|
192 |
+
**mel_kwargs,
|
193 |
+
)
|
194 |
+
self.to_flat = nn.ConvTranspose1d(
|
195 |
+
in_channels=mel_channels,
|
196 |
+
out_channels=1,
|
197 |
+
kernel_size=mel_win_length,
|
198 |
+
stride=mel_hop_length,
|
199 |
+
padding=(mel_win_length - mel_hop_length) // 2,
|
200 |
+
bias=False,
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore
|
204 |
+
# Get spectrogram, pack channels and flatten
|
205 |
+
spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l")
|
206 |
+
spectrogram_flat = self.to_flat(spectrogram)
|
207 |
+
# Pack wave channels
|
208 |
+
x = rearrange(x, "b c t -> (b c) 1 t")
|
209 |
+
return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs)
|
210 |
+
|
211 |
+
@torch.no_grad()
|
212 |
+
def sample( # type: ignore
|
213 |
+
self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs
|
214 |
+
) -> Tensor: # type: ignore
|
215 |
+
# Pack channels and flatten spectrogram
|
216 |
+
spectrogram, ps = pack([spectrogram], "* f l")
|
217 |
+
spectrogram_flat = self.to_flat(spectrogram)
|
218 |
+
# Get start noise and sample
|
219 |
+
noise = randn_like(spectrogram_flat, generator=generator)
|
220 |
+
waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs)
|
221 |
+
# Unpack wave channels
|
222 |
+
waveform = rearrange(waveform, "... 1 t -> ... t")
|
223 |
+
waveform = unpack(waveform, ps, "* t")[0]
|
224 |
+
return waveform
|
225 |
+
|
226 |
+
|
227 |
+
class DiffusionAR(DiffusionModel):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
in_channels: int,
|
231 |
+
length: int,
|
232 |
+
num_splits: int,
|
233 |
+
diffusion_t: Callable = ARVDiffusion,
|
234 |
+
sampler_t: Callable = ARVSampler,
|
235 |
+
**kwargs,
|
236 |
+
):
|
237 |
+
super().__init__(
|
238 |
+
in_channels=in_channels + 1,
|
239 |
+
out_channels=in_channels,
|
240 |
+
diffusion_t=diffusion_t,
|
241 |
+
diffusion_length=length,
|
242 |
+
diffusion_num_splits=num_splits,
|
243 |
+
sampler_t=sampler_t,
|
244 |
+
sampler_in_channels=in_channels,
|
245 |
+
sampler_length=length,
|
246 |
+
sampler_num_splits=num_splits,
|
247 |
+
use_time_conditioning=False,
|
248 |
+
use_modulation=False,
|
249 |
+
**kwargs,
|
250 |
+
)
|
Modules/diffusion/audio_diffusion_pytorch/utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from inspect import isfunction
|
3 |
+
from math import ceil, floor, log2, pi
|
4 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Generator, Tensor
|
10 |
+
from typing_extensions import TypeGuard
|
11 |
+
|
12 |
+
T = TypeVar("T")
|
13 |
+
|
14 |
+
|
15 |
+
def exists(val: Optional[T]) -> TypeGuard[T]:
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
|
19 |
+
def iff(condition: bool, value: T) -> Optional[T]:
|
20 |
+
return value if condition else None
|
21 |
+
|
22 |
+
|
23 |
+
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
24 |
+
return isinstance(obj, list) or isinstance(obj, tuple)
|
25 |
+
|
26 |
+
|
27 |
+
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
28 |
+
if exists(val):
|
29 |
+
return val
|
30 |
+
return d() if isfunction(d) else d
|
31 |
+
|
32 |
+
|
33 |
+
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
34 |
+
if isinstance(val, tuple):
|
35 |
+
return list(val)
|
36 |
+
if isinstance(val, list):
|
37 |
+
return val
|
38 |
+
return [val] # type: ignore
|
39 |
+
|
40 |
+
|
41 |
+
def prod(vals: Sequence[int]) -> int:
|
42 |
+
return reduce(lambda x, y: x * y, vals)
|
43 |
+
|
44 |
+
|
45 |
+
def closest_power_2(x: float) -> int:
|
46 |
+
exponent = log2(x)
|
47 |
+
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
48 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
49 |
+
return 2 ** int(exponent_closest)
|
50 |
+
|
51 |
+
|
52 |
+
"""
|
53 |
+
Kwargs Utils
|
54 |
+
"""
|
55 |
+
|
56 |
+
|
57 |
+
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
58 |
+
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
59 |
+
for key in d.keys():
|
60 |
+
no_prefix = int(not key.startswith(prefix))
|
61 |
+
return_dicts[no_prefix][key] = d[key]
|
62 |
+
return return_dicts
|
63 |
+
|
64 |
+
|
65 |
+
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
66 |
+
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
67 |
+
if keep_prefix:
|
68 |
+
return kwargs_with_prefix, kwargs
|
69 |
+
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
70 |
+
return kwargs_no_prefix, kwargs
|
71 |
+
|
72 |
+
|
73 |
+
def prefix_dict(prefix: str, d: Dict) -> Dict:
|
74 |
+
return {prefix + str(k): v for k, v in d.items()}
|
75 |
+
|
76 |
+
|
77 |
+
"""
|
78 |
+
DSP Utils
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
def resample(
|
83 |
+
waveforms: Tensor,
|
84 |
+
factor_in: int,
|
85 |
+
factor_out: int,
|
86 |
+
rolloff: float = 0.99,
|
87 |
+
lowpass_filter_width: int = 6,
|
88 |
+
) -> Tensor:
|
89 |
+
"""Resamples a waveform using sinc interpolation, adapted from torchaudio"""
|
90 |
+
b, _, length = waveforms.shape
|
91 |
+
length_target = int(factor_out * length / factor_in)
|
92 |
+
d = dict(device=waveforms.device, dtype=waveforms.dtype)
|
93 |
+
|
94 |
+
base_factor = min(factor_in, factor_out) * rolloff
|
95 |
+
width = ceil(lowpass_filter_width * factor_in / base_factor)
|
96 |
+
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
|
97 |
+
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
|
98 |
+
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi
|
99 |
+
|
100 |
+
window = torch.cos(t / lowpass_filter_width / 2) ** 2
|
101 |
+
scale = base_factor / factor_in
|
102 |
+
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
|
103 |
+
kernels *= window * scale
|
104 |
+
|
105 |
+
waveforms = rearrange(waveforms, "b c t -> (b c) t")
|
106 |
+
waveforms = F.pad(waveforms, (width, width + factor_in))
|
107 |
+
resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in)
|
108 |
+
resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b)
|
109 |
+
return resampled[..., :length_target]
|
110 |
+
|
111 |
+
|
112 |
+
def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
|
113 |
+
return resample(waveforms, factor_in=factor, factor_out=1, **kwargs)
|
114 |
+
|
115 |
+
|
116 |
+
def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
|
117 |
+
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
|
118 |
+
|
119 |
+
|
120 |
+
""" Torch Utils """
|
121 |
+
|
122 |
+
|
123 |
+
def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs):
|
124 |
+
"""randn_like that supports generator"""
|
125 |
+
return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor)
|
Modules/diffusion/diffusion.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from random import randint
|
3 |
+
from typing import Any, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .utils import *
|
11 |
+
from .sampler import *
|
12 |
+
|
13 |
+
"""
|
14 |
+
Diffusion Classes (generic for 1d data)
|
15 |
+
"""
|
16 |
+
|
17 |
+
|
18 |
+
class Model1d(nn.Module):
|
19 |
+
def __init__(self, unet_type: str = "base", **kwargs):
|
20 |
+
super().__init__()
|
21 |
+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
22 |
+
self.unet = None
|
23 |
+
self.diffusion = None
|
24 |
+
|
25 |
+
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
26 |
+
return self.diffusion(x, **kwargs)
|
27 |
+
|
28 |
+
def sample(self, *args, **kwargs) -> Tensor:
|
29 |
+
return self.diffusion.sample(*args, **kwargs)
|
30 |
+
|
31 |
+
|
32 |
+
"""
|
33 |
+
Audio Diffusion Classes (specific for 1d audio data)
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
def get_default_model_kwargs():
|
38 |
+
return dict(
|
39 |
+
channels=128,
|
40 |
+
patch_size=16,
|
41 |
+
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
42 |
+
factors=[4, 4, 4, 2, 2, 2],
|
43 |
+
num_blocks=[2, 2, 2, 2, 2, 2],
|
44 |
+
attentions=[0, 0, 0, 1, 1, 1, 1],
|
45 |
+
attention_heads=8,
|
46 |
+
attention_features=64,
|
47 |
+
attention_multiplier=2,
|
48 |
+
attention_use_rel_pos=False,
|
49 |
+
diffusion_type="v",
|
50 |
+
diffusion_sigma_distribution=UniformDistribution(),
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def get_default_sampling_kwargs():
|
55 |
+
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
56 |
+
|
57 |
+
|
58 |
+
class AudioDiffusionModel(Model1d):
|
59 |
+
def __init__(self, **kwargs):
|
60 |
+
super().__init__(**{**get_default_model_kwargs(), **kwargs})
|
61 |
+
|
62 |
+
def sample(self, *args, **kwargs):
|
63 |
+
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
|
64 |
+
|
65 |
+
|
66 |
+
class AudioDiffusionConditional(Model1d):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
embedding_features: int,
|
70 |
+
embedding_max_length: int,
|
71 |
+
embedding_mask_proba: float = 0.1,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
self.embedding_mask_proba = embedding_mask_proba
|
75 |
+
default_kwargs = dict(
|
76 |
+
**get_default_model_kwargs(),
|
77 |
+
unet_type="cfg",
|
78 |
+
context_embedding_features=embedding_features,
|
79 |
+
context_embedding_max_length=embedding_max_length,
|
80 |
+
)
|
81 |
+
super().__init__(**{**default_kwargs, **kwargs})
|
82 |
+
|
83 |
+
def forward(self, *args, **kwargs):
|
84 |
+
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
85 |
+
return super().forward(*args, **{**default_kwargs, **kwargs})
|
86 |
+
|
87 |
+
def sample(self, *args, **kwargs):
|
88 |
+
default_kwargs = dict(
|
89 |
+
**get_default_sampling_kwargs(),
|
90 |
+
embedding_scale=5.0,
|
91 |
+
)
|
92 |
+
return super().sample(*args, **{**default_kwargs, **kwargs})
|
93 |
+
|
94 |
+
|
Modules/diffusion/modules.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import floor, log, pi
|
2 |
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
from .utils import *
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange, reduce, repeat
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
from einops_exts import rearrange_many
|
11 |
+
from torch import Tensor, einsum
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
Utils
|
16 |
+
"""
|
17 |
+
|
18 |
+
class AdaLayerNorm(nn.Module):
|
19 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
20 |
+
super().__init__()
|
21 |
+
self.channels = channels
|
22 |
+
self.eps = eps
|
23 |
+
|
24 |
+
self.fc = nn.Linear(style_dim, channels*2)
|
25 |
+
|
26 |
+
def forward(self, x, s):
|
27 |
+
x = x.transpose(-1, -2)
|
28 |
+
x = x.transpose(1, -1)
|
29 |
+
|
30 |
+
h = self.fc(s)
|
31 |
+
h = h.view(h.size(0), h.size(1), 1)
|
32 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
33 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
34 |
+
|
35 |
+
|
36 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
37 |
+
x = (1 + gamma) * x + beta
|
38 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
39 |
+
|
40 |
+
class StyleTransformer1d(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
num_layers: int,
|
44 |
+
channels: int,
|
45 |
+
num_heads: int,
|
46 |
+
head_features: int,
|
47 |
+
multiplier: int,
|
48 |
+
use_context_time: bool = True,
|
49 |
+
use_rel_pos: bool = False,
|
50 |
+
context_features_multiplier: int = 1,
|
51 |
+
rel_pos_num_buckets: Optional[int] = None,
|
52 |
+
rel_pos_max_distance: Optional[int] = None,
|
53 |
+
context_features: Optional[int] = None,
|
54 |
+
context_embedding_features: Optional[int] = None,
|
55 |
+
embedding_max_length: int = 512,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.blocks = nn.ModuleList(
|
60 |
+
[
|
61 |
+
StyleTransformerBlock(
|
62 |
+
features=channels + context_embedding_features,
|
63 |
+
head_features=head_features,
|
64 |
+
num_heads=num_heads,
|
65 |
+
multiplier=multiplier,
|
66 |
+
style_dim=context_features,
|
67 |
+
use_rel_pos=use_rel_pos,
|
68 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
69 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
70 |
+
)
|
71 |
+
for i in range(num_layers)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
self.to_out = nn.Sequential(
|
76 |
+
Rearrange("b t c -> b c t"),
|
77 |
+
nn.Conv1d(
|
78 |
+
in_channels=channels + context_embedding_features,
|
79 |
+
out_channels=channels,
|
80 |
+
kernel_size=1,
|
81 |
+
),
|
82 |
+
)
|
83 |
+
|
84 |
+
use_context_features = exists(context_features)
|
85 |
+
self.use_context_features = use_context_features
|
86 |
+
self.use_context_time = use_context_time
|
87 |
+
|
88 |
+
if use_context_time or use_context_features:
|
89 |
+
context_mapping_features = channels + context_embedding_features
|
90 |
+
|
91 |
+
self.to_mapping = nn.Sequential(
|
92 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
93 |
+
nn.GELU(),
|
94 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
95 |
+
nn.GELU(),
|
96 |
+
)
|
97 |
+
|
98 |
+
if use_context_time:
|
99 |
+
assert exists(context_mapping_features)
|
100 |
+
self.to_time = nn.Sequential(
|
101 |
+
TimePositionalEmbedding(
|
102 |
+
dim=channels, out_features=context_mapping_features
|
103 |
+
),
|
104 |
+
nn.GELU(),
|
105 |
+
)
|
106 |
+
|
107 |
+
if use_context_features:
|
108 |
+
assert exists(context_features) and exists(context_mapping_features)
|
109 |
+
self.to_features = nn.Sequential(
|
110 |
+
nn.Linear(
|
111 |
+
in_features=context_features, out_features=context_mapping_features
|
112 |
+
),
|
113 |
+
nn.GELU(),
|
114 |
+
)
|
115 |
+
|
116 |
+
self.fixed_embedding = FixedEmbedding(
|
117 |
+
max_length=embedding_max_length, features=context_embedding_features
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
def get_mapping(
|
122 |
+
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
123 |
+
) -> Optional[Tensor]:
|
124 |
+
"""Combines context time features and features into mapping"""
|
125 |
+
items, mapping = [], None
|
126 |
+
# Compute time features
|
127 |
+
if self.use_context_time:
|
128 |
+
assert_message = "use_context_time=True but no time features provided"
|
129 |
+
assert exists(time), assert_message
|
130 |
+
items += [self.to_time(time)]
|
131 |
+
# Compute features
|
132 |
+
if self.use_context_features:
|
133 |
+
assert_message = "context_features exists but no features provided"
|
134 |
+
assert exists(features), assert_message
|
135 |
+
items += [self.to_features(features)]
|
136 |
+
|
137 |
+
# Compute joint mapping
|
138 |
+
if self.use_context_time or self.use_context_features:
|
139 |
+
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
140 |
+
mapping = self.to_mapping(mapping)
|
141 |
+
|
142 |
+
return mapping
|
143 |
+
|
144 |
+
def run(self, x, time, embedding, features):
|
145 |
+
|
146 |
+
mapping = self.get_mapping(time, features)
|
147 |
+
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
148 |
+
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
149 |
+
|
150 |
+
for block in self.blocks:
|
151 |
+
x = x + mapping
|
152 |
+
x = block(x, features)
|
153 |
+
|
154 |
+
x = x.mean(axis=1).unsqueeze(1)
|
155 |
+
x = self.to_out(x)
|
156 |
+
x = x.transpose(-1, -2)
|
157 |
+
|
158 |
+
return x
|
159 |
+
|
160 |
+
def forward(self, x: Tensor,
|
161 |
+
time: Tensor,
|
162 |
+
embedding_mask_proba: float = 0.0,
|
163 |
+
embedding: Optional[Tensor] = None,
|
164 |
+
features: Optional[Tensor] = None,
|
165 |
+
embedding_scale: float = 1.0) -> Tensor:
|
166 |
+
|
167 |
+
b, device = embedding.shape[0], embedding.device
|
168 |
+
fixed_embedding = self.fixed_embedding(embedding)
|
169 |
+
if embedding_mask_proba > 0.0:
|
170 |
+
# Randomly mask embedding
|
171 |
+
batch_mask = rand_bool(
|
172 |
+
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
173 |
+
)
|
174 |
+
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
175 |
+
|
176 |
+
if embedding_scale != 1.0:
|
177 |
+
# Compute both normal and fixed embedding outputs
|
178 |
+
out = self.run(x, time, embedding=embedding, features=features)
|
179 |
+
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
180 |
+
# Scale conditional output using classifier-free guidance
|
181 |
+
return out_masked + (out - out_masked) * embedding_scale
|
182 |
+
else:
|
183 |
+
return self.run(x, time, embedding=embedding, features=features)
|
184 |
+
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class StyleTransformerBlock(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
features: int,
|
192 |
+
num_heads: int,
|
193 |
+
head_features: int,
|
194 |
+
style_dim: int,
|
195 |
+
multiplier: int,
|
196 |
+
use_rel_pos: bool,
|
197 |
+
rel_pos_num_buckets: Optional[int] = None,
|
198 |
+
rel_pos_max_distance: Optional[int] = None,
|
199 |
+
context_features: Optional[int] = None,
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.use_cross_attention = exists(context_features) and context_features > 0
|
204 |
+
|
205 |
+
self.attention = StyleAttention(
|
206 |
+
features=features,
|
207 |
+
style_dim=style_dim,
|
208 |
+
num_heads=num_heads,
|
209 |
+
head_features=head_features,
|
210 |
+
use_rel_pos=use_rel_pos,
|
211 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
212 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
213 |
+
)
|
214 |
+
|
215 |
+
if self.use_cross_attention:
|
216 |
+
self.cross_attention = StyleAttention(
|
217 |
+
features=features,
|
218 |
+
style_dim=style_dim,
|
219 |
+
num_heads=num_heads,
|
220 |
+
head_features=head_features,
|
221 |
+
context_features=context_features,
|
222 |
+
use_rel_pos=use_rel_pos,
|
223 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
224 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
225 |
+
)
|
226 |
+
|
227 |
+
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
228 |
+
|
229 |
+
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
230 |
+
x = self.attention(x, s) + x
|
231 |
+
if self.use_cross_attention:
|
232 |
+
x = self.cross_attention(x, s, context=context) + x
|
233 |
+
x = self.feed_forward(x) + x
|
234 |
+
return x
|
235 |
+
|
236 |
+
class StyleAttention(nn.Module):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
features: int,
|
240 |
+
*,
|
241 |
+
style_dim: int,
|
242 |
+
head_features: int,
|
243 |
+
num_heads: int,
|
244 |
+
context_features: Optional[int] = None,
|
245 |
+
use_rel_pos: bool,
|
246 |
+
rel_pos_num_buckets: Optional[int] = None,
|
247 |
+
rel_pos_max_distance: Optional[int] = None,
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
self.context_features = context_features
|
251 |
+
mid_features = head_features * num_heads
|
252 |
+
context_features = default(context_features, features)
|
253 |
+
|
254 |
+
self.norm = AdaLayerNorm(style_dim, features)
|
255 |
+
self.norm_context = AdaLayerNorm(style_dim, context_features)
|
256 |
+
self.to_q = nn.Linear(
|
257 |
+
in_features=features, out_features=mid_features, bias=False
|
258 |
+
)
|
259 |
+
self.to_kv = nn.Linear(
|
260 |
+
in_features=context_features, out_features=mid_features * 2, bias=False
|
261 |
+
)
|
262 |
+
self.attention = AttentionBase(
|
263 |
+
features,
|
264 |
+
num_heads=num_heads,
|
265 |
+
head_features=head_features,
|
266 |
+
use_rel_pos=use_rel_pos,
|
267 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
268 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
269 |
+
)
|
270 |
+
|
271 |
+
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
272 |
+
assert_message = "You must provide a context when using context_features"
|
273 |
+
assert not self.context_features or exists(context), assert_message
|
274 |
+
# Use context if provided
|
275 |
+
context = default(context, x)
|
276 |
+
# Normalize then compute q from input and k,v from context
|
277 |
+
x, context = self.norm(x, s), self.norm_context(context, s)
|
278 |
+
|
279 |
+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
280 |
+
# Compute and return attention
|
281 |
+
return self.attention(q, k, v)
|
282 |
+
|
283 |
+
class Transformer1d(nn.Module):
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
num_layers: int,
|
287 |
+
channels: int,
|
288 |
+
num_heads: int,
|
289 |
+
head_features: int,
|
290 |
+
multiplier: int,
|
291 |
+
use_context_time: bool = True,
|
292 |
+
use_rel_pos: bool = False,
|
293 |
+
context_features_multiplier: int = 1,
|
294 |
+
rel_pos_num_buckets: Optional[int] = None,
|
295 |
+
rel_pos_max_distance: Optional[int] = None,
|
296 |
+
context_features: Optional[int] = None,
|
297 |
+
context_embedding_features: Optional[int] = None,
|
298 |
+
embedding_max_length: int = 512,
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.blocks = nn.ModuleList(
|
303 |
+
[
|
304 |
+
TransformerBlock(
|
305 |
+
features=channels + context_embedding_features,
|
306 |
+
head_features=head_features,
|
307 |
+
num_heads=num_heads,
|
308 |
+
multiplier=multiplier,
|
309 |
+
use_rel_pos=use_rel_pos,
|
310 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
311 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
312 |
+
)
|
313 |
+
for i in range(num_layers)
|
314 |
+
]
|
315 |
+
)
|
316 |
+
|
317 |
+
self.to_out = nn.Sequential(
|
318 |
+
Rearrange("b t c -> b c t"),
|
319 |
+
nn.Conv1d(
|
320 |
+
in_channels=channels + context_embedding_features,
|
321 |
+
out_channels=channels,
|
322 |
+
kernel_size=1,
|
323 |
+
),
|
324 |
+
)
|
325 |
+
|
326 |
+
use_context_features = exists(context_features)
|
327 |
+
self.use_context_features = use_context_features
|
328 |
+
self.use_context_time = use_context_time
|
329 |
+
|
330 |
+
if use_context_time or use_context_features:
|
331 |
+
context_mapping_features = channels + context_embedding_features
|
332 |
+
|
333 |
+
self.to_mapping = nn.Sequential(
|
334 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
335 |
+
nn.GELU(),
|
336 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
337 |
+
nn.GELU(),
|
338 |
+
)
|
339 |
+
|
340 |
+
if use_context_time:
|
341 |
+
assert exists(context_mapping_features)
|
342 |
+
self.to_time = nn.Sequential(
|
343 |
+
TimePositionalEmbedding(
|
344 |
+
dim=channels, out_features=context_mapping_features
|
345 |
+
),
|
346 |
+
nn.GELU(),
|
347 |
+
)
|
348 |
+
|
349 |
+
if use_context_features:
|
350 |
+
assert exists(context_features) and exists(context_mapping_features)
|
351 |
+
self.to_features = nn.Sequential(
|
352 |
+
nn.Linear(
|
353 |
+
in_features=context_features, out_features=context_mapping_features
|
354 |
+
),
|
355 |
+
nn.GELU(),
|
356 |
+
)
|
357 |
+
|
358 |
+
self.fixed_embedding = FixedEmbedding(
|
359 |
+
max_length=embedding_max_length, features=context_embedding_features
|
360 |
+
)
|
361 |
+
|
362 |
+
|
363 |
+
def get_mapping(
|
364 |
+
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
365 |
+
) -> Optional[Tensor]:
|
366 |
+
"""Combines context time features and features into mapping"""
|
367 |
+
items, mapping = [], None
|
368 |
+
# Compute time features
|
369 |
+
if self.use_context_time:
|
370 |
+
assert_message = "use_context_time=True but no time features provided"
|
371 |
+
assert exists(time), assert_message
|
372 |
+
items += [self.to_time(time)]
|
373 |
+
# Compute features
|
374 |
+
if self.use_context_features:
|
375 |
+
assert_message = "context_features exists but no features provided"
|
376 |
+
assert exists(features), assert_message
|
377 |
+
items += [self.to_features(features)]
|
378 |
+
|
379 |
+
# Compute joint mapping
|
380 |
+
if self.use_context_time or self.use_context_features:
|
381 |
+
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
382 |
+
mapping = self.to_mapping(mapping)
|
383 |
+
|
384 |
+
return mapping
|
385 |
+
|
386 |
+
def run(self, x, time, embedding, features):
|
387 |
+
|
388 |
+
mapping = self.get_mapping(time, features)
|
389 |
+
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
390 |
+
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
391 |
+
|
392 |
+
for block in self.blocks:
|
393 |
+
x = x + mapping
|
394 |
+
x = block(x)
|
395 |
+
|
396 |
+
x = x.mean(axis=1).unsqueeze(1)
|
397 |
+
x = self.to_out(x)
|
398 |
+
x = x.transpose(-1, -2)
|
399 |
+
|
400 |
+
return x
|
401 |
+
|
402 |
+
def forward(self, x: Tensor,
|
403 |
+
time: Tensor,
|
404 |
+
embedding_mask_proba: float = 0.0,
|
405 |
+
embedding: Optional[Tensor] = None,
|
406 |
+
features: Optional[Tensor] = None,
|
407 |
+
embedding_scale: float = 1.0) -> Tensor:
|
408 |
+
|
409 |
+
b, device = embedding.shape[0], embedding.device
|
410 |
+
fixed_embedding = self.fixed_embedding(embedding)
|
411 |
+
if embedding_mask_proba > 0.0:
|
412 |
+
# Randomly mask embedding
|
413 |
+
batch_mask = rand_bool(
|
414 |
+
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
415 |
+
)
|
416 |
+
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
417 |
+
|
418 |
+
if embedding_scale != 1.0:
|
419 |
+
# Compute both normal and fixed embedding outputs
|
420 |
+
out = self.run(x, time, embedding=embedding, features=features)
|
421 |
+
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
422 |
+
# Scale conditional output using classifier-free guidance
|
423 |
+
return out_masked + (out - out_masked) * embedding_scale
|
424 |
+
else:
|
425 |
+
return self.run(x, time, embedding=embedding, features=features)
|
426 |
+
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
"""
|
431 |
+
Attention Components
|
432 |
+
"""
|
433 |
+
|
434 |
+
|
435 |
+
class RelativePositionBias(nn.Module):
|
436 |
+
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
|
437 |
+
super().__init__()
|
438 |
+
self.num_buckets = num_buckets
|
439 |
+
self.max_distance = max_distance
|
440 |
+
self.num_heads = num_heads
|
441 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
442 |
+
|
443 |
+
@staticmethod
|
444 |
+
def _relative_position_bucket(
|
445 |
+
relative_position: Tensor, num_buckets: int, max_distance: int
|
446 |
+
):
|
447 |
+
num_buckets //= 2
|
448 |
+
ret = (relative_position >= 0).to(torch.long) * num_buckets
|
449 |
+
n = torch.abs(relative_position)
|
450 |
+
|
451 |
+
max_exact = num_buckets // 2
|
452 |
+
is_small = n < max_exact
|
453 |
+
|
454 |
+
val_if_large = (
|
455 |
+
max_exact
|
456 |
+
+ (
|
457 |
+
torch.log(n.float() / max_exact)
|
458 |
+
/ log(max_distance / max_exact)
|
459 |
+
* (num_buckets - max_exact)
|
460 |
+
).long()
|
461 |
+
)
|
462 |
+
val_if_large = torch.min(
|
463 |
+
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
|
464 |
+
)
|
465 |
+
|
466 |
+
ret += torch.where(is_small, n, val_if_large)
|
467 |
+
return ret
|
468 |
+
|
469 |
+
def forward(self, num_queries: int, num_keys: int) -> Tensor:
|
470 |
+
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
|
471 |
+
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
|
472 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
473 |
+
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
|
474 |
+
|
475 |
+
relative_position_bucket = self._relative_position_bucket(
|
476 |
+
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
|
477 |
+
)
|
478 |
+
|
479 |
+
bias = self.relative_attention_bias(relative_position_bucket)
|
480 |
+
bias = rearrange(bias, "m n h -> 1 h m n")
|
481 |
+
return bias
|
482 |
+
|
483 |
+
|
484 |
+
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
485 |
+
mid_features = features * multiplier
|
486 |
+
return nn.Sequential(
|
487 |
+
nn.Linear(in_features=features, out_features=mid_features),
|
488 |
+
nn.GELU(),
|
489 |
+
nn.Linear(in_features=mid_features, out_features=features),
|
490 |
+
)
|
491 |
+
|
492 |
+
|
493 |
+
class AttentionBase(nn.Module):
|
494 |
+
def __init__(
|
495 |
+
self,
|
496 |
+
features: int,
|
497 |
+
*,
|
498 |
+
head_features: int,
|
499 |
+
num_heads: int,
|
500 |
+
use_rel_pos: bool,
|
501 |
+
out_features: Optional[int] = None,
|
502 |
+
rel_pos_num_buckets: Optional[int] = None,
|
503 |
+
rel_pos_max_distance: Optional[int] = None,
|
504 |
+
):
|
505 |
+
super().__init__()
|
506 |
+
self.scale = head_features ** -0.5
|
507 |
+
self.num_heads = num_heads
|
508 |
+
self.use_rel_pos = use_rel_pos
|
509 |
+
mid_features = head_features * num_heads
|
510 |
+
|
511 |
+
if use_rel_pos:
|
512 |
+
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
|
513 |
+
self.rel_pos = RelativePositionBias(
|
514 |
+
num_buckets=rel_pos_num_buckets,
|
515 |
+
max_distance=rel_pos_max_distance,
|
516 |
+
num_heads=num_heads,
|
517 |
+
)
|
518 |
+
if out_features is None:
|
519 |
+
out_features = features
|
520 |
+
|
521 |
+
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
|
522 |
+
|
523 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
524 |
+
# Split heads
|
525 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
526 |
+
# Compute similarity matrix
|
527 |
+
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
528 |
+
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
|
529 |
+
sim = sim * self.scale
|
530 |
+
# Get attention matrix with softmax
|
531 |
+
attn = sim.softmax(dim=-1)
|
532 |
+
# Compute values
|
533 |
+
out = einsum("... n m, ... m d -> ... n d", attn, v)
|
534 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
535 |
+
return self.to_out(out)
|
536 |
+
|
537 |
+
|
538 |
+
class Attention(nn.Module):
|
539 |
+
def __init__(
|
540 |
+
self,
|
541 |
+
features: int,
|
542 |
+
*,
|
543 |
+
head_features: int,
|
544 |
+
num_heads: int,
|
545 |
+
out_features: Optional[int] = None,
|
546 |
+
context_features: Optional[int] = None,
|
547 |
+
use_rel_pos: bool,
|
548 |
+
rel_pos_num_buckets: Optional[int] = None,
|
549 |
+
rel_pos_max_distance: Optional[int] = None,
|
550 |
+
):
|
551 |
+
super().__init__()
|
552 |
+
self.context_features = context_features
|
553 |
+
mid_features = head_features * num_heads
|
554 |
+
context_features = default(context_features, features)
|
555 |
+
|
556 |
+
self.norm = nn.LayerNorm(features)
|
557 |
+
self.norm_context = nn.LayerNorm(context_features)
|
558 |
+
self.to_q = nn.Linear(
|
559 |
+
in_features=features, out_features=mid_features, bias=False
|
560 |
+
)
|
561 |
+
self.to_kv = nn.Linear(
|
562 |
+
in_features=context_features, out_features=mid_features * 2, bias=False
|
563 |
+
)
|
564 |
+
|
565 |
+
self.attention = AttentionBase(
|
566 |
+
features,
|
567 |
+
out_features=out_features,
|
568 |
+
num_heads=num_heads,
|
569 |
+
head_features=head_features,
|
570 |
+
use_rel_pos=use_rel_pos,
|
571 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
572 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
573 |
+
)
|
574 |
+
|
575 |
+
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
576 |
+
assert_message = "You must provide a context when using context_features"
|
577 |
+
assert not self.context_features or exists(context), assert_message
|
578 |
+
# Use context if provided
|
579 |
+
context = default(context, x)
|
580 |
+
# Normalize then compute q from input and k,v from context
|
581 |
+
x, context = self.norm(x), self.norm_context(context)
|
582 |
+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
583 |
+
# Compute and return attention
|
584 |
+
return self.attention(q, k, v)
|
585 |
+
|
586 |
+
|
587 |
+
"""
|
588 |
+
Transformer Blocks
|
589 |
+
"""
|
590 |
+
|
591 |
+
|
592 |
+
class TransformerBlock(nn.Module):
|
593 |
+
def __init__(
|
594 |
+
self,
|
595 |
+
features: int,
|
596 |
+
num_heads: int,
|
597 |
+
head_features: int,
|
598 |
+
multiplier: int,
|
599 |
+
use_rel_pos: bool,
|
600 |
+
rel_pos_num_buckets: Optional[int] = None,
|
601 |
+
rel_pos_max_distance: Optional[int] = None,
|
602 |
+
context_features: Optional[int] = None,
|
603 |
+
):
|
604 |
+
super().__init__()
|
605 |
+
|
606 |
+
self.use_cross_attention = exists(context_features) and context_features > 0
|
607 |
+
|
608 |
+
self.attention = Attention(
|
609 |
+
features=features,
|
610 |
+
num_heads=num_heads,
|
611 |
+
head_features=head_features,
|
612 |
+
use_rel_pos=use_rel_pos,
|
613 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
614 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
615 |
+
)
|
616 |
+
|
617 |
+
if self.use_cross_attention:
|
618 |
+
self.cross_attention = Attention(
|
619 |
+
features=features,
|
620 |
+
num_heads=num_heads,
|
621 |
+
head_features=head_features,
|
622 |
+
context_features=context_features,
|
623 |
+
use_rel_pos=use_rel_pos,
|
624 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
625 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
626 |
+
)
|
627 |
+
|
628 |
+
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
629 |
+
|
630 |
+
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
631 |
+
x = self.attention(x) + x
|
632 |
+
if self.use_cross_attention:
|
633 |
+
x = self.cross_attention(x, context=context) + x
|
634 |
+
x = self.feed_forward(x) + x
|
635 |
+
return x
|
636 |
+
|
637 |
+
|
638 |
+
|
639 |
+
"""
|
640 |
+
Time Embeddings
|
641 |
+
"""
|
642 |
+
|
643 |
+
|
644 |
+
class SinusoidalEmbedding(nn.Module):
|
645 |
+
def __init__(self, dim: int):
|
646 |
+
super().__init__()
|
647 |
+
self.dim = dim
|
648 |
+
|
649 |
+
def forward(self, x: Tensor) -> Tensor:
|
650 |
+
device, half_dim = x.device, self.dim // 2
|
651 |
+
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
|
652 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
653 |
+
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
|
654 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
655 |
+
|
656 |
+
|
657 |
+
class LearnedPositionalEmbedding(nn.Module):
|
658 |
+
"""Used for continuous time"""
|
659 |
+
|
660 |
+
def __init__(self, dim: int):
|
661 |
+
super().__init__()
|
662 |
+
assert (dim % 2) == 0
|
663 |
+
half_dim = dim // 2
|
664 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
665 |
+
|
666 |
+
def forward(self, x: Tensor) -> Tensor:
|
667 |
+
x = rearrange(x, "b -> b 1")
|
668 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
669 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
670 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
671 |
+
return fouriered
|
672 |
+
|
673 |
+
|
674 |
+
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
675 |
+
return nn.Sequential(
|
676 |
+
LearnedPositionalEmbedding(dim),
|
677 |
+
nn.Linear(in_features=dim + 1, out_features=out_features),
|
678 |
+
)
|
679 |
+
|
680 |
+
class FixedEmbedding(nn.Module):
|
681 |
+
def __init__(self, max_length: int, features: int):
|
682 |
+
super().__init__()
|
683 |
+
self.max_length = max_length
|
684 |
+
self.embedding = nn.Embedding(max_length, features)
|
685 |
+
|
686 |
+
def forward(self, x: Tensor) -> Tensor:
|
687 |
+
batch_size, length, device = *x.shape[0:2], x.device
|
688 |
+
assert_message = "Input sequence length must be <= max_length"
|
689 |
+
assert length <= self.max_length, assert_message
|
690 |
+
position = torch.arange(length, device=device)
|
691 |
+
fixed_embedding = self.embedding(position)
|
692 |
+
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
|
693 |
+
return fixed_embedding
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.github/workflows/python-publish.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will upload a Python Package using Twine when a release is created
|
2 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
3 |
+
|
4 |
+
# This workflow uses actions that are not certified by GitHub.
|
5 |
+
# They are provided by a third-party and are governed by
|
6 |
+
# separate terms of service, privacy policy, and support
|
7 |
+
# documentation.
|
8 |
+
|
9 |
+
name: Upload Python Package
|
10 |
+
|
11 |
+
on:
|
12 |
+
release:
|
13 |
+
types: [published]
|
14 |
+
|
15 |
+
permissions:
|
16 |
+
contents: read
|
17 |
+
|
18 |
+
jobs:
|
19 |
+
deploy:
|
20 |
+
|
21 |
+
runs-on: ubuntu-latest
|
22 |
+
|
23 |
+
steps:
|
24 |
+
- uses: actions/checkout@v3
|
25 |
+
- name: Set up Python
|
26 |
+
uses: actions/setup-python@v3
|
27 |
+
with:
|
28 |
+
python-version: '3.x'
|
29 |
+
- name: Install dependencies
|
30 |
+
run: |
|
31 |
+
python -m pip install --upgrade pip
|
32 |
+
pip install build
|
33 |
+
- name: Build package
|
34 |
+
run: python -m build
|
35 |
+
- name: Publish package
|
36 |
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
37 |
+
with:
|
38 |
+
user: __token__
|
39 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.mypy_cache
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v2.3.0
|
4 |
+
hooks:
|
5 |
+
- id: end-of-file-fixer
|
6 |
+
- id: trailing-whitespace
|
7 |
+
|
8 |
+
# Formats code correctly
|
9 |
+
- repo: https://github.com/psf/black
|
10 |
+
rev: 21.12b0
|
11 |
+
hooks:
|
12 |
+
- id: black
|
13 |
+
args: [
|
14 |
+
'--experimental-string-processing'
|
15 |
+
]
|
16 |
+
|
17 |
+
# Sorts imports
|
18 |
+
- repo: https://github.com/pycqa/isort
|
19 |
+
rev: 5.10.1
|
20 |
+
hooks:
|
21 |
+
- id: isort
|
22 |
+
name: isort (python)
|
23 |
+
args: ["--profile", "black"]
|
24 |
+
|
25 |
+
# Checks unused imports, like lengths, etc
|
26 |
+
- repo: https://gitlab.com/pycqa/flake8
|
27 |
+
rev: 4.0.0
|
28 |
+
hooks:
|
29 |
+
- id: flake8
|
30 |
+
args: [
|
31 |
+
'--per-file-ignores=__init__.py:F401',
|
32 |
+
'--max-line-length=88',
|
33 |
+
'--ignore=E203,W503'
|
34 |
+
]
|
35 |
+
|
36 |
+
# Checks types
|
37 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
38 |
+
rev: 'v0.971'
|
39 |
+
hooks:
|
40 |
+
- id: mypy
|
41 |
+
additional_dependencies: [data-science-types>=0.2, torch>=1.6]
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 archinet.ai
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/LOGO.png
ADDED
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/README.md
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<img src="./LOGO.png"></img>
|
2 |
+
|
3 |
+
A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. **Notes: (1) no pre-trained models are provided here, (2) the configs shown are indicative and untested, see [Moûsai](https://arxiv.org/abs/2301.11757) for the configs used in the paper.**
|
4 |
+
|
5 |
+
|
6 |
+
## Install
|
7 |
+
|
8 |
+
```bash
|
9 |
+
pip install audio-diffusion-pytorch
|
10 |
+
```
|
11 |
+
|
12 |
+
[![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/)
|
13 |
+
[![Downloads](https://static.pepy.tech/personalized-badge/audio-diffusion-pytorch?period=total&units=international_system&left_color=black&right_color=black&left_text=Downloads)](https://pepy.tech/project/audio-diffusion-pytorch)
|
14 |
+
|
15 |
+
|
16 |
+
## Usage
|
17 |
+
|
18 |
+
### Unconditional Generator
|
19 |
+
|
20 |
+
```py
|
21 |
+
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
|
22 |
+
|
23 |
+
model = DiffusionModel(
|
24 |
+
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
|
25 |
+
in_channels=2, # U-Net: number of input/output (audio) channels
|
26 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
27 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
28 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
29 |
+
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
|
30 |
+
attention_heads=8, # U-Net: number of attention heads per attention item
|
31 |
+
attention_features=64, # U-Net: number of attention features per attention item
|
32 |
+
diffusion_t=VDiffusion, # The diffusion method used
|
33 |
+
sampler_t=VSampler, # The diffusion sampler used
|
34 |
+
)
|
35 |
+
|
36 |
+
# Train model with audio waveforms
|
37 |
+
audio = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length]
|
38 |
+
loss = model(audio)
|
39 |
+
loss.backward()
|
40 |
+
|
41 |
+
# Turn noise into new audio sample with diffusion
|
42 |
+
noise = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length]
|
43 |
+
sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-100
|
44 |
+
```
|
45 |
+
|
46 |
+
### Text-Conditional Generator
|
47 |
+
A text-to-audio diffusion model that conditions the generation with `t5-base` text embeddings, requires `pip install transformers`.
|
48 |
+
```py
|
49 |
+
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
|
50 |
+
|
51 |
+
model = DiffusionModel(
|
52 |
+
# ... same as unconditional model
|
53 |
+
use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
|
54 |
+
use_embedding_cfg=True, # U-Net: enables classifier free guidance
|
55 |
+
embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
|
56 |
+
embedding_features=768, # U-Net: text mbedding features (default for T5-base)
|
57 |
+
cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
|
58 |
+
)
|
59 |
+
|
60 |
+
# Train model with audio waveforms
|
61 |
+
audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
62 |
+
loss = model(
|
63 |
+
audio_wave,
|
64 |
+
text=['The audio description'], # Text conditioning, one element per batch
|
65 |
+
embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
|
66 |
+
)
|
67 |
+
loss.backward()
|
68 |
+
|
69 |
+
# Turn noise into new audio sample with diffusion
|
70 |
+
noise = torch.randn(1, 2, 2**18)
|
71 |
+
sample = model.sample(
|
72 |
+
noise,
|
73 |
+
text=['The audio description'],
|
74 |
+
embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
|
75 |
+
num_steps=2 # Higher for better quality, suggested num_steps: 10-100
|
76 |
+
)
|
77 |
+
```
|
78 |
+
|
79 |
+
### Diffusion Upsampler
|
80 |
+
Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz.
|
81 |
+
```py
|
82 |
+
from audio_diffusion_pytorch import DiffusionUpsampler, UNetV0, VDiffusion, VSampler
|
83 |
+
|
84 |
+
upsampler = DiffusionUpsampler(
|
85 |
+
net_t=UNetV0, # The model type used for diffusion
|
86 |
+
upsample_factor=16, # The upsample factor (e.g. 16 can be used for 3kHz to 48kHz)
|
87 |
+
in_channels=2, # U-Net: number of input/output (audio) channels
|
88 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
89 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
90 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
91 |
+
diffusion_t=VDiffusion, # The diffusion method used
|
92 |
+
sampler_t=VSampler, # The diffusion sampler used
|
93 |
+
)
|
94 |
+
|
95 |
+
# Train model with high sample rate audio waveforms
|
96 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
97 |
+
loss = upsampler(audio)
|
98 |
+
loss.backward()
|
99 |
+
|
100 |
+
# Turn low sample rate audio into high sample rate
|
101 |
+
downsampled_audio = torch.randn(1, 2, 2**14) # [batch, in_channels, length]
|
102 |
+
sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [1, 2, 2**18]
|
103 |
+
```
|
104 |
+
|
105 |
+
### Diffusion Vocoder
|
106 |
+
Convert a mel-spectrogram to wavefrom using diffusion.
|
107 |
+
```py
|
108 |
+
from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler
|
109 |
+
|
110 |
+
vocoder = DiffusionVocoder(
|
111 |
+
mel_n_fft=1024, # Mel-spectrogram n_fft
|
112 |
+
mel_channels=80, # Mel-spectrogram channels
|
113 |
+
mel_sample_rate=48000, # Mel-spectrogram sample rate
|
114 |
+
mel_normalize_log=True, # Mel-spectrogram log normalization (alternative is mel_normalize=True for [-1,1] power normalization)
|
115 |
+
net_t=UNetV0, # The model type used for diffusion vocoding
|
116 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
117 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
118 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
119 |
+
diffusion_t=VDiffusion, # The diffusion method used
|
120 |
+
sampler_t=VSampler, # The diffusion sampler used
|
121 |
+
)
|
122 |
+
|
123 |
+
# Train model on waveforms (automatically converted to mel internally)
|
124 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
125 |
+
loss = vocoder(audio)
|
126 |
+
loss.backward()
|
127 |
+
|
128 |
+
# Turn mel spectrogram into waveform
|
129 |
+
mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channels, mel_length]
|
130 |
+
sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18]
|
131 |
+
```
|
132 |
+
|
133 |
+
### Diffusion Autoencoder
|
134 |
+
Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it subclasses the `EncoderBase` class or contains an `out_channels` and `downsample_factor` field.
|
135 |
+
```py
|
136 |
+
from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
|
137 |
+
from audio_encoders_pytorch import MelE1d, TanhBottleneck
|
138 |
+
|
139 |
+
autoencoder = DiffusionAE(
|
140 |
+
encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
|
141 |
+
in_channels=2,
|
142 |
+
channels=512,
|
143 |
+
multipliers=[1, 1],
|
144 |
+
factors=[2],
|
145 |
+
num_blocks=[12],
|
146 |
+
out_channels=32,
|
147 |
+
mel_channels=80,
|
148 |
+
mel_sample_rate=48000,
|
149 |
+
mel_normalize_log=True,
|
150 |
+
bottleneck=TanhBottleneck(),
|
151 |
+
),
|
152 |
+
inject_depth=6,
|
153 |
+
net_t=UNetV0, # The model type used for diffusion upsampling
|
154 |
+
in_channels=2, # U-Net: number of input/output (audio) channels
|
155 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
156 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
157 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
158 |
+
diffusion_t=VDiffusion, # The diffusion method used
|
159 |
+
sampler_t=VSampler, # The diffusion sampler used
|
160 |
+
)
|
161 |
+
|
162 |
+
# Train autoencoder with audio samples
|
163 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
164 |
+
loss = autoencoder(audio)
|
165 |
+
loss.backward()
|
166 |
+
|
167 |
+
# Encode/decode audio
|
168 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
169 |
+
latent = autoencoder.encode(audio) # Encode
|
170 |
+
sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent
|
171 |
+
```
|
172 |
+
|
173 |
+
## Other
|
174 |
+
|
175 |
+
### Inpainting
|
176 |
+
```py
|
177 |
+
from audio_diffusion_pytorch import UNetV0, VInpainter
|
178 |
+
|
179 |
+
# The diffusion UNetV0 (this is an example, the net must be trained to work)
|
180 |
+
net = UNetV0(
|
181 |
+
dim=1,
|
182 |
+
in_channels=2, # U-Net: number of input/output (audio) channels
|
183 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
184 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
185 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
186 |
+
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
|
187 |
+
attention_heads=8, # U-Net: number of attention heads per attention block
|
188 |
+
attention_features=64, # U-Net: number of attention features per attention block,
|
189 |
+
)
|
190 |
+
|
191 |
+
# Instantiate inpainter with trained net
|
192 |
+
inpainter = VInpainter(net=net)
|
193 |
+
|
194 |
+
# Inpaint source
|
195 |
+
y = inpainter(
|
196 |
+
source=torch.randn(1, 2, 2**18), # Start source
|
197 |
+
mask=torch.randint(0, 2, (1, 2, 2 ** 18), dtype=torch.bool), # Set to `True` the parts you want to keep
|
198 |
+
num_steps=10, # Number of inpainting steps
|
199 |
+
num_resamples=2, # Number of resampling steps
|
200 |
+
show_progress=True,
|
201 |
+
) # [1, 2, 2 ** 18]
|
202 |
+
```
|
203 |
+
|
204 |
+
## Appreciation
|
205 |
+
|
206 |
+
* [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions.
|
207 |
+
* [ETH Zurich](https://inf.ethz.ch/) for the resources, [Zhijing Jin](https://zhijing-jin.com/), [Bernhard Schoelkopf](https://is.mpg.de/~bs), and [Mrinmaya Sachan](http://www.mrinmaya.io/) for supervising this Thesis.
|
208 |
+
* [Phil Wang](https://github.com/lucidrains) for the beautiful open source contributions on [diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch) and [Imagen](https://github.com/lucidrains/imagen-pytorch).
|
209 |
+
* [Katherine Crowson](https://github.com/crowsonkb) for the experiments with [k-diffusion](https://github.com/crowsonkb/k-diffusion) and the insane collection of samplers.
|
210 |
+
|
211 |
+
## Citations
|
212 |
+
|
213 |
+
DDPM Diffusion
|
214 |
+
```bibtex
|
215 |
+
@misc{2006.11239,
|
216 |
+
Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
|
217 |
+
Title = {Denoising Diffusion Probabilistic Models},
|
218 |
+
Year = {2020},
|
219 |
+
Eprint = {arXiv:2006.11239},
|
220 |
+
}
|
221 |
+
```
|
222 |
+
|
223 |
+
DDIM (V-Sampler)
|
224 |
+
```bibtex
|
225 |
+
@misc{2010.02502,
|
226 |
+
Author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
|
227 |
+
Title = {Denoising Diffusion Implicit Models},
|
228 |
+
Year = {2020},
|
229 |
+
Eprint = {arXiv:2010.02502},
|
230 |
+
}
|
231 |
+
```
|
232 |
+
|
233 |
+
V-Diffusion
|
234 |
+
```bibtex
|
235 |
+
@misc{2202.00512,
|
236 |
+
Author = {Tim Salimans and Jonathan Ho},
|
237 |
+
Title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
238 |
+
Year = {2022},
|
239 |
+
Eprint = {arXiv:2202.00512},
|
240 |
+
}
|
241 |
+
```
|
242 |
+
|
243 |
+
Imagen (T5 Text Conditioning)
|
244 |
+
```bibtex
|
245 |
+
@misc{2205.11487,
|
246 |
+
Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi},
|
247 |
+
Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
|
248 |
+
Year = {2022},
|
249 |
+
Eprint = {arXiv:2205.11487},
|
250 |
+
}
|
251 |
+
```
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/setup.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="audio-diffusion-pytorch",
|
5 |
+
packages=find_packages(exclude=[]),
|
6 |
+
version="0.1.3",
|
7 |
+
license="MIT",
|
8 |
+
description="Audio Diffusion - PyTorch",
|
9 |
+
long_description_content_type="text/markdown",
|
10 |
+
author="Flavio Schneider",
|
11 |
+
author_email="archinetai@protonmail.com",
|
12 |
+
url="https://github.com/archinetai/audio-diffusion-pytorch",
|
13 |
+
keywords=["artificial intelligence", "deep learning", "audio generation"],
|
14 |
+
install_requires=[
|
15 |
+
"tqdm",
|
16 |
+
"torch>=1.6",
|
17 |
+
"torchaudio",
|
18 |
+
"data-science-types>=0.2",
|
19 |
+
"einops>=0.6",
|
20 |
+
"a-unet",
|
21 |
+
],
|
22 |
+
classifiers=[
|
23 |
+
"Development Status :: 4 - Beta",
|
24 |
+
"Intended Audience :: Developers",
|
25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
26 |
+
"License :: OSI Approved :: MIT License",
|
27 |
+
"Programming Language :: Python :: 3.6",
|
28 |
+
],
|
29 |
+
)
|
Modules/diffusion/reconstruction_head/audio-diffusion-pytorch/tests/testcustomloss.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
|
4 |
+
from audio_encoders_pytorch import MelE1d, TanhBottleneck
|
5 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
6 |
+
|
7 |
+
autoencoder = DiffusionAE(
|
8 |
+
encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder
|
9 |
+
in_channels=2,
|
10 |
+
channels=512,
|
11 |
+
multipliers=[1, 1],
|
12 |
+
factors=[2],
|
13 |
+
num_blocks=[12],
|
14 |
+
out_channels=32,
|
15 |
+
mel_channels=80,
|
16 |
+
mel_sample_rate=48000,
|
17 |
+
mel_normalize_log=True,
|
18 |
+
bottleneck=TanhBottleneck(),
|
19 |
+
),
|
20 |
+
inject_depth=6,
|
21 |
+
net_t=UNetV0, # The model type used for diffusion upsampling
|
22 |
+
in_channels=2, # U-Net: number of input/output (audio) channels
|
23 |
+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
|
24 |
+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
25 |
+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
|
26 |
+
diffusion_t=VDiffusion, # The diffusion method used
|
27 |
+
sampler_t=VSampler, # The diffusion sampler used
|
28 |
+
loss_fn=MultiResolutionSTFTLoss(), # The loss function used
|
29 |
+
)
|
30 |
+
|
31 |
+
# Train autoencoder with audio samples
|
32 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
33 |
+
loss = autoencoder(audio)
|
34 |
+
loss.backward()
|
35 |
+
|
36 |
+
# Encode/decode audio
|
37 |
+
audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
|
38 |
+
latent = autoencoder.encode(audio) # Encode
|
39 |
+
sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent
|
Modules/diffusion/sampler.py
ADDED
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import atan, cos, pi, sin, sqrt
|
2 |
+
from typing import Any, Callable, List, Optional, Tuple, Type
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange, reduce
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from .utils import *
|
11 |
+
|
12 |
+
"""
|
13 |
+
Diffusion Training
|
14 |
+
"""
|
15 |
+
|
16 |
+
""" Distributions """
|
17 |
+
|
18 |
+
|
19 |
+
class Distribution:
|
20 |
+
def __call__(self, num_samples: int, device: torch.device):
|
21 |
+
raise NotImplementedError()
|
22 |
+
|
23 |
+
|
24 |
+
class LogNormalDistribution(Distribution):
|
25 |
+
def __init__(self, mean: float, std: float):
|
26 |
+
self.mean = mean
|
27 |
+
self.std = std
|
28 |
+
|
29 |
+
def __call__(
|
30 |
+
self, num_samples: int, device: torch.device = torch.device("cpu")
|
31 |
+
) -> Tensor:
|
32 |
+
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
|
33 |
+
return normal.exp()
|
34 |
+
|
35 |
+
|
36 |
+
class UniformDistribution(Distribution):
|
37 |
+
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
38 |
+
return torch.rand(num_samples, device=device)
|
39 |
+
|
40 |
+
|
41 |
+
class VKDistribution(Distribution):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
min_value: float = 0.0,
|
45 |
+
max_value: float = float("inf"),
|
46 |
+
sigma_data: float = 1.0,
|
47 |
+
):
|
48 |
+
self.min_value = min_value
|
49 |
+
self.max_value = max_value
|
50 |
+
self.sigma_data = sigma_data
|
51 |
+
|
52 |
+
def __call__(
|
53 |
+
self, num_samples: int, device: torch.device = torch.device("cpu")
|
54 |
+
) -> Tensor:
|
55 |
+
sigma_data = self.sigma_data
|
56 |
+
min_cdf = atan(self.min_value / sigma_data) * 2 / pi
|
57 |
+
max_cdf = atan(self.max_value / sigma_data) * 2 / pi
|
58 |
+
u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
|
59 |
+
return torch.tan(u * pi / 2) * sigma_data
|
60 |
+
|
61 |
+
|
62 |
+
""" Diffusion Classes """
|
63 |
+
|
64 |
+
|
65 |
+
def pad_dims(x: Tensor, ndim: int) -> Tensor:
|
66 |
+
# Pads additional ndims to the right of the tensor
|
67 |
+
return x.view(*x.shape, *((1,) * ndim))
|
68 |
+
|
69 |
+
|
70 |
+
def clip(x: Tensor, dynamic_threshold: float = 0.0):
|
71 |
+
if dynamic_threshold == 0.0:
|
72 |
+
return x.clamp(-1.0, 1.0)
|
73 |
+
else:
|
74 |
+
# Dynamic thresholding
|
75 |
+
# Find dynamic threshold quantile for each batch
|
76 |
+
x_flat = rearrange(x, "b ... -> b (...)")
|
77 |
+
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
|
78 |
+
# Clamp to a min of 1.0
|
79 |
+
scale.clamp_(min=1.0)
|
80 |
+
# Clamp all values and scale
|
81 |
+
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
|
82 |
+
x = x.clamp(-scale, scale) / scale
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
def to_batch(
|
87 |
+
batch_size: int,
|
88 |
+
device: torch.device,
|
89 |
+
x: Optional[float] = None,
|
90 |
+
xs: Optional[Tensor] = None,
|
91 |
+
) -> Tensor:
|
92 |
+
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
93 |
+
# If x provided use the same for all batch items
|
94 |
+
if exists(x):
|
95 |
+
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
96 |
+
assert exists(xs)
|
97 |
+
return xs
|
98 |
+
|
99 |
+
|
100 |
+
class Diffusion(nn.Module):
|
101 |
+
|
102 |
+
alias: str = ""
|
103 |
+
|
104 |
+
"""Base diffusion class"""
|
105 |
+
|
106 |
+
def denoise_fn(
|
107 |
+
self,
|
108 |
+
x_noisy: Tensor,
|
109 |
+
sigmas: Optional[Tensor] = None,
|
110 |
+
sigma: Optional[float] = None,
|
111 |
+
**kwargs,
|
112 |
+
) -> Tensor:
|
113 |
+
raise NotImplementedError("Diffusion class missing denoise_fn")
|
114 |
+
|
115 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
116 |
+
raise NotImplementedError("Diffusion class missing forward function")
|
117 |
+
|
118 |
+
|
119 |
+
class VDiffusion(Diffusion):
|
120 |
+
|
121 |
+
alias = "v"
|
122 |
+
|
123 |
+
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
|
124 |
+
super().__init__()
|
125 |
+
self.net = net
|
126 |
+
self.sigma_distribution = sigma_distribution
|
127 |
+
|
128 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
129 |
+
angle = sigmas * pi / 2
|
130 |
+
alpha = torch.cos(angle)
|
131 |
+
beta = torch.sin(angle)
|
132 |
+
return alpha, beta
|
133 |
+
|
134 |
+
def denoise_fn(
|
135 |
+
self,
|
136 |
+
x_noisy: Tensor,
|
137 |
+
sigmas: Optional[Tensor] = None,
|
138 |
+
sigma: Optional[float] = None,
|
139 |
+
**kwargs,
|
140 |
+
) -> Tensor:
|
141 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
142 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
143 |
+
return self.net(x_noisy, sigmas, **kwargs)
|
144 |
+
|
145 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
146 |
+
batch_size, device = x.shape[0], x.device
|
147 |
+
|
148 |
+
# Sample amount of noise to add for each batch element
|
149 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
150 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
151 |
+
|
152 |
+
# Get noise
|
153 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
154 |
+
|
155 |
+
# Combine input and noise weighted by half-circle
|
156 |
+
alpha, beta = self.get_alpha_beta(sigmas_padded)
|
157 |
+
x_noisy = x * alpha + noise * beta
|
158 |
+
x_target = noise * alpha - x * beta
|
159 |
+
|
160 |
+
# Denoise and return loss
|
161 |
+
x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
|
162 |
+
return F.mse_loss(x_denoised, x_target)
|
163 |
+
|
164 |
+
|
165 |
+
class KDiffusion(Diffusion):
|
166 |
+
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
|
167 |
+
|
168 |
+
alias = "k"
|
169 |
+
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
net: nn.Module,
|
173 |
+
*,
|
174 |
+
sigma_distribution: Distribution,
|
175 |
+
sigma_data: float, # data distribution standard deviation
|
176 |
+
dynamic_threshold: float = 0.0,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.net = net
|
180 |
+
self.sigma_data = sigma_data
|
181 |
+
self.sigma_distribution = sigma_distribution
|
182 |
+
self.dynamic_threshold = dynamic_threshold
|
183 |
+
|
184 |
+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
|
185 |
+
sigma_data = self.sigma_data
|
186 |
+
c_noise = torch.log(sigmas) * 0.25
|
187 |
+
sigmas = rearrange(sigmas, "b -> b 1 1")
|
188 |
+
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
|
189 |
+
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
|
190 |
+
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
|
191 |
+
return c_skip, c_out, c_in, c_noise
|
192 |
+
|
193 |
+
def denoise_fn(
|
194 |
+
self,
|
195 |
+
x_noisy: Tensor,
|
196 |
+
sigmas: Optional[Tensor] = None,
|
197 |
+
sigma: Optional[float] = None,
|
198 |
+
**kwargs,
|
199 |
+
) -> Tensor:
|
200 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
201 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
202 |
+
|
203 |
+
# Predict network output and add skip connection
|
204 |
+
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
|
205 |
+
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
|
206 |
+
x_denoised = c_skip * x_noisy + c_out * x_pred
|
207 |
+
|
208 |
+
return x_denoised
|
209 |
+
|
210 |
+
def loss_weight(self, sigmas: Tensor) -> Tensor:
|
211 |
+
# Computes weight depending on data distribution
|
212 |
+
return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
|
213 |
+
|
214 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
215 |
+
batch_size, device = x.shape[0], x.device
|
216 |
+
from einops import rearrange, reduce
|
217 |
+
|
218 |
+
# Sample amount of noise to add for each batch element
|
219 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
220 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
221 |
+
|
222 |
+
# Add noise to input
|
223 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
224 |
+
x_noisy = x + sigmas_padded * noise
|
225 |
+
|
226 |
+
# Compute denoised values
|
227 |
+
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
|
228 |
+
|
229 |
+
# Compute weighted loss
|
230 |
+
losses = F.mse_loss(x_denoised, x, reduction="none")
|
231 |
+
losses = reduce(losses, "b ... -> b", "mean")
|
232 |
+
losses = losses * self.loss_weight(sigmas)
|
233 |
+
loss = losses.mean()
|
234 |
+
return loss
|
235 |
+
|
236 |
+
|
237 |
+
class VKDiffusion(Diffusion):
|
238 |
+
|
239 |
+
alias = "vk"
|
240 |
+
|
241 |
+
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
|
242 |
+
super().__init__()
|
243 |
+
self.net = net
|
244 |
+
self.sigma_distribution = sigma_distribution
|
245 |
+
|
246 |
+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
|
247 |
+
sigma_data = 1.0
|
248 |
+
sigmas = rearrange(sigmas, "b -> b 1 1")
|
249 |
+
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
|
250 |
+
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
|
251 |
+
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
|
252 |
+
return c_skip, c_out, c_in
|
253 |
+
|
254 |
+
def sigma_to_t(self, sigmas: Tensor) -> Tensor:
|
255 |
+
return sigmas.atan() / pi * 2
|
256 |
+
|
257 |
+
def t_to_sigma(self, t: Tensor) -> Tensor:
|
258 |
+
return (t * pi / 2).tan()
|
259 |
+
|
260 |
+
def denoise_fn(
|
261 |
+
self,
|
262 |
+
x_noisy: Tensor,
|
263 |
+
sigmas: Optional[Tensor] = None,
|
264 |
+
sigma: Optional[float] = None,
|
265 |
+
**kwargs,
|
266 |
+
) -> Tensor:
|
267 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
268 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
269 |
+
|
270 |
+
# Predict network output and add skip connection
|
271 |
+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
|
272 |
+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
|
273 |
+
x_denoised = c_skip * x_noisy + c_out * x_pred
|
274 |
+
return x_denoised
|
275 |
+
|
276 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
277 |
+
batch_size, device = x.shape[0], x.device
|
278 |
+
|
279 |
+
# Sample amount of noise to add for each batch element
|
280 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
281 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
282 |
+
|
283 |
+
# Add noise to input
|
284 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
285 |
+
x_noisy = x + sigmas_padded * noise
|
286 |
+
|
287 |
+
# Compute model output
|
288 |
+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
|
289 |
+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
|
290 |
+
|
291 |
+
# Compute v-objective target
|
292 |
+
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
|
293 |
+
|
294 |
+
# Compute loss
|
295 |
+
loss = F.mse_loss(x_pred, v_target)
|
296 |
+
return loss
|
297 |
+
|
298 |
+
|
299 |
+
"""
|
300 |
+
Diffusion Sampling
|
301 |
+
"""
|
302 |
+
|
303 |
+
""" Schedules """
|
304 |
+
|
305 |
+
|
306 |
+
class Schedule(nn.Module):
|
307 |
+
"""Interface used by different sampling schedules"""
|
308 |
+
|
309 |
+
def forward(self, num_steps: int, device: torch.device) -> Tensor:
|
310 |
+
raise NotImplementedError()
|
311 |
+
|
312 |
+
|
313 |
+
class LinearSchedule(Schedule):
|
314 |
+
def forward(self, num_steps: int, device: Any) -> Tensor:
|
315 |
+
sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
|
316 |
+
return sigmas
|
317 |
+
|
318 |
+
|
319 |
+
class KarrasSchedule(Schedule):
|
320 |
+
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
321 |
+
|
322 |
+
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
|
323 |
+
super().__init__()
|
324 |
+
self.sigma_min = sigma_min
|
325 |
+
self.sigma_max = sigma_max
|
326 |
+
self.rho = rho
|
327 |
+
|
328 |
+
def forward(self, num_steps: int, device: Any) -> Tensor:
|
329 |
+
rho_inv = 1.0 / self.rho
|
330 |
+
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
331 |
+
sigmas = (
|
332 |
+
self.sigma_max ** rho_inv
|
333 |
+
+ (steps / (num_steps - 1))
|
334 |
+
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
|
335 |
+
) ** self.rho
|
336 |
+
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
337 |
+
return sigmas
|
338 |
+
|
339 |
+
|
340 |
+
""" Samplers """
|
341 |
+
|
342 |
+
|
343 |
+
class Sampler(nn.Module):
|
344 |
+
|
345 |
+
diffusion_types: List[Type[Diffusion]] = []
|
346 |
+
|
347 |
+
def forward(
|
348 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
349 |
+
) -> Tensor:
|
350 |
+
raise NotImplementedError()
|
351 |
+
|
352 |
+
def inpaint(
|
353 |
+
self,
|
354 |
+
source: Tensor,
|
355 |
+
mask: Tensor,
|
356 |
+
fn: Callable,
|
357 |
+
sigmas: Tensor,
|
358 |
+
num_steps: int,
|
359 |
+
num_resamples: int,
|
360 |
+
) -> Tensor:
|
361 |
+
raise NotImplementedError("Inpainting not available with current sampler")
|
362 |
+
|
363 |
+
|
364 |
+
class VSampler(Sampler):
|
365 |
+
|
366 |
+
diffusion_types = [VDiffusion]
|
367 |
+
|
368 |
+
def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
|
369 |
+
angle = sigma * pi / 2
|
370 |
+
alpha = cos(angle)
|
371 |
+
beta = sin(angle)
|
372 |
+
return alpha, beta
|
373 |
+
|
374 |
+
def forward(
|
375 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
376 |
+
) -> Tensor:
|
377 |
+
x = sigmas[0] * noise
|
378 |
+
alpha, beta = self.get_alpha_beta(sigmas[0].item())
|
379 |
+
|
380 |
+
for i in range(num_steps - 1):
|
381 |
+
is_last = i == num_steps - 1
|
382 |
+
|
383 |
+
x_denoised = fn(x, sigma=sigmas[i])
|
384 |
+
x_pred = x * alpha - x_denoised * beta
|
385 |
+
x_eps = x * beta + x_denoised * alpha
|
386 |
+
|
387 |
+
if not is_last:
|
388 |
+
alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
|
389 |
+
x = x_pred * alpha + x_eps * beta
|
390 |
+
|
391 |
+
return x_pred
|
392 |
+
|
393 |
+
|
394 |
+
class KarrasSampler(Sampler):
|
395 |
+
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
|
396 |
+
|
397 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
398 |
+
|
399 |
+
def __init__(
|
400 |
+
self,
|
401 |
+
s_tmin: float = 0,
|
402 |
+
s_tmax: float = float("inf"),
|
403 |
+
s_churn: float = 0.0,
|
404 |
+
s_noise: float = 1.0,
|
405 |
+
):
|
406 |
+
super().__init__()
|
407 |
+
self.s_tmin = s_tmin
|
408 |
+
self.s_tmax = s_tmax
|
409 |
+
self.s_noise = s_noise
|
410 |
+
self.s_churn = s_churn
|
411 |
+
|
412 |
+
def step(
|
413 |
+
self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
|
414 |
+
) -> Tensor:
|
415 |
+
"""Algorithm 2 (step)"""
|
416 |
+
# Select temporarily increased noise level
|
417 |
+
sigma_hat = sigma + gamma * sigma
|
418 |
+
# Add noise to move from sigma to sigma_hat
|
419 |
+
epsilon = self.s_noise * torch.randn_like(x)
|
420 |
+
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
|
421 |
+
# Evaluate ∂x/∂sigma at sigma_hat
|
422 |
+
d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
|
423 |
+
# Take euler step from sigma_hat to sigma_next
|
424 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d
|
425 |
+
# Second order correction
|
426 |
+
if sigma_next != 0:
|
427 |
+
model_out_next = fn(x_next, sigma=sigma_next)
|
428 |
+
d_prime = (x_next - model_out_next) / sigma_next
|
429 |
+
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
|
430 |
+
return x_next
|
431 |
+
|
432 |
+
def forward(
|
433 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
434 |
+
) -> Tensor:
|
435 |
+
x = sigmas[0] * noise
|
436 |
+
# Compute gammas
|
437 |
+
gammas = torch.where(
|
438 |
+
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
|
439 |
+
min(self.s_churn / num_steps, sqrt(2) - 1),
|
440 |
+
0.0,
|
441 |
+
)
|
442 |
+
# Denoise to sample
|
443 |
+
for i in range(num_steps - 1):
|
444 |
+
x = self.step(
|
445 |
+
x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
|
446 |
+
)
|
447 |
+
|
448 |
+
return x
|
449 |
+
|
450 |
+
|
451 |
+
class AEulerSampler(Sampler):
|
452 |
+
|
453 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
454 |
+
|
455 |
+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
|
456 |
+
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
457 |
+
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
458 |
+
return sigma_up, sigma_down
|
459 |
+
|
460 |
+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
|
461 |
+
# Sigma steps
|
462 |
+
sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
|
463 |
+
# Derivative at sigma (∂x/∂sigma)
|
464 |
+
d = (x - fn(x, sigma=sigma)) / sigma
|
465 |
+
# Euler method
|
466 |
+
x_next = x + d * (sigma_down - sigma)
|
467 |
+
# Add randomness
|
468 |
+
x_next = x_next + torch.randn_like(x) * sigma_up
|
469 |
+
return x_next
|
470 |
+
|
471 |
+
def forward(
|
472 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
473 |
+
) -> Tensor:
|
474 |
+
x = sigmas[0] * noise
|
475 |
+
# Denoise to sample
|
476 |
+
for i in range(num_steps - 1):
|
477 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
478 |
+
return x
|
479 |
+
|
480 |
+
|
481 |
+
class ADPM2Sampler(Sampler):
|
482 |
+
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
483 |
+
|
484 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
485 |
+
|
486 |
+
def __init__(self, rho: float = 1.0):
|
487 |
+
super().__init__()
|
488 |
+
self.rho = rho
|
489 |
+
|
490 |
+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
|
491 |
+
r = self.rho
|
492 |
+
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
493 |
+
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
494 |
+
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
495 |
+
return sigma_up, sigma_down, sigma_mid
|
496 |
+
|
497 |
+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
|
498 |
+
# Sigma steps
|
499 |
+
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
500 |
+
# Derivative at sigma (∂x/∂sigma)
|
501 |
+
d = (x - fn(x, sigma=sigma)) / sigma
|
502 |
+
# Denoise to midpoint
|
503 |
+
x_mid = x + d * (sigma_mid - sigma)
|
504 |
+
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
|
505 |
+
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
|
506 |
+
# Denoise to next
|
507 |
+
x = x + d_mid * (sigma_down - sigma)
|
508 |
+
# Add randomness
|
509 |
+
x_next = x + torch.randn_like(x) * sigma_up
|
510 |
+
return x_next
|
511 |
+
|
512 |
+
def forward(
|
513 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
514 |
+
) -> Tensor:
|
515 |
+
x = sigmas[0] * noise
|
516 |
+
# Denoise to sample
|
517 |
+
for i in range(num_steps - 1):
|
518 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
519 |
+
return x
|
520 |
+
|
521 |
+
def inpaint(
|
522 |
+
self,
|
523 |
+
source: Tensor,
|
524 |
+
mask: Tensor,
|
525 |
+
fn: Callable,
|
526 |
+
sigmas: Tensor,
|
527 |
+
num_steps: int,
|
528 |
+
num_resamples: int,
|
529 |
+
) -> Tensor:
|
530 |
+
x = sigmas[0] * torch.randn_like(source)
|
531 |
+
|
532 |
+
for i in range(num_steps - 1):
|
533 |
+
# Noise source to current noise level
|
534 |
+
source_noisy = source + sigmas[i] * torch.randn_like(source)
|
535 |
+
for r in range(num_resamples):
|
536 |
+
# Merge noisy source and current then denoise
|
537 |
+
x = source_noisy * mask + x * ~mask
|
538 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
539 |
+
# Renoise if not last resample step
|
540 |
+
if r < num_resamples - 1:
|
541 |
+
sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
|
542 |
+
x = x + sigma * torch.randn_like(x)
|
543 |
+
|
544 |
+
return source * mask + x * ~mask
|
545 |
+
|
546 |
+
|
547 |
+
""" Main Classes """
|
548 |
+
|
549 |
+
|
550 |
+
class DiffusionSampler(nn.Module):
|
551 |
+
def __init__(
|
552 |
+
self,
|
553 |
+
diffusion: Diffusion,
|
554 |
+
*,
|
555 |
+
sampler: Sampler,
|
556 |
+
sigma_schedule: Schedule,
|
557 |
+
num_steps: Optional[int] = None,
|
558 |
+
clamp: bool = True,
|
559 |
+
):
|
560 |
+
super().__init__()
|
561 |
+
self.denoise_fn = diffusion.denoise_fn
|
562 |
+
self.sampler = sampler
|
563 |
+
self.sigma_schedule = sigma_schedule
|
564 |
+
self.num_steps = num_steps
|
565 |
+
self.clamp = clamp
|
566 |
+
|
567 |
+
# Check sampler is compatible with diffusion type
|
568 |
+
sampler_class = sampler.__class__.__name__
|
569 |
+
diffusion_class = diffusion.__class__.__name__
|
570 |
+
message = f"{sampler_class} incompatible with {diffusion_class}"
|
571 |
+
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
|
572 |
+
|
573 |
+
def forward(
|
574 |
+
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
|
575 |
+
) -> Tensor:
|
576 |
+
device = noise.device
|
577 |
+
num_steps = default(num_steps, self.num_steps) # type: ignore
|
578 |
+
assert exists(num_steps), "Parameter `num_steps` must be provided"
|
579 |
+
# Compute sigmas using schedule
|
580 |
+
sigmas = self.sigma_schedule(num_steps, device)
|
581 |
+
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
582 |
+
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
583 |
+
# Sample using sampler
|
584 |
+
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
585 |
+
x = x.clamp(-1.0, 1.0) if self.clamp else x
|
586 |
+
return x
|
587 |
+
|
588 |
+
|
589 |
+
class DiffusionInpainter(nn.Module):
|
590 |
+
def __init__(
|
591 |
+
self,
|
592 |
+
diffusion: Diffusion,
|
593 |
+
*,
|
594 |
+
num_steps: int,
|
595 |
+
num_resamples: int,
|
596 |
+
sampler: Sampler,
|
597 |
+
sigma_schedule: Schedule,
|
598 |
+
):
|
599 |
+
super().__init__()
|
600 |
+
self.denoise_fn = diffusion.denoise_fn
|
601 |
+
self.num_steps = num_steps
|
602 |
+
self.num_resamples = num_resamples
|
603 |
+
self.inpaint_fn = sampler.inpaint
|
604 |
+
self.sigma_schedule = sigma_schedule
|
605 |
+
|
606 |
+
@torch.no_grad()
|
607 |
+
def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
|
608 |
+
x = self.inpaint_fn(
|
609 |
+
source=inpaint,
|
610 |
+
mask=inpaint_mask,
|
611 |
+
fn=self.denoise_fn,
|
612 |
+
sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
|
613 |
+
num_steps=self.num_steps,
|
614 |
+
num_resamples=self.num_resamples,
|
615 |
+
)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
def sequential_mask(like: Tensor, start: int) -> Tensor:
|
620 |
+
length, device = like.shape[2], like.device
|
621 |
+
mask = torch.ones_like(like, dtype=torch.bool)
|
622 |
+
mask[:, :, start:] = torch.zeros((length - start,), device=device)
|
623 |
+
return mask
|
624 |
+
|
625 |
+
|
626 |
+
class SpanBySpanComposer(nn.Module):
|
627 |
+
def __init__(
|
628 |
+
self,
|
629 |
+
inpainter: DiffusionInpainter,
|
630 |
+
*,
|
631 |
+
num_spans: int,
|
632 |
+
):
|
633 |
+
super().__init__()
|
634 |
+
self.inpainter = inpainter
|
635 |
+
self.num_spans = num_spans
|
636 |
+
|
637 |
+
def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
|
638 |
+
half_length = start.shape[2] // 2
|
639 |
+
|
640 |
+
spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
|
641 |
+
# Inpaint second half from first half
|
642 |
+
inpaint = torch.zeros_like(start)
|
643 |
+
inpaint[:, :, :half_length] = start[:, :, half_length:]
|
644 |
+
inpaint_mask = sequential_mask(like=start, start=half_length)
|
645 |
+
|
646 |
+
for i in range(self.num_spans):
|
647 |
+
# Inpaint second half
|
648 |
+
span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
|
649 |
+
# Replace first half with generated second half
|
650 |
+
second_half = span[:, :, half_length:]
|
651 |
+
inpaint[:, :, :half_length] = second_half
|
652 |
+
# Save generated span
|
653 |
+
spans.append(second_half)
|
654 |
+
|
655 |
+
return torch.cat(spans, dim=2)
|
656 |
+
|
657 |
+
|
658 |
+
class XDiffusion(nn.Module):
|
659 |
+
def __init__(self, type: str, net: nn.Module, **kwargs):
|
660 |
+
super().__init__()
|
661 |
+
|
662 |
+
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
|
663 |
+
aliases = [t.alias for t in diffusion_classes] # type: ignore
|
664 |
+
message = f"type='{type}' must be one of {*aliases,}"
|
665 |
+
assert type in aliases, message
|
666 |
+
self.net = net
|
667 |
+
|
668 |
+
for XDiffusion in diffusion_classes:
|
669 |
+
if XDiffusion.alias == type: # type: ignore
|
670 |
+
self.diffusion = XDiffusion(net=net, **kwargs)
|
671 |
+
|
672 |
+
def forward(self, *args, **kwargs) -> Tensor:
|
673 |
+
return self.diffusion(*args, **kwargs)
|
674 |
+
|
675 |
+
def sample(
|
676 |
+
self,
|
677 |
+
noise: Tensor,
|
678 |
+
num_steps: int,
|
679 |
+
sigma_schedule: Schedule,
|
680 |
+
sampler: Sampler,
|
681 |
+
clamp: bool,
|
682 |
+
**kwargs,
|
683 |
+
) -> Tensor:
|
684 |
+
diffusion_sampler = DiffusionSampler(
|
685 |
+
diffusion=self.diffusion,
|
686 |
+
sampler=sampler,
|
687 |
+
sigma_schedule=sigma_schedule,
|
688 |
+
num_steps=num_steps,
|
689 |
+
clamp=clamp,
|
690 |
+
)
|
691 |
+
return diffusion_sampler(noise, **kwargs)
|
Modules/diffusion/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from inspect import isfunction
|
3 |
+
from math import ceil, floor, log2, pi
|
4 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Generator, Tensor
|
10 |
+
from typing_extensions import TypeGuard
|
11 |
+
|
12 |
+
T = TypeVar("T")
|
13 |
+
|
14 |
+
|
15 |
+
def exists(val: Optional[T]) -> TypeGuard[T]:
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
|
19 |
+
def iff(condition: bool, value: T) -> Optional[T]:
|
20 |
+
return value if condition else None
|
21 |
+
|
22 |
+
|
23 |
+
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
24 |
+
return isinstance(obj, list) or isinstance(obj, tuple)
|
25 |
+
|
26 |
+
|
27 |
+
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
28 |
+
if exists(val):
|
29 |
+
return val
|
30 |
+
return d() if isfunction(d) else d
|
31 |
+
|
32 |
+
|
33 |
+
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
34 |
+
if isinstance(val, tuple):
|
35 |
+
return list(val)
|
36 |
+
if isinstance(val, list):
|
37 |
+
return val
|
38 |
+
return [val] # type: ignore
|
39 |
+
|
40 |
+
|
41 |
+
def prod(vals: Sequence[int]) -> int:
|
42 |
+
return reduce(lambda x, y: x * y, vals)
|
43 |
+
|
44 |
+
|
45 |
+
def closest_power_2(x: float) -> int:
|
46 |
+
exponent = log2(x)
|
47 |
+
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
48 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
49 |
+
return 2 ** int(exponent_closest)
|
50 |
+
|
51 |
+
def rand_bool(shape, proba, device = None):
|
52 |
+
if proba == 1:
|
53 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
54 |
+
elif proba == 0:
|
55 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
56 |
+
else:
|
57 |
+
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
58 |
+
|
59 |
+
|
60 |
+
"""
|
61 |
+
Kwargs Utils
|
62 |
+
"""
|
63 |
+
|
64 |
+
|
65 |
+
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
66 |
+
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
67 |
+
for key in d.keys():
|
68 |
+
no_prefix = int(not key.startswith(prefix))
|
69 |
+
return_dicts[no_prefix][key] = d[key]
|
70 |
+
return return_dicts
|
71 |
+
|
72 |
+
|
73 |
+
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
74 |
+
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
75 |
+
if keep_prefix:
|
76 |
+
return kwargs_with_prefix, kwargs
|
77 |
+
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
78 |
+
return kwargs_no_prefix, kwargs
|
79 |
+
|
80 |
+
|
81 |
+
def prefix_dict(prefix: str, d: Dict) -> Dict:
|
82 |
+
return {prefix + str(k): v for k, v in d.items()}
|
Modules/discriminators.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
6 |
+
|
7 |
+
from .utils import get_padding
|
8 |
+
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
12 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
13 |
+
Args:
|
14 |
+
x (Tensor): Input signal tensor (B, T).
|
15 |
+
fft_size (int): FFT size.
|
16 |
+
hop_size (int): Hop size.
|
17 |
+
win_length (int): Window length.
|
18 |
+
window (str): Window function type.
|
19 |
+
Returns:
|
20 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
21 |
+
"""
|
22 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
|
23 |
+
return_complex=True)
|
24 |
+
real = x_stft[..., 0]
|
25 |
+
imag = x_stft[..., 1]
|
26 |
+
|
27 |
+
return torch.abs(x_stft).transpose(2, 1)
|
28 |
+
|
29 |
+
class SpecDiscriminator(nn.Module):
|
30 |
+
"""docstring for Discriminator."""
|
31 |
+
|
32 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
33 |
+
super(SpecDiscriminator, self).__init__()
|
34 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
35 |
+
self.fft_size = fft_size
|
36 |
+
self.shift_size = shift_size
|
37 |
+
self.win_length = win_length
|
38 |
+
self.window = getattr(torch, window)(win_length)
|
39 |
+
self.discriminators = nn.ModuleList([
|
40 |
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
41 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
42 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
43 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
44 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
|
45 |
+
])
|
46 |
+
|
47 |
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
48 |
+
|
49 |
+
def forward(self, y):
|
50 |
+
|
51 |
+
fmap = []
|
52 |
+
y = y.squeeze(1)
|
53 |
+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
|
54 |
+
y = y.unsqueeze(1)
|
55 |
+
for i, d in enumerate(self.discriminators):
|
56 |
+
y = d(y)
|
57 |
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
58 |
+
fmap.append(y)
|
59 |
+
|
60 |
+
y = self.out(y)
|
61 |
+
fmap.append(y)
|
62 |
+
|
63 |
+
return torch.flatten(y, 1, -1), fmap
|
64 |
+
|
65 |
+
class MultiResSpecDiscriminator(torch.nn.Module):
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
fft_sizes=[1024, 2048, 512],
|
69 |
+
hop_sizes=[120, 240, 50],
|
70 |
+
win_lengths=[600, 1200, 240],
|
71 |
+
window="hann_window"):
|
72 |
+
|
73 |
+
super(MultiResSpecDiscriminator, self).__init__()
|
74 |
+
self.discriminators = nn.ModuleList([
|
75 |
+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
76 |
+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
77 |
+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
|
78 |
+
])
|
79 |
+
|
80 |
+
def forward(self, y, y_hat):
|
81 |
+
y_d_rs = []
|
82 |
+
y_d_gs = []
|
83 |
+
fmap_rs = []
|
84 |
+
fmap_gs = []
|
85 |
+
for i, d in enumerate(self.discriminators):
|
86 |
+
y_d_r, fmap_r = d(y)
|
87 |
+
y_d_g, fmap_g = d(y_hat)
|
88 |
+
y_d_rs.append(y_d_r)
|
89 |
+
fmap_rs.append(fmap_r)
|
90 |
+
y_d_gs.append(y_d_g)
|
91 |
+
fmap_gs.append(fmap_g)
|
92 |
+
|
93 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
94 |
+
|
95 |
+
|
96 |
+
class DiscriminatorP(torch.nn.Module):
|
97 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
98 |
+
super(DiscriminatorP, self).__init__()
|
99 |
+
self.period = period
|
100 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
101 |
+
self.convs = nn.ModuleList([
|
102 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
103 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
104 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
105 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
106 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
107 |
+
])
|
108 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
fmap = []
|
112 |
+
|
113 |
+
# 1d to 2d
|
114 |
+
b, c, t = x.shape
|
115 |
+
if t % self.period != 0: # pad first
|
116 |
+
n_pad = self.period - (t % self.period)
|
117 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
118 |
+
t = t + n_pad
|
119 |
+
x = x.view(b, c, t // self.period, self.period)
|
120 |
+
|
121 |
+
for l in self.convs:
|
122 |
+
x = l(x)
|
123 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
124 |
+
fmap.append(x)
|
125 |
+
x = self.conv_post(x)
|
126 |
+
fmap.append(x)
|
127 |
+
x = torch.flatten(x, 1, -1)
|
128 |
+
|
129 |
+
return x, fmap
|
130 |
+
|
131 |
+
|
132 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
133 |
+
def __init__(self):
|
134 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
135 |
+
self.discriminators = nn.ModuleList([
|
136 |
+
DiscriminatorP(2),
|
137 |
+
DiscriminatorP(3),
|
138 |
+
DiscriminatorP(5),
|
139 |
+
DiscriminatorP(7),
|
140 |
+
DiscriminatorP(11),
|
141 |
+
])
|
142 |
+
|
143 |
+
def forward(self, y, y_hat):
|
144 |
+
y_d_rs = []
|
145 |
+
y_d_gs = []
|
146 |
+
fmap_rs = []
|
147 |
+
fmap_gs = []
|
148 |
+
for i, d in enumerate(self.discriminators):
|
149 |
+
y_d_r, fmap_r = d(y)
|
150 |
+
y_d_g, fmap_g = d(y_hat)
|
151 |
+
y_d_rs.append(y_d_r)
|
152 |
+
fmap_rs.append(fmap_r)
|
153 |
+
y_d_gs.append(y_d_g)
|
154 |
+
fmap_gs.append(fmap_g)
|
155 |
+
|
156 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
157 |
+
|
158 |
+
class WavLMDiscriminator(nn.Module):
|
159 |
+
"""docstring for Discriminator."""
|
160 |
+
|
161 |
+
def __init__(self, slm_hidden=768,
|
162 |
+
slm_layers=13,
|
163 |
+
initial_channel=64,
|
164 |
+
use_spectral_norm=False):
|
165 |
+
super(WavLMDiscriminator, self).__init__()
|
166 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
167 |
+
self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
|
168 |
+
|
169 |
+
self.convs = nn.ModuleList([
|
170 |
+
norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
|
171 |
+
norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
|
172 |
+
norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
|
173 |
+
])
|
174 |
+
|
175 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
x = self.pre(x)
|
179 |
+
|
180 |
+
fmap = []
|
181 |
+
for l in self.convs:
|
182 |
+
x = l(x)
|
183 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
184 |
+
fmap.append(x)
|
185 |
+
x = self.conv_post(x)
|
186 |
+
x = torch.flatten(x, 1, -1)
|
187 |
+
|
188 |
+
return x
|