multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
12.6 kB
import json
import subprocess
import uuid
from pathlib import Path
import comfy.model_management as model_management
import comfy.utils
import folder_paths
import numpy as np
import torch
from PIL import Image
from ..log import log
from ..utils import PIL_FILTER_MAP, output_dir, session_id, tensor2np
def get_playlist_path(playlist_name: str, persistant_playlist=False):
if persistant_playlist:
return output_dir / "playlists" / f"{playlist_name}.json"
return output_dir / "playlists" / session_id / f"{playlist_name}.json"
class MTB_ReadPlaylist:
"""Read a playlist"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"enable": ("BOOLEAN", {"default": True}),
"persistant_playlist": ("BOOLEAN", {"default": False}),
"playlist_name": (
"STRING",
{"default": "playlist_{index:04d}"},
),
"index": ("INT", {"default": 0, "min": 0}),
}
}
RETURN_TYPES = ("PLAYLIST",)
FUNCTION = "read_playlist"
CATEGORY = "mtb/IO"
EXPERIMENTAL = True
def read_playlist(
self,
enable: bool,
persistant_playlist: bool,
playlist_name: str,
index: int,
):
playlist_name = playlist_name.format(index=index)
playlist_path = get_playlist_path(playlist_name, persistant_playlist)
if not enable:
return (None,)
if not playlist_path.exists():
log.warning(f"Playlist {playlist_path} does not exist, skipping")
return (None,)
log.debug(f"Reading playlist {playlist_path}")
return (json.loads(playlist_path.read_text(encoding="utf-8")),)
class MTB_AddToPlaylist:
"""Add a video to the playlist"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"relative_paths": ("BOOLEAN", {"default": False}),
"persistant_playlist": ("BOOLEAN", {"default": False}),
"playlist_name": (
"STRING",
{"default": "playlist_{index:04d}"},
),
"index": ("INT", {"default": 0, "min": 0}),
}
}
RETURN_TYPES = ()
OUTPUT_NODE = True
FUNCTION = "add_to_playlist"
CATEGORY = "mtb/IO"
EXPERIMENTAL = True
def add_to_playlist(
self,
relative_paths: bool,
persistant_playlist: bool,
playlist_name: str,
index: int,
**kwargs,
):
playlist_name = playlist_name.format(index=index)
playlist_path = get_playlist_path(playlist_name, persistant_playlist)
if not playlist_path.parent.exists():
playlist_path.parent.mkdir(parents=True, exist_ok=True)
playlist = []
if not playlist_path.exists():
playlist_path.write_text("[]")
else:
playlist = json.loads(playlist_path.read_text())
log.debug(f"Playlist {playlist_path} has {len(playlist)} items")
for video in kwargs.values():
if relative_paths:
video = Path(video).relative_to(output_dir).as_posix()
log.debug(f"Adding {video} to playlist")
playlist.append(video)
log.debug(f"Writing playlist {playlist_path}")
playlist_path.write_text(json.dumps(playlist), encoding="utf-8")
return ()
class MTB_ExportWithFfmpeg:
"""Export with FFmpeg (Experimental).
[DEPRACATED] Use VHS nodes instead
"""
@classmethod
def INPUT_TYPES(cls):
return {
"optional": {
"images": ("IMAGE",),
"playlist": ("PLAYLIST",),
},
"required": {
"fps": ("FLOAT", {"default": 24, "min": 1}),
"prefix": ("STRING", {"default": "export"}),
"format": (
["mov", "mp4", "mkv", "gif", "avi"],
{"default": "mov"},
),
"codec": (
["prores_ks", "libx264", "libx265", "gif"],
{"default": "prores_ks"},
),
},
}
RETURN_TYPES = ("VIDEO",)
OUTPUT_NODE = True
FUNCTION = "export_prores"
DEPRECATED = True
CATEGORY = "mtb/IO"
def export_prores(
self,
fps: float,
prefix: str,
format: str,
codec: str,
images: torch.Tensor | None = None,
playlist: list[str] | None = None,
):
file_ext = format
file_id = f"{prefix}_{uuid.uuid4()}.{file_ext}"
if playlist is not None and images is not None:
log.info(f"Exporting to {output_dir / file_id}")
if playlist is not None:
if len(playlist) == 0:
log.debug("Playlist is empty, skipping")
return ("",)
temp_playlist_path = (
output_dir / f"temp_playlist_{uuid.uuid4()}.txt"
)
log.debug(
f"Create a temporary file to list the videos for concatenation to {temp_playlist_path}"
)
with open(temp_playlist_path, "w") as f:
for video_path in playlist:
f.write(f"file '{video_path}'\n")
out_path = (output_dir / file_id).as_posix()
# Prepare the FFmpeg command for concatenating videos from the playlist
command = [
"ffmpeg",
"-f",
"concat",
"-safe",
"0",
"-i",
temp_playlist_path.as_posix(),
"-c",
"copy",
"-y",
out_path,
]
log.debug(f"Executing {command}")
subprocess.run(command)
temp_playlist_path.unlink()
return (out_path,)
if (
images is None or images.size(0) == 0
): # the is None check is just for the type checker
return ("",)
frames = tensor2np(images)
log.debug(f"Frames type {type(frames[0])}")
log.debug(f"Exporting {len(frames)} frames")
height, width, channels = frames[0].shape
has_alpha = channels == 4
out_path = (output_dir / file_id).as_posix()
if codec == "gif":
command = [
"ffmpeg",
"-f",
"image2pipe",
"-vcodec",
"png",
"-r",
str(fps),
"-i",
"-",
"-vcodec",
"gif",
"-y",
out_path,
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
for frame in frames:
model_management.throw_exception_if_processing_interrupted()
Image.fromarray(frame).save(process.stdin, "PNG")
process.stdin.close()
process.wait()
return (out_path,)
else:
if has_alpha:
if codec in ["prores_ks", "libx264", "libx265"]:
pix_fmt = (
"yuva444p" if codec == "prores_ks" else "yuva420p"
)
frames = [
frame.astype(np.uint16) * 257 for frame in frames
]
else:
log.warning(
f"Alpha channel not supported for codec {codec}. Alpha will be ignored."
)
frames = [
frame[:, :, :3].astype(np.uint16) * 257
for frame in frames
]
pix_fmt = "rgb48le" if codec == "prores_ks" else "yuv420p"
else:
pix_fmt = "rgb48le" if codec == "prores_ks" else "yuv420p"
frames = [frame.astype(np.uint16) * 257 for frame in frames]
# Prepare the FFmpeg command
command = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-vcodec",
"rawvideo",
"-s",
f"{width}x{height}",
"-pix_fmt",
pix_fmt,
"-r",
str(fps),
"-i",
"-",
"-c:v",
codec,
]
if codec == "prores_ks":
command.extend(["-profile:v", "4444"])
command.extend(
[
"-r",
str(fps),
"-y",
out_path,
]
)
process = subprocess.Popen(command, stdin=subprocess.PIPE)
pbar = comfy.utils.ProgressBar(len(frames))
for frame in frames:
process.stdin.write(frame.tobytes())
pbar.update(1)
process.stdin.close()
process.wait()
return (out_path,)
def prepare_animated_batch(
batch: torch.Tensor,
pingpong=False,
resize_by=1.0,
resample_filter: Image.Resampling | None = None,
image_type=np.uint8,
) -> list[Image.Image]:
images = tensor2np(batch)
images = [frame.astype(image_type) for frame in images]
height, width, _ = batch[0].shape
if pingpong:
reversed_frames = images[::-1]
images.extend(reversed_frames)
pil_images = [Image.fromarray(frame) for frame in images]
# Resize frames if necessary
if abs(resize_by - 1.0) > 1e-6:
new_width = int(width * resize_by)
new_height = int(height * resize_by)
pil_images_resized = [
frame.resize((new_width, new_height), resample=resample_filter)
for frame in pil_images
]
pil_images = pil_images_resized
return pil_images
# todo: deprecate for apng
class MTB_SaveGif:
"""Save the images from the batch as a GIF.
[DEPRACATED] Use VHS nodes instead
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"fps": ("INT", {"default": 12, "min": 1, "max": 120}),
"resize_by": ("FLOAT", {"default": 1.0, "min": 0.1}),
"optimize": ("BOOLEAN", {"default": False}),
"pingpong": ("BOOLEAN", {"default": False}),
"resample_filter": (list(PIL_FILTER_MAP.keys()),),
"use_ffmpeg": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
CATEGORY = "mtb/IO"
FUNCTION = "save_gif"
DEPRECATED = True
def save_gif(
self,
image,
fps=12,
resize_by=1.0,
optimize=False,
pingpong=False,
resample_filter=None,
use_ffmpeg=False,
):
if image.size(0) == 0:
return ("",)
if resample_filter is not None:
resample_filter = PIL_FILTER_MAP.get(resample_filter)
pil_images = prepare_animated_batch(
image,
pingpong,
resize_by,
resample_filter,
)
ruuid = uuid.uuid4()
ruuid = ruuid.hex[:10]
out_path = f"{folder_paths.output_directory}/{ruuid}.gif"
if use_ffmpeg:
# Use FFmpeg to create the GIF from PIL images
command = [
"ffmpeg",
"-f",
"image2pipe",
"-vcodec",
"png",
"-r",
str(fps),
"-i",
"-",
"-vcodec",
"gif",
"-y",
out_path,
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
for image in pil_images:
model_management.throw_exception_if_processing_interrupted()
image.save(process.stdin, "PNG")
process.stdin.close()
process.wait()
else:
pil_images[0].save(
out_path,
save_all=True,
append_images=pil_images[1:],
optimize=optimize,
duration=int(1000 / fps),
loop=0,
)
results = [
{"filename": f"{ruuid}.gif", "subfolder": "", "type": "output"}
]
return {"ui": {"gif": results}}
__nodes__ = [
MTB_SaveGif,
MTB_ExportWithFfmpeg,
MTB_AddToPlaylist,
MTB_ReadPlaylist,
]