|
import cv2 |
|
import glob |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
|
|
from networks.amts import Model as AMTS |
|
from networks.amtl import Model as AMTL |
|
from networks.amtg import Model as AMTG |
|
from utils import ( |
|
img2tensor, tensor2img, |
|
InputPadder, |
|
check_dim_and_resize |
|
) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model_dict = { |
|
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG |
|
} |
|
|
|
def img2vid(model_type, img0, img1, frame_ratio, iters): |
|
model = model_dict[model_type]() |
|
model.to(device) |
|
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') |
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
img0_t = img2tensor(img0).to(device) |
|
img1_t = img2tensor(img1).to(device) |
|
inputs = [img0_t, img1_t] |
|
|
|
|
|
if device == 'cuda': |
|
anchor_resolution = 1024 * 512 |
|
anchor_memory = 1500 * 1024**2 |
|
anchor_memory_bias = 2500 * 1024**2 |
|
vram_avail = torch.cuda.get_device_properties(device).total_memory |
|
else: |
|
|
|
anchor_resolution = 8192*8192 |
|
anchor_memory = 1 |
|
anchor_memory_bias = 0 |
|
vram_avail = 1 |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
inputs = check_dim_and_resize(inputs) |
|
h, w = inputs[0].shape[-2:] |
|
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) |
|
scale = 1 if scale > 1 else scale |
|
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 |
|
if scale < 1: |
|
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") |
|
padding = int(16 / scale) |
|
padder = InputPadder(inputs[0].shape, padding) |
|
inputs = padder.pad(*inputs) |
|
|
|
for i in range(iters): |
|
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
|
outputs = [inputs[0]] |
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
|
in_0 = in_0.to(device) |
|
in_1 = in_1.to(device) |
|
with torch.no_grad(): |
|
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] |
|
outputs += [imgt_pred.cpu(), in_1.cpu()] |
|
inputs = outputs |
|
outputs = padder.unpad(*outputs) |
|
out_path = 'results' |
|
size = outputs[0].shape[2:][::-1] |
|
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size) |
|
for i, imgt_pred in enumerate(outputs): |
|
imgt_pred = tensor2img(imgt_pred) |
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
|
writer.write(imgt_pred) |
|
writer.release() |
|
return 'results/demo.mp4' |
|
|
|
|
|
def demo_img(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('## Image Demo') |
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<div style="text-align: left; auto;"> |
|
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> |
|
Description: With 2 input images, you can generate a short video from them. |
|
</h3> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img0 = gr.Image(label='Image0') |
|
img1 = gr.Image(label='Image1') |
|
with gr.Column(): |
|
result = gr.Video(label="Generated Video") |
|
with gr.Accordion('Advanced options', open=False): |
|
ratio = gr.Slider(label='Multiple Ratio', |
|
minimum=4, |
|
maximum=7, |
|
value=6, |
|
step=1) |
|
frame_ratio = gr.Slider(label='Frame Ratio', |
|
minimum=8, |
|
maximum=64, |
|
value=16, |
|
step=1) |
|
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], |
|
label='Model Select', |
|
value='AMT-S') |
|
run_button = gr.Button(label='Run') |
|
inputs = [ |
|
model_type, |
|
img0, |
|
img1, |
|
frame_ratio, |
|
ratio, |
|
] |
|
|
|
gr.Examples(examples=glob.glob("examples/*.png"), |
|
inputs=img0, |
|
label='Example images (drag them to input windows)', |
|
run_on_click=False, |
|
) |
|
|
|
run_button.click(fn=img2vid, |
|
inputs=inputs, |
|
outputs=result,) |
|
return demo |