File size: 7,163 Bytes
2c01ee6
a8c39f5
2c01ee6
a8c39f5
1378843
a8c39f5
1378843
 
2c01ee6
1378843
a8c39f5
2c01ee6
 
a8c39f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1378843
a8c39f5
 
 
 
 
 
 
 
 
 
2c01ee6
a8c39f5
 
 
 
1378843
a8c39f5
 
 
 
 
 
 
 
 
 
2c01ee6
a8c39f5
 
 
 
2c01ee6
 
 
 
a8c39f5
 
1378843
a8c39f5
 
 
 
 
 
 
 
 
 
 
 
1378843
a8c39f5
 
 
 
1378843
 
 
a8c39f5
 
 
 
 
 
 
 
 
 
 
 
1378843
 
a8c39f5
 
 
1378843
 
 
 
 
 
 
a8c39f5
 
 
 
 
 
 
 
 
 
 
1378843
 
a8c39f5
 
 
2c01ee6
1378843
 
 
 
 
 
 
a8c39f5
 
 
2c01ee6
 
 
 
a8c39f5
 
 
 
 
 
1378843
a8c39f5
 
 
2c01ee6
 
 
a8c39f5
 
 
 
 
 
 
 
 
 
 
1378843
 
a8c39f5
2c01ee6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor

import requests
from tqdm import tqdm

from tts_service.utils import env_bool
from tts_service.voices import voice_manager

log = logging.getLogger(__name__)

url_base = "https://huggingface.co/IAHispano/Applio/resolve/main/Resources"

pretraineds_v1_list = [
    (
        "pretrained_v1/",
        [
            "D32k.pth",
            "D40k.pth",
            "D48k.pth",
            "G32k.pth",
            "G40k.pth",
            "G48k.pth",
            "f0D32k.pth",
            "f0D40k.pth",
            "f0D48k.pth",
            "f0G32k.pth",
            "f0G40k.pth",
            "f0G48k.pth",
        ],
    )
]
pretraineds_v2_list = [
    (
        "pretrained_v2/",
        [
            "D32k.pth",
            "D40k.pth",
            "D48k.pth",
            "G32k.pth",
            "G40k.pth",
            "G48k.pth",
            "f0D32k.pth",
            "f0D40k.pth",
            "f0D48k.pth",
            "f0G32k.pth",
            "f0G40k.pth",
            "f0G48k.pth",
        ],
    )
]
models_list = [("predictors/", ["rmvpe.pt", "fcpe.pt"])]
embedders_list = [("embedders/contentvec/", ["pytorch_model.bin", "config.json"])]

folder_mapping_list = {
    "pretrained_v1/": "rvc/models/pretraineds/pretrained_v1/",
    "pretrained_v2/": "rvc/models/pretraineds/pretrained_v2/",
    "embedders/contentvec/": "rvc/models/embedders/contentvec/",
    "predictors/": "rvc/models/predictors/",
    "formant/": "rvc/models/formant/",
}


def get_file_size_if_missing(file_list: list[tuple[str, list[str]]]) -> int:
    """
    Calculate the total size of files to be downloaded only if they do not exist locally.
    """
    total_size = 0
    for remote_folder, files in file_list:
        local_folder = folder_mapping_list.get(remote_folder, "")
        for file in files:
            destination_path = os.path.join(local_folder, file)
            if not os.path.exists(destination_path):
                url = f"{url_base}/{remote_folder}{file}"
                response = requests.head(url, allow_redirects=True)
                total_size += int(response.headers.get("content-length", 0))
    return total_size


def download_file(url: str, destination_path: str, global_bar: tqdm) -> None:
    """
    Download a file from the given URL to the specified destination path,
    updating the global progress bar as data is downloaded.
    """

    dir_name = os.path.dirname(destination_path)
    if dir_name:
        os.makedirs(dir_name, exist_ok=True)
    response = requests.get(url, stream=True)
    block_size = 1024
    total = 0
    with open(destination_path, "wb") as file:
        for data in response.iter_content(block_size):
            file.write(data)
            global_bar.update(len(data))
            total += len(data)
    global_bar.clear()
    log.info(f"Downloaded {total:,} bytes to {destination_path}")
    global_bar.display()


