Spaces:
Sleeping
Sleeping
import argparse | |
import ast | |
import os | |
import platform | |
import shlex | |
import stat | |
import subprocess | |
import sys | |
from contextlib import contextmanager | |
from importlib import import_module | |
from pathlib import Path | |
import requests | |
# region constants | |
here = Path(__file__).parent | |
executable = Path(sys.executable) | |
# - detect mode | |
mode = None | |
if os.environ.get("COLAB_GPU"): | |
mode = "colab" | |
elif "python_embeded" in str(executable): | |
mode = "embeded" | |
elif ".venv" in str(executable): | |
mode = "venv" | |
if mode is None: | |
mode = "unknown" | |
repo_url = "https://github.com/melmass/comfy_mtb.git" | |
repo_owner = "melmass" | |
repo_name = "comfy_mtb" | |
short_platform = { | |
"windows": "win_amd64", | |
"linux": "linux_x86_64", | |
} | |
current_platform = platform.system().lower() | |
pip_map = { | |
"onnxruntime-gpu": "onnxruntime", | |
"opencv-contrib": "cv2", | |
"tb-nightly": "tensorboard", | |
"protobuf": "google.protobuf", | |
"qrcode[pil]": "qrcode", | |
"requirements-parser": "requirements" | |
# Add more mappings as needed | |
} | |
# endregion | |
# region ansi | |
# ANSI escape sequences for text styling | |
ANSI_FORMATS = { | |
"reset": "\033[0m", | |
"bold": "\033[1m", | |
"dim": "\033[2m", | |
"italic": "\033[3m", | |
"underline": "\033[4m", | |
"blink": "\033[5m", | |
"reverse": "\033[7m", | |
"strike": "\033[9m", | |
} | |
ANSI_COLORS = { | |
"black": "\033[30m", | |
"red": "\033[31m", | |
"green": "\033[32m", | |
"yellow": "\033[33m", | |
"blue": "\033[34m", | |
"magenta": "\033[35m", | |
"cyan": "\033[36m", | |
"white": "\033[37m", | |
"bright_black": "\033[30;1m", | |
"bright_red": "\033[31;1m", | |
"bright_green": "\033[32;1m", | |
"bright_yellow": "\033[33;1m", | |
"bright_blue": "\033[34;1m", | |
"bright_magenta": "\033[35;1m", | |
"bright_cyan": "\033[36;1m", | |
"bright_white": "\033[37;1m", | |
"bg_black": "\033[40m", | |
"bg_red": "\033[41m", | |
"bg_green": "\033[42m", | |
"bg_yellow": "\033[43m", | |
"bg_blue": "\033[44m", | |
"bg_magenta": "\033[45m", | |
"bg_cyan": "\033[46m", | |
"bg_white": "\033[47m", | |
"bg_bright_black": "\033[40;1m", | |
"bg_bright_red": "\033[41;1m", | |
"bg_bright_green": "\033[42;1m", | |
"bg_bright_yellow": "\033[43;1m", | |
"bg_bright_blue": "\033[44;1m", | |
"bg_bright_magenta": "\033[45;1m", | |
"bg_bright_cyan": "\033[46;1m", | |
"bg_bright_white": "\033[47;1m", | |
} | |
def apply_format(text, *formats): | |
"""Apply ANSI escape sequences for the specified formats to the given text.""" | |
formatted_text = text | |
for format in formats: | |
formatted_text = f"{ANSI_FORMATS.get(format, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" | |
return formatted_text | |
def apply_color(text, color=None, background=None): | |
"""Apply ANSI escape sequences for the specified color and background to the given text.""" | |
formatted_text = text | |
if color: | |
formatted_text = f"{ANSI_COLORS.get(color, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" | |
if background: | |
formatted_text = f"{ANSI_COLORS.get(background, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" | |
return formatted_text | |
def print_formatted(text, *formats, color=None, background=None, **kwargs): | |
"""Print the given text with the specified formats, color, and background.""" | |
formatted_text = apply_format(text, *formats) | |
formatted_text = apply_color(formatted_text, color, background) | |
file = kwargs.get("file", sys.stdout) | |
header = "[mtb install] " | |
# Handle console encoding for Unicode characters (utf-8) | |
encoded_header = header.encode(sys.stdout.encoding, errors="replace").decode( | |
sys.stdout.encoding | |
) | |
encoded_text = formatted_text.encode(sys.stdout.encoding, errors="replace").decode( | |
sys.stdout.encoding | |
) | |
print( | |
" " * len(encoded_header) | |
if kwargs.get("no_header") | |
else apply_color(apply_format(encoded_header, "bold"), color="yellow"), | |
encoded_text, | |
file=file, | |
) | |
# endregion | |
# region utils | |
def run_command(cmd, ignored_lines_start=None): | |
if ignored_lines_start is None: | |
ignored_lines_start = [] | |
if isinstance(cmd, str): | |
shell_cmd = cmd | |
elif isinstance(cmd, list): | |
shell_cmd = " ".join( | |
arg.as_posix() if isinstance(arg, Path) else shlex.quote(str(arg)) | |
for arg in cmd | |
) | |
else: | |
raise ValueError( | |
"Invalid 'cmd' argument. It must be a string or a list of arguments." | |
) | |
try: | |
_run_command(shell_cmd, ignored_lines_start) | |
except subprocess.CalledProcessError as e: | |
print(f"Command failed with return code: {e.returncode}", file=sys.stderr) | |
print(e.stderr.strip(), file=sys.stderr) | |
except KeyboardInterrupt: | |
print("Command execution interrupted.") | |
def _run_command(shell_cmd, ignored_lines_start): | |
print_formatted(f"Running {shell_cmd}", "bold") | |
result = subprocess.run( | |
shell_cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True, | |
shell=True, | |
check=True, | |
) | |
stdout_lines = result.stdout.strip().split("\n") | |
stderr_lines = result.stderr.strip().split("\n") | |
# Print stdout, skipping ignored lines | |
for line in stdout_lines: | |
if not any(line.startswith(ign) for ign in ignored_lines_start): | |
print(line) | |
# Print stderr | |
for line in stderr_lines: | |
print(line, file=sys.stderr) | |
print("Command executed successfully!") | |
def is_pipe(): | |
if not sys.stdin.isatty(): | |
return False | |
if sys.platform == "win32": | |
try: | |
import msvcrt | |
return msvcrt.get_osfhandle(0) != -1 | |
except ImportError: | |
return False | |
else: | |
try: | |
mode = os.fstat(0).st_mode | |
return ( | |
stat.S_ISFIFO(mode) | |
or stat.S_ISREG(mode) | |
or stat.S_ISBLK(mode) | |
or stat.S_ISSOCK(mode) | |
) | |
except OSError: | |
return False | |
def suppress_std(): | |
with open(os.devnull, "w") as devnull: | |
old_stdout = sys.stdout | |
old_stderr = sys.stderr | |
sys.stdout = devnull | |
sys.stderr = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
sys.stderr = old_stderr | |
# Get the version from __init__.py | |
def get_local_version(): | |
init_file = os.path.join(os.path.dirname(__file__), "__init__.py") | |
if os.path.isfile(init_file): | |
with open(init_file, "r") as f: | |
tree = ast.parse(f.read()) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Assign): | |
for target in node.targets: | |
if ( | |
isinstance(target, ast.Name) | |
and target.id == "__version__" | |
and isinstance(node.value, ast.Str) | |
): | |
return node.value.s | |
return None | |
def download_file(url, file_name): | |
with requests.get(url, stream=True) as response: | |
response.raise_for_status() | |
total_size = int(response.headers.get("content-length", 0)) | |
with open(file_name, "wb") as file, tqdm( | |
desc=file_name.stem, | |
total=total_size, | |
unit="B", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
progress_bar.update(len(chunk)) | |
def try_import(requirement): | |
dependency = requirement.name.strip() | |
import_name = pip_map.get(dependency, dependency) | |
installed = False | |
pip_name = dependency | |
pip_spec = "".join(specs[0]) if (specs := requirement.specs) else "" | |
try: | |
with suppress_std(): | |
import_module(import_name) | |
print_formatted( | |
f"\t✅ Package {pip_name} already installed (import name: '{import_name}').", | |
"bold", | |
color="green", | |
no_header=True, | |
) | |
installed = True | |
except ImportError: | |
print_formatted( | |
f"\t⛔ Package {pip_name} is missing (import name: '{import_name}').", | |
"bold", | |
color="red", | |
no_header=True, | |
) | |
return (installed, pip_name, pip_spec, import_name) | |
def import_or_install(requirement, dry=False): | |
installed, pip_name, pip_spec, import_name = try_import(requirement) | |
pip_install_name = pip_name + pip_spec | |
if not installed: | |
print_formatted(f"Installing package {pip_name}...", "italic", color="yellow") | |
if dry: | |
print_formatted( | |
f"Dry-run: Package {pip_install_name} would be installed (import name: '{import_name}').", | |
color="yellow", | |
) | |
else: | |
try: | |
run_command([executable, "-m", "pip", "install", pip_install_name]) | |
print_formatted( | |
f"Package {pip_install_name} installed successfully using pip package name (import name: '{import_name}')", | |
"bold", | |
color="green", | |
) | |
except subprocess.CalledProcessError as e: | |
print_formatted( | |
f"Failed to install package {pip_install_name} using pip package name (import name: '{import_name}'). Error: {str(e)}", | |
"bold", | |
color="red", | |
) | |
def get_github_assets(tag=None): | |
if tag: | |
tag_url = ( | |
f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag}" | |
) | |
else: | |
tag_url = ( | |
f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" | |
) | |
response = requests.get(tag_url) | |
if response.status_code == 404: | |
# print_formatted( | |
# f"Tag version '{apply_color(version,'cyan')}' not found for {owner}/{repo} repository." | |
# ) | |
print_formatted("Error retrieving the release assets.", color="red") | |
sys.exit() | |
tag_data = response.json() | |
tag_name = tag_data["name"] | |
return tag_data, tag_name | |
# endregion | |
try: | |
from tqdm import tqdm | |
except ImportError: | |
print_formatted("Installing tqdm...", "italic", color="yellow") | |
run_command([executable, "-m", "pip", "install", "--upgrade", "tqdm"]) | |
from tqdm import tqdm | |
def main(): | |
if len(sys.argv) == 1: | |
print_formatted( | |
"mtb doesn't need an install script anymore.", "italic", color="yellow" | |
) | |
return | |
if all(arg not in ("-p", "--path") for arg in sys.argv): | |
print( | |
"This script is only used for and edge case of remote installs on some cloud providers, unrecognized arguments:", | |
sys.argv[1:], | |
) | |
return | |
# Parse command-line arguments | |
parser = argparse.ArgumentParser(description="Comfy_mtb install script") | |
parser.add_argument( | |
"--path", | |
"-p", | |
type=str, | |
help="Path to clone the repository to (i.e the absolute path to ComfyUI/custom_nodes)", | |
) | |
print_formatted("mtb install", "bold", color="yellow") | |
args = parser.parse_args() | |
print_formatted(f"Detected environment: {apply_color(mode,'cyan')}") | |
if args.path: | |
clone_dir = Path(args.path) | |
if not clone_dir.exists(): | |
print_formatted( | |
"The path provided does not exist on disk... It must be pointing to ComfyUI's custom_nodes directory" | |
) | |
sys.exit() | |
else: | |
repo_dir = clone_dir / repo_name | |
if not repo_dir.exists(): | |
print_formatted(f"Cloning to {repo_dir}...", "italic", color="yellow") | |
run_command(["git", "clone", "--recursive", repo_url, repo_dir]) | |
else: | |
print_formatted( | |
f"Directory {repo_dir} already exists, we will update it..." | |
) | |
run_command(["git", "pull", "-C", repo_dir]) | |
here = clone_dir | |
full = True | |
print_formatted("Checking environment...", "italic", color="yellow") | |
missing_deps = [] | |
install_cmd = [executable, "-m", "pip", "install", "-r", "requirements.txt"] | |
run_command(install_cmd) | |
print_formatted( | |
"✅ Successfully installed all dependencies.", "italic", color="green" | |
) | |
if __name__ == "__main__": | |
main() | |