File size: 4,728 Bytes
2fb3163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import cv2
import glob
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download

from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import img2tensor, tensor2img, InputPadder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
    'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}


def vid2vid(model_type, video, iters):
    model = model_dict[model_type]()
    model.to(device)
    ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    vcap = cv2.VideoCapture(video)
    ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS)
    inputs = []
    h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
    w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    if device == 'cuda':
        anchor_resolution = 1024 * 512
        anchor_memory = 1500 * 1024**2
        anchor_memory_bias = 2500 * 1024**2
        vram_avail = torch.cuda.get_device_properties(device).total_memory
    else:
        # Do not resize in cpu mode
        anchor_resolution = 8192*8192
        anchor_memory = 1
        anchor_memory_bias = 0
        vram_avail = 1
    
    scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
    scale = 1 if scale > 1 else scale
    scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
    if scale < 1:
        print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
    padding = int(16 / scale)
    padder = InputPadder((h, w), padding)
    while True:
        ret, frame = vcap.read()
        if ret is False:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_t = img2tensor(frame).to(device)
        frame_t = padder.pad(frame_t)
        inputs.append(frame_t)
    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)

    for i in range(iters):
        print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
        outputs = [inputs[0]]
        for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
            with torch.no_grad():
                imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
            imgt_pred = padder.unpad(imgt_pred)
            in_1 = padder.unpad(in_1)
            outputs += [imgt_pred, in_1]
        inputs = outputs

    out_path = 'results'
    size = outputs[0].shape[2:][::-1]
    writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4', 
                             cv2.VideoWriter_fourcc(*'mp4v'), 
                             ori_frame_rate * 2 ** iters, size)
    for i, imgt_pred in enumerate(outputs):
        imgt_pred = tensor2img(imgt_pred)
        imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
        writer.write(imgt_pred)
    writer.release()
    return 'results/demo_vfi.mp4'

    
def demo_vid():
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Video Demo')
        with gr.Row():
            gr.HTML(
                """
                <div style="text-align: left; auto;">
                <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
                    Description: You can increase the frame rate of the video by 2 times, 4 times, or 8 times. (The video should be less than 10 seconds.)
                </h3>
                </div>
                """)

        with gr.Row():
            with gr.Column():
                video = gr.Video(label='Video Input')
            with gr.Column():
                result = gr.Video(label="Generated Video")
                with gr.Accordion('Advanced options', open=False):
                    ratio = gr.Slider(label='Multiple Ratio',
                                     minimum=1,
                                     maximum=4,
                                     value=2,
                                     step=1)
                    model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
                                          label='Model Select',
                                          value='AMT-S')
                run_button = gr.Button(label='Run')
        inputs = [
            model_type,
            video,
            ratio,
        ]

        gr.Examples(examples=glob.glob("examples/*.mp4"),
                inputs=video,
                label='Example videos (drag them to the input window)',
                run_on_click=False,
        )

        run_button.click(fn=vid2vid,
                         inputs=inputs,
                         outputs=result,)
    return demo