def download_mapping_files(file_mapping_list: list[tuple[str, list[str]]], global_bar: tqdm) -> None:
    """
    Download all files in the provided file mapping list using a thread pool executor,
    and update the global progress bar as downloads progress.
    """
    with ThreadPoolExecutor() as executor:
        futures = []
        for remote_folder, file_list in file_mapping_list:
            local_folder = folder_mapping_list.get(remote_folder, "")
            for file in file_list:
                destination_path = os.path.join(local_folder, file)
                if not os.path.exists(destination_path):
                    url = f"{url_base}/{remote_folder}{file}"
                    futures.append(executor.submit(download_file, url, destination_path, global_bar))
        for future in futures:
            future.result()


def split_pretraineds(
    pretrained_list: list[tuple[str, list[str]]],
) -> tuple[list[tuple[str, list[str]]], list[tuple[str, list[str]]]]:
    f0_list = []
    non_f0_list = []
    for folder, files in pretrained_list:
        f0_files = [f for f in files if f.startswith("f0")]
        non_f0_files = [f for f in files if not f.startswith("f0")]
        if f0_files:
            f0_list.append((folder, f0_files))
        if non_f0_files:
            non_f0_list.append((folder, non_f0_files))
    return f0_list, non_f0_list


pretraineds_v1_f0_list, pretraineds_v1_nof0_list = split_pretraineds(pretraineds_v1_list)
pretraineds_v2_f0_list, pretraineds_v2_nof0_list = split_pretraineds(pretraineds_v2_list)


def calculate_total_size(
    pretraineds_v1_f0: list[tuple[str, list[str]]],
    pretraineds_v1_nof0: list[tuple[str, list[str]]],
    pretraineds_v2_f0: list[tuple[str, list[str]]],
    pretraineds_v2_nof0: list[tuple[str, list[str]]],
    models: bool,
    voices: bool,
) -> int:
    """
    Calculate the total size of all files to be downloaded based on selected categories.
    """
    total_size = 0
    if models:
        total_size += get_file_size_if_missing(models_list)
        total_size += get_file_size_if_missing(embedders_list)
    total_size += get_file_size_if_missing(pretraineds_v1_f0)
    total_size += get_file_size_if_missing(pretraineds_v1_nof0)
    total_size += get_file_size_if_missing(pretraineds_v2_f0)
    total_size += get_file_size_if_missing(pretraineds_v2_nof0)
    if voices:
        total_size += voice_manager.get_voices_size_if_missing()
    return total_size


def prerequisites_download_pipeline(
    pretraineds_v1_f0: bool,
    pretraineds_v1_nof0: bool,
    pretraineds_v2_f0: bool,
    pretraineds_v2_nof0: bool,
    models: bool,
    voices: bool,
) -> None:
    """
    Manage the download pipeline for different categories of files.
    """
    if env_bool("OFFLINE", False):
        log.info("Skipping download due to OFFLINE environment variable")
        return

    total_size = calculate_total_size(
        pretraineds_v1_f0_list if pretraineds_v1_f0 else [],
        pretraineds_v1_nof0_list if pretraineds_v1_nof0 else [],
        pretraineds_v2_f0_list if pretraineds_v2_f0 else [],
        pretraineds_v2_nof0_list if pretraineds_v2_nof0 else [],
        models,
        voices,
    )

    if total_size > 0:
        log.info(f"Will download {total_size:,} bytes")
        miniters = None if sys.stdout.isatty() else total_size
        with tqdm(total=total_size, unit="iB", unit_scale=True, desc="Downloading...", miniters=miniters) as global_bar:
            if models:
                download_mapping_files(models_list, global_bar)
                download_mapping_files(embedders_list, global_bar)
            if pretraineds_v1_f0:
                download_mapping_files(pretraineds_v1_f0_list, global_bar)
            if pretraineds_v1_nof0:
                download_mapping_files(pretraineds_v1_nof0_list, global_bar)
            if pretraineds_v2_f0:
                download_mapping_files(pretraineds_v2_f0_list, global_bar)
            if pretraineds_v2_nof0:
                download_mapping_files(pretraineds_v2_nof0_list, global_bar)
            if voices:
                voice_manager.download_voice_files(global_bar)
    else:
        log.info("No files to download")