|
import glob |
|
import mimetypes |
|
import os |
|
import platform |
|
import shutil |
|
import ssl |
|
import subprocess |
|
import sys |
|
import urllib |
|
import torch |
|
import gradio |
|
import tempfile |
|
import cv2 |
|
import zipfile |
|
import traceback |
|
import threading |
|
import threading |
|
|
|
from typing import Union, Any |
|
from contextlib import nullcontext |
|
|
|
from pathlib import Path |
|
from typing import List, Any |
|
from tqdm import tqdm |
|
from scipy.spatial import distance |
|
|
|
import roop.template_parser as template_parser |
|
|
|
import roop.globals |
|
|
|
TEMP_FILE = "temp.mp4" |
|
TEMP_DIRECTORY = "temp" |
|
|
|
THREAD_SEMAPHORE = threading.Semaphore() |
|
NULL_CONTEXT = nullcontext() |
|
|
|
|
|
|
|
if platform.system().lower() == "darwin": |
|
ssl._create_default_https_context = ssl._create_unverified_context |
|
|
|
|
|
|
|
def detect_fps(target_path: str) -> float: |
|
fps = 24.0 |
|
cap = cv2.VideoCapture(target_path) |
|
if cap.isOpened(): |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return fps |
|
|
|
|
|
|
|
def convert_to_gradio(image): |
|
if image is None: |
|
return None |
|
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
def sort_filenames_ignore_path(filenames): |
|
"""Sorts a list of filenames containing a complete path by their filename, |
|
while retaining their original path. |
|
|
|
Args: |
|
filenames: A list of filenames containing a complete path. |
|
|
|
Returns: |
|
A sorted list of filenames containing a complete path. |
|
""" |
|
filename_path_tuples = [ |
|
(os.path.split(filename)[1], filename) for filename in filenames |
|
] |
|
sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0]) |
|
return [ |
|
filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples |
|
] |
|
|
|
|
|
def sort_rename_frames(path: str): |
|
filenames = os.listdir(path) |
|
filenames.sort() |
|
for i in range(len(filenames)): |
|
of = os.path.join(path, filenames[i]) |
|
newidx = i + 1 |
|
new_filename = os.path.join( |
|
path, f"{newidx:06d}." + roop.globals.CFG.output_image_format |
|
) |
|
os.rename(of, new_filename) |
|
|
|
|
|
def get_temp_frame_paths(target_path: str) -> List[str]: |
|
temp_directory_path = get_temp_directory_path(target_path) |
|
return glob.glob( |
|
( |
|
os.path.join( |
|
glob.escape(temp_directory_path), |
|
f"*.{roop.globals.CFG.output_image_format}", |
|
) |
|
) |
|
) |
|
|
|
|
|
def get_temp_directory_path(target_path: str) -> str: |
|
target_name, _ = os.path.splitext(os.path.basename(target_path)) |
|
target_directory_path = os.path.dirname(target_path) |
|
return os.path.join(target_directory_path, TEMP_DIRECTORY, target_name) |
|
|
|
|
|
def get_temp_output_path(target_path: str) -> str: |
|
temp_directory_path = get_temp_directory_path(target_path) |
|
return os.path.join(temp_directory_path, TEMP_FILE) |
|
|
|
|
|
def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any: |
|
if source_path and target_path: |
|
source_name, _ = os.path.splitext(os.path.basename(source_path)) |
|
target_name, target_extension = os.path.splitext(os.path.basename(target_path)) |
|
if os.path.isdir(output_path): |
|
return os.path.join( |
|
output_path, source_name + "-" + target_name + target_extension |
|
) |
|
return output_path |
|
|
|
|
|
def get_destfilename_from_path( |
|
srcfilepath: str, destfilepath: str, extension: str |
|
) -> str: |
|
fn, ext = os.path.splitext(os.path.basename(srcfilepath)) |
|
if "." in extension: |
|
return os.path.join(destfilepath, f"{fn}{extension}") |
|
return os.path.join(destfilepath, f"{fn}{extension}{ext}") |
|
|
|
|
|
def replace_template(file_path: str, index: int = 0) -> str: |
|
fn, ext = os.path.splitext(os.path.basename(file_path)) |
|
|
|
|
|
fn = fn.replace("__temp", "") |
|
|
|
template = roop.globals.CFG.output_template |
|
replaced_filename = template_parser.parse( |
|
template, {"index": str(index), "file": fn} |
|
) |
|
|
|
return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}") |
|
|
|
|
|
def create_temp(target_path: str) -> None: |
|
temp_directory_path = get_temp_directory_path(target_path) |
|
Path(temp_directory_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def move_temp(target_path: str, output_path: str) -> None: |
|
temp_output_path = get_temp_output_path(target_path) |
|
if os.path.isfile(temp_output_path): |
|
if os.path.isfile(output_path): |
|
os.remove(output_path) |
|
shutil.move(temp_output_path, output_path) |
|
|
|
|
|
def clean_temp(target_path: str) -> None: |
|
temp_directory_path = get_temp_directory_path(target_path) |
|
parent_directory_path = os.path.dirname(temp_directory_path) |
|
if not roop.globals.keep_frames and os.path.isdir(temp_directory_path): |
|
shutil.rmtree(temp_directory_path) |
|
if os.path.exists(parent_directory_path) and not os.listdir(parent_directory_path): |
|
os.rmdir(parent_directory_path) |
|
|
|
|
|
def delete_temp_frames(filename: str) -> None: |
|
dir = os.path.dirname(os.path.dirname(filename)) |
|
shutil.rmtree(dir) |
|
|
|
|
|
def has_image_extension(image_path: str) -> bool: |
|
return image_path.lower().endswith(("png", "jpg", "jpeg", "webp")) |
|
|
|
|
|
def has_extension(filepath: str, extensions: List[str]) -> bool: |
|
return filepath.lower().endswith(tuple(extensions)) |
|
|
|
|
|
def is_image(image_path: str) -> bool: |
|
if image_path and os.path.isfile(image_path): |
|
if image_path.endswith(".webp"): |
|
return True |
|
mimetype, _ = mimetypes.guess_type(image_path) |
|
return bool(mimetype and mimetype.startswith("image/")) |
|
return False |
|
|
|
|
|
def is_video(video_path: str) -> bool: |
|
if video_path and os.path.isfile(video_path): |
|
mimetype, _ = mimetypes.guess_type(video_path) |
|
return bool(mimetype and mimetype.startswith("video/")) |
|
return False |
|
|
|
|
|
def conditional_download(download_directory_path: str, urls: List[str]) -> None: |
|
if not os.path.exists(download_directory_path): |
|
os.makedirs(download_directory_path) |
|
for url in urls: |
|
download_file_path = os.path.join( |
|
download_directory_path, os.path.basename(url) |
|
) |
|
if not os.path.exists(download_file_path): |
|
request = urllib.request.urlopen(url) |
|
total = int(request.headers.get("Content-Length", 0)) |
|
with tqdm( |
|
total=total, |
|
desc=f"Downloading {url}", |
|
unit="B", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as progress: |
|
urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) |
|
|
|
|
|
def get_local_files_from_folder(folder: str) -> List[str]: |
|
if not os.path.exists(folder) or not os.path.isdir(folder): |
|
return None |
|
files = [ |
|
os.path.join(folder, f) |
|
for f in os.listdir(folder) |
|
if os.path.isfile(os.path.join(folder, f)) |
|
] |
|
return files |
|
|
|
|
|
def resolve_relative_path(path: str) -> str: |
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) |
|
|
|
|
|
def get_device() -> str: |
|
if len(roop.globals.execution_providers) < 1: |
|
roop.globals.execution_providers = ["CPUExecutionProvider"] |
|
|
|
prov = roop.globals.execution_providers[0] |
|
if "CoreMLExecutionProvider" in prov: |
|
return "mps" |
|
if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov: |
|
return "cuda" |
|
if "OpenVINOExecutionProvider" in prov: |
|
return "mkl" |
|
return "cpu" |
|
|
|
|
|
def str_to_class(module_name, class_name) -> Any: |
|
from importlib import import_module |
|
|
|
class_ = None |
|
try: |
|
module_ = import_module(module_name) |
|
try: |
|
class_ = getattr(module_, class_name)() |
|
except AttributeError: |
|
print(f"Class {class_name} does not exist") |
|
except ImportError: |
|
print(f"Module {module_name} does not exist") |
|
return class_ |
|
|
|
def is_installed(name:str) -> bool: |
|
return shutil.which(name); |
|
|
|
|
|
def get_platform() -> str: |
|
if sys.platform == "linux": |
|
try: |
|
proc_version = open("/proc/version").read() |
|
if "Microsoft" in proc_version: |
|
return "wsl" |
|
except: |
|
pass |
|
return sys.platform |
|
|
|
def open_with_default_app(filename:str): |
|
if filename == None: |
|
return |
|
platform = get_platform() |
|
if platform == "darwin": |
|
subprocess.call(("open", filename)) |
|
elif platform in ["win64", "win32"]: os.startfile(filename.replace("/", "\\")) |
|
elif platform == "wsl": |
|
subprocess.call("cmd.exe /C start".split() + [filename]) |
|
else: |
|
subprocess.call("xdg-open", filename) |
|
|
|
|
|
def prepare_for_batch(target_files) -> str: |
|
print("Preparing temp files") |
|
tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp") |
|
if os.path.exists(tempfolder): |
|
shutil.rmtree(tempfolder) |
|
Path(tempfolder).mkdir(parents=True, exist_ok=True) |
|
for f in target_files: |
|
newname = os.path.basename(f.name) |
|
shutil.move(f.name, os.path.join(tempfolder, newname)) |
|
return tempfolder |
|
|
|
|
|
def zip(files, zipname): |
|
with zipfile.ZipFile(zipname, "w") as zip_file: |
|
for f in files: |
|
zip_file.write(f, os.path.basename(f)) |
|
|
|
|
|
def unzip(zipfilename: str, target_path: str): |
|
with zipfile.ZipFile(zipfilename, "r") as zip_file: |
|
zip_file.extractall(target_path) |
|
|
|
|
|
def mkdir_with_umask(directory): |
|
oldmask = os.umask(0) |
|
|
|
os.makedirs(directory, mode=0o775, exist_ok=True) |
|
os.umask(oldmask) |
|
|
|
|
|
def open_folder(path: str): |
|
platform = get_platform() |
|
try: |
|
if platform == "darwin": |
|
subprocess.call(("open", path)) |
|
elif platform in ["win64", "win32"]: |
|
open_with_default_app(path) |
|
elif platform == "wsl": |
|
subprocess.call("cmd.exe /C start".split() + [path]) |
|
else: |
|
subprocess.Popen(["xdg-open", path]) |
|
except Exception as e: |
|
traceback.print_exc() |
|
pass |
|
|
|
|
|
|
|
|
|
def create_version_html() -> str: |
|
python_version = ".".join([str(x) for x in sys.version_info[0:3]]) |
|
versions_html = f""" |
|
python: <span title="{sys.version}">{python_version}</span> |
|
β’ |
|
torch: {getattr(torch, '__long_version__',torch.__version__)} |
|
β’ |
|
gradio: {gradio.__version__} |
|
""" |
|
return versions_html |
|
|
|
|
|
def compute_cosine_distance(emb1, emb2) -> float: |
|
return distance.cosine(emb1, emb2) |
|
|
|
def has_cuda_device(): |
|
return torch.cuda is not None and torch.cuda.is_available() |
|
|
|
|
|
def print_cuda_info(): |
|
try: |
|
print(f'Number of CUDA devices: {torch.cuda.device_count()} Currently used Id: {torch.cuda.current_device()} Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}') |
|
except: |
|
print('No CUDA device found!') |
|
|
|
def clean_dir(path: str): |
|
contents = os.listdir(path) |
|
for item in contents: |
|
item_path = os.path.join(path, item) |
|
try: |
|
if os.path.isfile(item_path): |
|
os.remove(item_path) |
|
elif os.path.isdir(item_path): |
|
shutil.rmtree(item_path) |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
def conditional_thread_semaphore() -> Union[Any, Any]: |
|
if 'DmlExecutionProvider' in roop.globals.execution_providers or 'ROCMExecutionProvider' in roop.globals.execution_providers: |
|
return THREAD_SEMAPHORE |
|
return NULL_CONTEXT |
|
|