File size: 4,942 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from pathlib import Path 
import os
import sys
sys.path.append(str(Path(os.path.abspath(''))))

import torch
import numpy as np
from tools.genrl_utils import ViCLIPGlobalInstance

import time
import torchvision
from huggingface_hub import hf_hub_download

def save_videos(batch_tensors, savedir, filenames, fps=10):
    # b,samples,c,t,h,w
    n_samples = batch_tensors.shape[1]
    for idx, vid_tensor in enumerate(batch_tensors):
        video = vid_tensor.detach().cpu()
        video = torch.clamp(video.float(), 0., 1.)
        video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
        frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
        grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
        grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
        savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
        torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})

class Text2Video():
    def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
        model_folder = str(Path(os.path.abspath('')) / 'models')
        model_filename = 'genrl_stickman_500k_2.pt'
        
        if not os.path.isfile(os.path.join(model_folder, model_filename)):
            self.download_model(model_folder, model_filename)
        if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
            self.download_internvideo2(model_folder)
        self.agent = torch.load(os.path.join(model_folder, model_filename))
        model_name = 'internvideo2'

        # Get ViCLIP
        viclip_global_instance = ViCLIPGlobalInstance(model_name)
        if not viclip_global_instance._instantiated:
            print("Instantiating InternVideo2")
            viclip_global_instance.instantiate()
        self.clip = viclip_global_instance.viclip
        self.tokenizer = viclip_global_instance.viclip_tokenizer

        self.result_dir = result_dir
        if not os.path.exists(self.result_dir):
            os.mkdir(self.result_dir)

    def get_prompt(self, prompt, duration):
        torch.cuda.empty_cache()
        print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
        start = time.time()

        prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
        prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str

        labels_list = [prompt_str]
        with torch.no_grad():
            wm = world_model = self.agent.wm
            connector = self.agent.wm.connector
            decoder = world_model.heads['decoder']
            n_frames = connector.n_frames
            
            # Get text(video) embed
            text_feat = []
            for text in labels_list:
                with torch.no_grad():
                    text_feat.append(self.clip.get_txt_feat(text,))
            text_feat = torch.stack(text_feat, dim=0).to(self.clip.device)

            video_embed = text_feat

            B = video_embed.shape[0]
            T = 1

            # Get actions
            video_embed = video_embed.repeat(1, duration, 1)
            with torch.no_grad():
                # Imagine
                prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=True)
                # Decode
                prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5

        save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
        print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
        return os.path.join(self.result_dir, f"{prompt_str}.mp4")
    
    def download_model(self, model_folder, model_filename):
        REPO_ID = 'mazpie/genrl_models'
        filename_list = [model_filename]
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
        for filename in filename_list:
            local_file = os.path.join(model_folder, filename)

            if not os.path.exists(local_file):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
    
    def download_internvideo2(self, model_folder):
        REPO_ID = 'OpenGVLab/InternVideo2-Stage2_1B-224p-f4'
        filename_list = ['InternVideo2-stage2_1b-224p-f4.pt']
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
        for filename in filename_list:
            local_file = os.path.join(model_folder, filename)

            if not os.path.exists(local_file):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)

if __name__ == '__main__':
    t2v = Text2Video()
    video_path = t2v.get_prompt('a black swan swims on the pond', 8)
    print('done', video_path)