chattts / ChatTTS /utils /gpu_utils.py
doby4u's picture
webui origin
47442ee
raw
history blame
929 Bytes
import torch
import logging
def select_device(min_memory = 2048):
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
available_gpus = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
available_gpus.append((i, free_memory))
selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
device = torch.device(f'cuda:{selected_gpu}')
free_memory_mb = max_free_memory / (1024 * 1024)
if free_memory_mb < min_memory:
logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
device = torch.device('cpu')
else:
logger.log(logging.WARNING, f'No GPU found, use CPU instead')
device = torch.device('cpu')
return device