multimodalart's picture
Squashing commit
4450790 verified
import argparse
import ast
import os
import platform
import shlex
import stat
import subprocess
import sys
from contextlib import contextmanager
from importlib import import_module
from pathlib import Path
import requests
# region constants
here = Path(__file__).parent
executable = Path(sys.executable)
# - detect mode
mode = None
if os.environ.get("COLAB_GPU"):
mode = "colab"
elif "python_embeded" in str(executable):
mode = "embeded"
elif ".venv" in str(executable):
mode = "venv"
if mode is None:
mode = "unknown"
repo_url = "https://github.com/melmass/comfy_mtb.git"
repo_owner = "melmass"
repo_name = "comfy_mtb"
short_platform = {
"windows": "win_amd64",
"linux": "linux_x86_64",
}
current_platform = platform.system().lower()
pip_map = {
"onnxruntime-gpu": "onnxruntime",
"opencv-contrib": "cv2",
"tb-nightly": "tensorboard",
"protobuf": "google.protobuf",
"qrcode[pil]": "qrcode",
"requirements-parser": "requirements"
# Add more mappings as needed
}
# endregion
# region ansi
# ANSI escape sequences for text styling
ANSI_FORMATS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"italic": "\033[3m",
"underline": "\033[4m",
"blink": "\033[5m",
"reverse": "\033[7m",
"strike": "\033[9m",
}
ANSI_COLORS = {
"black": "\033[30m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[30;1m",
"bright_red": "\033[31;1m",
"bright_green": "\033[32;1m",
"bright_yellow": "\033[33;1m",
"bright_blue": "\033[34;1m",
"bright_magenta": "\033[35;1m",
"bright_cyan": "\033[36;1m",
"bright_white": "\033[37;1m",
"bg_black": "\033[40m",
"bg_red": "\033[41m",
"bg_green": "\033[42m",
"bg_yellow": "\033[43m",
"bg_blue": "\033[44m",
"bg_magenta": "\033[45m",
"bg_cyan": "\033[46m",
"bg_white": "\033[47m",
"bg_bright_black": "\033[40;1m",
"bg_bright_red": "\033[41;1m",
"bg_bright_green": "\033[42;1m",
"bg_bright_yellow": "\033[43;1m",
"bg_bright_blue": "\033[44;1m",
"bg_bright_magenta": "\033[45;1m",
"bg_bright_cyan": "\033[46;1m",
"bg_bright_white": "\033[47;1m",
}
def apply_format(text, *formats):
"""Apply ANSI escape sequences for the specified formats to the given text."""
formatted_text = text
for format in formats:
formatted_text = f"{ANSI_FORMATS.get(format, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}"
return formatted_text
def apply_color(text, color=None, background=None):
"""Apply ANSI escape sequences for the specified color and background to the given text."""
formatted_text = text
if color:
formatted_text = f"{ANSI_COLORS.get(color, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}"
if background:
formatted_text = f"{ANSI_COLORS.get(background, '')}{formatted_text}{ANSI_FORMATS.get('reset', '')}"
return formatted_text
def print_formatted(text, *formats, color=None, background=None, **kwargs):
"""Print the given text with the specified formats, color, and background."""
formatted_text = apply_format(text, *formats)
formatted_text = apply_color(formatted_text, color, background)
file = kwargs.get("file", sys.stdout)
header = "[mtb install] "
# Handle console encoding for Unicode characters (utf-8)
encoded_header = header.encode(sys.stdout.encoding, errors="replace").decode(
sys.stdout.encoding
)
encoded_text = formatted_text.encode(sys.stdout.encoding, errors="replace").decode(
sys.stdout.encoding
)
print(
" " * len(encoded_header)
if kwargs.get("no_header")
else apply_color(apply_format(encoded_header, "bold"), color="yellow"),
encoded_text,
file=file,
)
# endregion
# region utils
def run_command(cmd, ignored_lines_start=None):
if ignored_lines_start is None:
ignored_lines_start = []
if isinstance(cmd, str):
shell_cmd = cmd
elif isinstance(cmd, list):
shell_cmd = " ".join(
arg.as_posix() if isinstance(arg, Path) else shlex.quote(str(arg))
for arg in cmd
)
else:
raise ValueError(
"Invalid 'cmd' argument. It must be a string or a list of arguments."
)
try:
_run_command(shell_cmd, ignored_lines_start)
except subprocess.CalledProcessError as e:
print(f"Command failed with return code: {e.returncode}", file=sys.stderr)
print(e.stderr.strip(), file=sys.stderr)
except KeyboardInterrupt:
print("Command execution interrupted.")
def _run_command(shell_cmd, ignored_lines_start):
print_formatted(f"Running {shell_cmd}", "bold")
result = subprocess.run(
shell_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
shell=True,
check=True,
)
stdout_lines = result.stdout.strip().split("\n")
stderr_lines = result.stderr.strip().split("\n")
# Print stdout, skipping ignored lines
for line in stdout_lines:
if not any(line.startswith(ign) for ign in ignored_lines_start):
print(line)
# Print stderr
for line in stderr_lines:
print(line, file=sys.stderr)
print("Command executed successfully!")
def is_pipe():
if not sys.stdin.isatty():
return False
if sys.platform == "win32":
try:
import msvcrt
return msvcrt.get_osfhandle(0) != -1
except ImportError:
return False
else:
try:
mode = os.fstat(0).st_mode
return (
stat.S_ISFIFO(mode)
or stat.S_ISREG(mode)
or stat.S_ISBLK(mode)
or stat.S_ISSOCK(mode)
)
except OSError:
return False
@contextmanager
def suppress_std():
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
# Get the version from __init__.py
def get_local_version():
init_file = os.path.join(os.path.dirname(__file__), "__init__.py")
if os.path.isfile(init_file):
with open(init_file, "r") as f:
tree = ast.parse(f.read())
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "__version__"
and isinstance(node.value, ast.Str)
):
return node.value.s
return None
def download_file(url, file_name):
with requests.get(url, stream=True) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
with open(file_name, "wb") as file, tqdm(
desc=file_name.stem,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
progress_bar.update(len(chunk))
def try_import(requirement):
dependency = requirement.name.strip()
import_name = pip_map.get(dependency, dependency)
installed = False
pip_name = dependency
pip_spec = "".join(specs[0]) if (specs := requirement.specs) else ""
try:
with suppress_std():
import_module(import_name)
print_formatted(
f"\tโœ… Package {pip_name} already installed (import name: '{import_name}').",
"bold",
color="green",
no_header=True,
)
installed = True
except ImportError:
print_formatted(
f"\tโ›” Package {pip_name} is missing (import name: '{import_name}').",
"bold",
color="red",
no_header=True,
)
return (installed, pip_name, pip_spec, import_name)
def import_or_install(requirement, dry=False):
installed, pip_name, pip_spec, import_name = try_import(requirement)
pip_install_name = pip_name + pip_spec
if not installed:
print_formatted(f"Installing package {pip_name}...", "italic", color="yellow")
if dry:
print_formatted(
f"Dry-run: Package {pip_install_name} would be installed (import name: '{import_name}').",
color="yellow",
)
else:
try:
run_command([executable, "-m", "pip", "install", pip_install_name])
print_formatted(
f"Package {pip_install_name} installed successfully using pip package name (import name: '{import_name}')",
"bold",
color="green",
)
except subprocess.CalledProcessError as e:
print_formatted(
f"Failed to install package {pip_install_name} using pip package name (import name: '{import_name}'). Error: {str(e)}",
"bold",
color="red",
)
def get_github_assets(tag=None):
if tag:
tag_url = (
f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{tag}"
)
else:
tag_url = (
f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
)
response = requests.get(tag_url)
if response.status_code == 404:
# print_formatted(
# f"Tag version '{apply_color(version,'cyan')}' not found for {owner}/{repo} repository."
# )
print_formatted("Error retrieving the release assets.", color="red")
sys.exit()
tag_data = response.json()
tag_name = tag_data["name"]
return tag_data, tag_name
# endregion
try:
from tqdm import tqdm
except ImportError:
print_formatted("Installing tqdm...", "italic", color="yellow")
run_command([executable, "-m", "pip", "install", "--upgrade", "tqdm"])
from tqdm import tqdm
def main():
if len(sys.argv) == 1:
print_formatted(
"mtb doesn't need an install script anymore.", "italic", color="yellow"
)
return
if all(arg not in ("-p", "--path") for arg in sys.argv):
print(
"This script is only used for and edge case of remote installs on some cloud providers, unrecognized arguments:",
sys.argv[1:],
)
return
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Comfy_mtb install script")
parser.add_argument(
"--path",
"-p",
type=str,
help="Path to clone the repository to (i.e the absolute path to ComfyUI/custom_nodes)",
)
print_formatted("mtb install", "bold", color="yellow")
args = parser.parse_args()
print_formatted(f"Detected environment: {apply_color(mode,'cyan')}")
if args.path:
clone_dir = Path(args.path)
if not clone_dir.exists():
print_formatted(
"The path provided does not exist on disk... It must be pointing to ComfyUI's custom_nodes directory"
)
sys.exit()
else:
repo_dir = clone_dir / repo_name
if not repo_dir.exists():
print_formatted(f"Cloning to {repo_dir}...", "italic", color="yellow")
run_command(["git", "clone", "--recursive", repo_url, repo_dir])
else:
print_formatted(
f"Directory {repo_dir} already exists, we will update it..."
)
run_command(["git", "pull", "-C", repo_dir])
here = clone_dir
full = True
print_formatted("Checking environment...", "italic", color="yellow")
missing_deps = []
install_cmd = [executable, "-m", "pip", "install", "-r", "requirements.txt"]
run_command(install_cmd)
print_formatted(
"โœ… Successfully installed all dependencies.", "italic", color="green"
)
if __name__ == "__main__":
main()