RefurnishAI / custom_nodes /comfy_mtb /nodes /image_interpolation.py
multimodalart's picture
Squashing commit
4450790 verified
from pathlib import Path
import comfy
import comfy.model_management as model_management
import comfy.utils
import numpy as np
import tensorflow as tf
import torch
from frame_interpolation.eval import interpolator, util
from ..errors import ModelNotFound
from ..log import log
from ..utils import get_model_path
class MTB_LoadFilmModel:
"""Loads a FILM model
[DEPRECATED] Use ComfyUI-FrameInterpolation instead
"""
@staticmethod
def get_models() -> list[Path]:
models_paths = get_model_path("FILM").iterdir()
return [x for x in models_paths if x.suffix in [".onnx", ".pth"]]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"film_model": (
["L1", "Style", "VGG"],
{"default": "Style"},
),
},
}
RETURN_TYPES = ("FILM_MODEL",)
FUNCTION = "load_model"
CATEGORY = "mtb/frame iterpolation"
DEPRECATED = True
def load_model(self, film_model: str):
model_path = get_model_path("FILM", film_model)
if not model_path or not model_path.exists():
raise ModelNotFound(f"FILM ({model_path})")
if not (model_path / "saved_model.pb").exists():
model_path = model_path / "saved_model"
if not model_path.exists():
log.error(f"Model {model_path} does not exist")
raise ValueError(f"Model {model_path} does not exist")
log.info(f"Loading model {model_path}")
return (interpolator.Interpolator(model_path.as_posix(), None),)
class MTB_FilmInterpolation:
"""Google Research FILM frame interpolation for large motion
[DEPRECATED] Use ComfyUI-FrameInterpolation instead
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"interpolate": ("INT", {"default": 2, "min": 1, "max": 50}),
"film_model": ("FILM_MODEL",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "do_interpolation"
CATEGORY = "mtb/frame iterpolation"
DEPRECATED = True
def do_interpolation(
self,
images: torch.Tensor,
interpolate: int,
film_model: interpolator.Interpolator,
):
n = images.size(0)
# check if images is an empty tensor and return it...
if n == 0:
return (images,)
# check if tensorflow GPU is available
available_gpus = tf.config.list_physical_devices("GPU")
if not len(available_gpus):
log.warning(
"Tensorflow GPU not available, falling back to CPU this will be very slow"
)
else:
log.debug(f"Tensorflow GPU available, using {available_gpus}")
num_frames = (n - 1) * (2 ** (interpolate) - 1)
log.debug(f"Will interpolate into {num_frames} frames")
in_frames = [images[i] for i in range(n)]
out_tensors = []
pbar = comfy.utils.ProgressBar(num_frames)
for frame in util.interpolate_recursively_from_memory(
in_frames, interpolate, film_model
):
out_tensors.append(
torch.from_numpy(frame)
if isinstance(frame, np.ndarray)
else frame
)
model_management.throw_exception_if_processing_interrupted()
pbar.update(1)
out_tensors = torch.cat(
[tens.unsqueeze(0) for tens in out_tensors], dim=0
)
log.debug(f"Returning {len(out_tensors)} tensors")
log.debug(f"Output shape {out_tensors.shape}")
log.debug(f"Output type {out_tensors.dtype}")
return (out_tensors,)
__nodes__ = [MTB_LoadFilmModel, MTB_FilmInterpolation]