Spaces:
Running
Running
import os | |
from urllib.parse import urlparse | |
aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER') | |
HF_ENDPOINT = os.getenv('HF_ENDPOINT') | |
if aria2 is not None: | |
secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET') | |
url = urlparse(aria2) | |
port = url.port | |
host = url.scheme + '://' + url.hostname | |
import aria2p | |
aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret)) | |
def download_url(model_url: str, model_dir: str, filename: str): | |
if aria2: | |
return aria2_download_url(model_url, model_dir, filename) | |
else: | |
from torchvision.datasets.utils import download_url as torchvision_download_url | |
return torchvision_download_url(model_url, model_dir, filename) | |
def aria2_find_task(dir: str, filename: str): | |
target = os.path.join(dir, filename) | |
downloads = aria2.get_downloads() | |
for download in downloads: | |
for file in download.files: | |
if file.is_metadata: | |
continue | |
if str(file.path) == target: | |
return download | |
def aria2_download_url(model_url: str, model_dir: str, filename: str): | |
import manager_core as core | |
import tqdm | |
import time | |
if model_dir.startswith(core.comfy_path): | |
model_dir = model_dir[len(core.comfy_path) :] | |
if HF_ENDPOINT: | |
model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT) | |
download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir) | |
download = aria2_find_task(download_dir, filename) | |
if download is None: | |
options = {'dir': download_dir, 'out': filename} | |
download = aria2.add(model_url, options)[0] | |
if download.is_active: | |
with tqdm.tqdm( | |
total=download.total_length, | |
bar_format='{l_bar}{bar}{r_bar}', | |
desc=filename, | |
unit='B', | |
unit_scale=True, | |
) as progress_bar: | |
while download.is_active: | |
if progress_bar.total == 0 and download.total_length != 0: | |
progress_bar.reset(download.total_length) | |
progress_bar.update(download.completed_length - progress_bar.n) | |
time.sleep(1) | |
download.update() | |