import argparse import ast import os import platform import shlex import stat import subprocess import sys from contextlib import contextmanager from importlib import import_module from pathlib import Path import requests # region constants here = Path(__file__).parent executable = Path(sys.executable) # - detect mode mode = None if os.environ.get("COLAB_GPU"): mode = "colab" elif "python_embeded" in str(executable): mode = "embeded" elif ".venv" in str(executable): mode = "venv" if mode is None: mode = "unknown" repo_url = "https://github.com/melmass/comfy_mtb.git" repo_owner = "melmass" repo_name = "comfy_mtb" short_platform = { "windows": "win_amd64", "linux": "linux_x86_64", } current_platform = platform.system().lower() pip_map = { "onnxruntime-gpu": "onnxruntime", "opencv-contrib": "cv2", "tb-nightly": "tensorboard", "protobuf": "google.protobuf", "qrcode[pil]": "qrcode", "requirements-parser": "requirements" # Add more mappings as needed } # endregion # region ansi # ANSI escape sequences for text styling ANSI_FORMATS = { "reset": "\033[0m", "bold": "\033[1m", "dim": "\033[2m", "italic": "\033[3m", "underline": "\033[4m", "blink": "\033[5m", "reverse": "\033[7m", "strike": "\033[9m", } ANSI_COLORS = { "black": "\033[30m", "red": "\033[31m", "green": "\033[32m", "yellow": "\033[33m", "blue": "\033[34m", "magenta": "\033[35m", "cyan": "\033[36m", "white": "\033[37m", "bright_black": "\033[30;1m", "bright_red": "\033[31;1m", "bright_green": "\033[32;1m", "bright_yellow": "\033[33;1m", "bright_blue": "\033[34;1m", "bright_magenta": "\033[35;1m", "bright_cyan": "\033[36;1m", "bright_white": "\033[37;1m", "bg_black": "\033[40m", "bg_red": "\033[41m", "bg_green": "\033[42m", "bg_yellow": "\033[43m", "bg_blue": "\033[44m", "bg_magenta": "\033[45m", "bg_cyan": "\033[46m", "bg_white": "\033[47m", "bg_bright_black": "\033[40;1m", "bg_bright_red": "\033[41;1m", "bg_bright_green": "\033[42;1m", "bg_bright_yellow": "\033[43;1m", "bg_bright_blue": "\033[44;1m", "bg_bright_magenta": "\033[45;1m", "bg_bright_cyan": "\033[46;1m", "bg_bright_white": "\033[47;1m", } def apply_format(text, *formats): """Apply ANSI escape sequences for the specified formats to the given text.""" formatted_text = text for format in formats: formatted_text = f"{ANSI_FORMATS.get(format, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" return formatted_text def apply_color(text, color=None, background=None): """Apply ANSI escape sequences for the specified color and background to the given text.""" formatted_text = text if color: formatted_text = f"{ANSI_COLORS.get(color, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" if background: formatted_text = f"{ANSI_COLORS.get(background, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}" return formatted_text def print_formatted(text, *formats, color=None, background=None, **kwargs): """Print the given text with the specified formats, color, and background.""" formatted_text = apply_format(text, *formats) formatted_text = apply_color(formatted_text, color, background) file = kwargs.get("file", sys.stdout) header = "[mtb install] " # Handle console encoding for Unicode characters (utf-8) encoded_header = header.encode(sys.stdout.encoding, errors="replace").decode( sys.stdout.encoding ) encoded_text = formatted_text.encode(sys.stdout.encoding, errors="replace").decode( sys.stdout.encoding ) print( " " * len(encoded_header) if kwargs.get("no_header") else apply_color(apply_format(encoded_header, "bold"), color="yellow"), encoded_text, file=file, ) # endregion # region utils def run_command(cmd, ignored_lines_start=None): if ignored_lines_start is None: ignored_lines_start = [] if isinstance(cmd, str): shell_cmd = cmd elif isinstance(cmd, list): shell_cmd = " ".join( arg.as_posix() if isinstance(arg, Path) else shlex.quote(str(arg)) for arg in cmd ) else: raise ValueError( "Invalid 'cmd' argument. It must be a string or a list of arguments." ) try: _run_command(shell_cmd, ignored_lines_start) except subprocess.CalledProcessError as e: print(f"Command failed with return code: {e.returncode}", file=sys.stderr) print(e.stderr.strip(), file=sys.stderr) except KeyboardInterrupt: print("Command execution interrupted.") def _run_command(shell_cmd, ignored_lines_start): print_formatted(f"Running {shell_cmd}", "bold") result = subprocess.run( shell_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True, check=True, ) stdout_lines = result.stdout.strip().split("\n") stderr_lines = result.stderr.strip().split("\n") # Print stdout, skipping ignored lines for line in stdout_lines: if not any(line.startswith(ign) for ign in ignored_lines_start): print(line) # Print stderr for line in stderr_lines: print(line, file=sys.stderr) print("Command executed successfully!") def is_pipe(): if not sys.stdin.isatty(): return False if sys.platform == "win32": try: import msvcrt return msvcrt.get_osfhandle(0) != -1 except ImportError: return False else: try: mode = os.fstat(0).st_mode return ( stat.S_ISFIFO(mode) or stat.S_ISREG(mode) or stat.S_ISBLK(mode) or stat.S_ISSOCK(mode) ) except OSError: return False @contextmanager def suppress_std(): with open(os.devnull, "w") as devnull: old_stdout = sys.stdout old_stderr = sys.stderr sys.stdout = devnull sys.stderr = devnull try: yield finally: sys.stdout = old_stdout sys.stderr = old_stderr # Get the version from __init__.py def get_local_version(): init_file = os.path.join(os.path.dirname(__file__), "__init__.py") if os.path.isfile(init_file): with open(init_file, "r") as f: tree = ast.parse(f.read()) for node in ast.walk(tree): if isinstance(node, ast.Assign): for target in node.targets: if ( isinstance(target, ast.Name) and target.id == "__version__" and isinstance(node.value, ast.Str) ): return node.value.s return None def download_file(url, file_name): with requests.get(url, stream=True) as response: response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) with open(file_name, "wb") as file, tqdm( desc=file_name.stem, total=total_size, unit="B", unit_scale=True, unit_divisor=1024, ) as progress_bar: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) progress_bar.update(len(chunk)) def try_import(requirement): dependency = requirement.name.strip() import_name = pip_map.get(dependency, dependency) installed = False pip_name = dependency pip_spec = "".join(specs[0]) if (specs := requirement.specs) else "" try: with suppress_std(): import_module(import_name) print_formatted( f"\t✅ Package {pip_name} already installed (import name: '{import_name}').", "bold", color="green", no_header=True, ) installed = True except ImportError: print_formatted( f"\t⛔ Package {pip_name} is missing (import name: '{import_name}').", "bold", color="red", no_header=True, ) return (installed, pip_name, pip_spec, import_name) def import_or_install(requirement, dry=False): installed, pip_name, pip_spec, import_name = try_import(requirement) pip_install_name = pip_name + pip_spec if not installed: print_formatted(f"Installing package {pip_name}...", "italic", color="yellow") if dry: print_formatted( f"Dry-run: Package {pip_install_name} would be installed (import name: '{import_name}').", color="yellow", ) else: try: run_command([executable, "-m", "pip", "install", pip_install_name]) print_formatted( f"Package {pip_install_name} installed successfully using pip package name (import name: '{import_name}')", "bold", color="green", ) except subprocess.CalledProcessError as e: print_formatted( f"Failed to install package {pip_install_name} using pip package name (import name: '{import_name}'). Error: {str(e)}", "bold", color="red", ) def get_github_assets(tag=None): if tag: tag_url = ( f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag}" ) else: tag_url = ( f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" ) response = requests.get(tag_url) if response.status_code == 404: # print_formatted( # f"Tag version '{apply_color(version,'cyan')}' not found for {owner}/{repo} repository." # ) print_formatted("Error retrieving the release assets.", color="red") sys.exit() tag_data = response.json() tag_name = tag_data["name"] return tag_data, tag_name # endregion try: from tqdm import tqdm except ImportError: print_formatted("Installing tqdm...", "italic", color="yellow") run_command([executable, "-m", "pip", "install", "--upgrade", "tqdm"]) from tqdm import tqdm def main(): if len(sys.argv) == 1: print_formatted( "mtb doesn't need an install script anymore.", "italic", color="yellow" ) return if all(arg not in ("-p", "--path") for arg in sys.argv): print( "This script is only used for and edge case of remote installs on some cloud providers, unrecognized arguments:", sys.argv[1:], ) return # Parse command-line arguments parser = argparse.ArgumentParser(description="Comfy_mtb install script") parser.add_argument( "--path", "-p", type=str, help="Path to clone the repository to (i.e the absolute path to ComfyUI/custom_nodes)", ) print_formatted("mtb install", "bold", color="yellow") args = parser.parse_args() print_formatted(f"Detected environment: {apply_color(mode,'cyan')}") if args.path: clone_dir = Path(args.path) if not clone_dir.exists(): print_formatted( "The path provided does not exist on disk... It must be pointing to ComfyUI's custom_nodes directory" ) sys.exit() else: repo_dir = clone_dir / repo_name if not repo_dir.exists(): print_formatted(f"Cloning to {repo_dir}...", "italic", color="yellow") run_command(["git", "clone", "--recursive", repo_url, repo_dir]) else: print_formatted( f"Directory {repo_dir} already exists, we will update it..." ) run_command(["git", "pull", "-C", repo_dir]) here = clone_dir full = True print_formatted("Checking environment...", "italic", color="yellow") missing_deps = [] install_cmd = [executable, "-m", "pip", "install", "-r", "requirements.txt"] run_command(install_cmd) print_formatted( "✅ Successfully installed all dependencies.", "italic", color="green" ) if __name__ == "__main__": main()