lavila / demo_narrator.py
nateraw's picture
Upload . with huggingface_hub
39d5658
raw
history blame
3.63 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import urllib.request
from collections import OrderedDict
import decord
import torch
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from lavila.data.video_transforms import Permute
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL
from lavila.models.tokenizer import MyGPT2Tokenizer
from eval_narrator import decode_one
def main(args):
vr = decord.VideoReader(args.video_path)
num_seg = 4
frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False)
frames = video_loader_by_frames('./', args.video_path, frame_ids)
ckpt_name = 'vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth'
ckpt_path = os.path.join('modelzoo/', ckpt_name)
os.makedirs('modelzoo/', exist_ok=True)
if not os.path.exists(ckpt_path):
print('downloading model to {}'.format(ckpt_path))
urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/{}'.format(ckpt_name), ckpt_path)
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
# instantiate the model, and load the pre-trained weights
model = VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(
text_use_cls_token=False,
project_embed_dim=256,
gated_xattn=True,
timesformer_gated_xattn=False,
freeze_lm_vclm=False, # we use model.eval() anyway
freeze_visual_vclm=False, # we use model.eval() anyway
num_frames=4,
drop_path_rate=0.
)
model.load_state_dict(state_dict, strict=True)
if args.cuda:
model.cuda()
model.eval()
# transforms on input frames
crop_size = 336
val_transform = transforms.Compose([
Permute([3, 0, 1, 2]),
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
])
frames = val_transform(frames)
frames = frames.unsqueeze(0) # fake a batch dimension
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
with torch.no_grad():
if args.cuda:
frames = frames.cuda(non_blocking=True)
image_features = model.encode_image(frames)
generated_text_ids, ppls = model.generate(
image_features,
tokenizer,
target=None, # free-form generation
max_text_length=77,
top_k=None,
top_p=0.95, # nucleus sampling
num_return_sequences=10, # number of candidates: 10
temperature=0.7,
early_stopping=True,
)
for i in range(10):
generated_text_str = decode_one(generated_text_ids[i], tokenizer)
print('{}: {}'.format(i, generated_text_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('lavila narrator demo')
parser.add_argument('--cuda', action='store_true', help='use cuda')
parser.add_argument('--video-path', default='assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', type=str, help='video path')
args = parser.parse_args()
main(args)