File size: 6,983 Bytes
af05d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e004bff
 
 
 
af05d16
e004bff
 
af05d16
e004bff
 
af05d16
e004bff
af05d16
e004bff
 
 
 
af05d16
e004bff
af05d16
e004bff
 
 
 
 
 
 
 
af05d16
e004bff
 
af05d16
e004bff
 
af05d16
e004bff
af05d16
e004bff
 
 
 
af05d16
e004bff
af05d16
e004bff
 
 
 
 
af05d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import argparse
import sys
import torch
import json
from multiprocessing import cpu_count

global usefp16
usefp16 = False


def use_fp32_config():
    usefp16 = False
    device_capability = 0
    if torch.cuda.is_available():
        device = torch.device("cuda:0")  # Assuming you have only one GPU (index 0).
        device_capability = torch.cuda.get_device_capability(device)[0]
        if device_capability >= 7:
        #     usefp16 = True
        #     for config_file in ["32k.json", "40k.json", "48k.json"]:
        #         with open(f"configs/{config_file}", "r") as d:
        #             data = json.load(d)

        #         if "train" in data and "fp16_run" in data["train"]:
        #             data["train"]["fp16_run"] = True

        #         with open(f"configs/{config_file}", "w") as d:
        #             json.dump(data, d, indent=4)

        #         print(f"Set fp16_run to true in {config_file}")

        #     with open(
        #         "trainset_preprocess_pipeline_print.py", "r", encoding="utf-8"
        #     ) as f:
        #         strr = f.read()

        #     strr = strr.replace("3.0", "3.7")

        #     with open(
        #         "trainset_preprocess_pipeline_print.py", "w", encoding="utf-8"
        #     ) as f:
        #         f.write(strr)
        # else:
        #     for config_file in ["32k.json", "40k.json", "48k.json"]:
        #         with open(f"configs/{config_file}", "r") as f:
        #             data = json.load(f)

        #         if "train" in data and "fp16_run" in data["train"]:
        #             data["train"]["fp16_run"] = False

        #         with open(f"configs/{config_file}", "w") as d:
        #             json.dump(data, d, indent=4)

        #         print(f"Set fp16_run to false in {config_file}")

        #     with open(
        #         "trainset_preprocess_pipeline_print.py", "r", encoding="utf-8"
        #     ) as f:
        #         strr = f.read()

        #     strr = strr.replace("3.7", "3.0")

        #     with open(
        #         "trainset_preprocess_pipeline_print.py", "w", encoding="utf-8"
        #     ) as f:
        #         f.write(strr)
          pass
    else:
        print(
            "CUDA is not available. Make sure you have an NVIDIA GPU and CUDA installed."
        )
    return (usefp16, device_capability)


class Config:
    def __init__(self):
        self.device = "cuda:0"
        self.is_half = True
        self.n_cpu = 0
        self.gpu_name = None
        self.gpu_mem = None
        (
            self.python_cmd,
            self.listen_port,
            self.iscolab,
            self.noparallel,
            self.noautoopen,
            self.paperspace,
            self.is_cli,
        ) = self.arg_parse()

        self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()

    @staticmethod
    def arg_parse() -> tuple:
        exe = sys.executable or "python"
        parser = argparse.ArgumentParser()
        parser.add_argument("--port", type=int, default=7865, help="Listen port")
        parser.add_argument("--pycmd", type=str, default=exe, help="Python command")
        parser.add_argument("--colab", action="store_true", help="Launch in colab")
        parser.add_argument(
            "--noparallel", action="store_true", help="Disable parallel processing"
        )
        parser.add_argument(
            "--noautoopen",
            action="store_true",
            help="Do not open in browser automatically",
        )
        parser.add_argument(  # Fork Feature. Paperspace integration for web UI
            "--paperspace",
            action="store_true",
            help="Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems.",
        )
        parser.add_argument(  # Fork Feature. Embed a CLI into the infer-web.py
            "--is_cli",
            action="store_true",
            help="Use the CLI instead of setting up a gradio UI. This flag will launch an RVC text interface where you can execute functions from infer-web.py!",
        )
        cmd_opts = parser.parse_args()

        cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865

        return (
            cmd_opts.pycmd,
            cmd_opts.port,
            cmd_opts.colab,
            cmd_opts.noparallel,
            cmd_opts.noautoopen,
            cmd_opts.paperspace,
            cmd_opts.is_cli,
        )

    # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
    # check `getattr` and try it for compatibility
    @staticmethod
    def has_mps() -> bool:
        if not torch.backends.mps.is_available():
            return False
        try:
            torch.zeros(1).to(torch.device("mps"))
            return True
        except Exception:
            return False

    def device_config(self) -> tuple:
        if torch.cuda.is_available():
            i_device = int(self.device.split(":")[-1])
            self.gpu_name = torch.cuda.get_device_name(i_device)
            if (
                ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
                or "P40" in self.gpu_name.upper()
                or "1060" in self.gpu_name
                or "1070" in self.gpu_name
                or "1080" in self.gpu_name
            ):
                print("Found GPU", self.gpu_name, ", force to fp32")
                self.is_half = False
            else:
                print("Found GPU", self.gpu_name)
                use_fp32_config()
            self.gpu_mem = int(
                torch.cuda.get_device_properties(i_device).total_memory
                / 1024
                / 1024
                / 1024
                + 0.4
            )
            if self.gpu_mem <= 4:
                with open("trainset_preprocess_pipeline_print.py", "r") as f:
                    strr = f.read().replace("3.7", "3.0")
                with open("trainset_preprocess_pipeline_print.py", "w") as f:
                    f.write(strr)
        elif self.has_mps():
            print("No supported Nvidia GPU found, use MPS instead")
            self.device = "mps"
            self.is_half = False
            use_fp32_config()
        else:
            print("No supported Nvidia GPU found, use CPU instead")
            self.device = "cpu"
            self.is_half = False
            use_fp32_config()

        if self.n_cpu == 0:
            self.n_cpu = cpu_count()

        if self.is_half:
            # 6G显存配置
            x_pad = 3
            x_query = 10
            x_center = 60
            x_max = 65
        else:
            # 5G显存配置
            x_pad = 1
            x_query = 6
            x_center = 38
            x_max = 41

        if self.gpu_mem != None and self.gpu_mem <= 4:
            x_pad = 1
            x_query = 5
            x_center = 30
            x_max = 32

        return x_pad, x_query, x_center, x_max