import sys import contextlib from functools import lru_cache import torch #from modules import errors if sys.platform == "darwin": from modules import mac_specific def has_mps() -> bool: if sys.platform != "darwin": return False else: return mac_specific.has_mps def get_cuda_device_string(): return "cuda" def get_optimal_device_name(): if torch.cuda.is_available(): return get_cuda_device_string() if has_mps(): return "mps" return "cpu" def get_optimal_device(): return torch.device(get_optimal_device_name()) def get_device_for(task): return get_optimal_device() def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() if has_mps(): mac_specific.torch_mps_gc() def enable_tf32(): if torch.cuda.is_available(): # 启用基准选项似乎能让一系列显卡在无法使用 fp16 时使用 fp16 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True enable_tf32() #errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") dtype = torch.float16 dtype_vae = torch.float16 dtype_unet = torch.float16 unet_needs_upcast = False def cond_cast_unet(input): return input.to(dtype_unet) if unet_needs_upcast else input def cond_cast_float(input): return input.float() if unet_needs_upcast else input def randn(seed, shape): torch.manual_seed(seed) return torch.randn(shape, device=device) def randn_without_seed(shape): return torch.randn(shape, device=device) def autocast(disable=False): if disable: return contextlib.nullcontext() return torch.autocast("cuda") def without_autocast(disable=False): return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): pass def test_for_nans(x, where): if not torch.all(torch.isnan(x)).item(): return if where == "unet": message = "在 Unet 中生成了一个包含所有 NaNs 的张量。" elif where == "vae": message = "在 VAE 中生成了一个包含所有 NaN 的张量。" else: message = "产生了一个包含所有 NaN 的张量。" message += " 使用 --disable-nan-check 命令行参数禁用此检查。" raise NansException(message) @lru_cache def first_time_calculation(): """ 只要用 pytorch 层进行任何计算,第一次计算就会分配约 700MB 内存,耗时约 2.7 秒,至少在 NVIDIA 上是这样。 """ x = torch.zeros((1, 1)).to(device, dtype) linear = torch.nn.Linear(1, 1).to(device, dtype) linear(x) x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x)