xierui.0097
Add application file
f0e9666
raw
history blame
3.64 kB
import os
import random
import glob
import torchvision
from einops import rearrange
from torch.utils import data as data
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
class PairedCaptionVideoDataset(data.Dataset):
def __init__(
self,
root_folders=None,
null_text_ratio=0.5,
num_frames=16
):
super(PairedCaptionVideoDataset, self).__init__()
self.null_text_ratio = null_text_ratio
self.lr_list = []
self.gt_list = []
self.tag_path_list = []
self.num_frames = num_frames
# root_folders = root_folders.split(',')
for root_folder in root_folders:
lr_path = root_folder +'/lq'
tag_path = root_folder +'/text'
gt_path = root_folder +'/gt'
self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
assert len(self.lr_list) == len(self.gt_list)
assert len(self.lr_list) == len(self.tag_path_list)
def __getitem__(self, index):
gt_path = self.gt_list[index]
vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
fps = info['video_fps']
vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
# gt = self.trandform(vframes_gt)
lq_path = self.lr_list[index]
vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
# lq = self.trandform(vframes_lq)
if random.random() < self.null_text_ratio:
tag = ''
else:
tag_path = self.tag_path_list[index]
with open(tag_path, 'r', encoding='utf-8') as file:
tag = file.read()
return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}
def __len__(self):
return len(self.gt_list)
class PairedCaptionImageDataset(data.Dataset):
def __init__(
self,
root_folder=None,
):
super(PairedCaptionImageDataset, self).__init__()
self.lr_list = []
self.gt_list = []
self.tag_path_list = []
lr_path = root_folder +'/sr_bicubic'
gt_path = root_folder +'/gt'
self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
assert len(self.lr_list) == len(self.gt_list)
self.img_preproc = transforms.Compose([
transforms.ToTensor(),
])
# Define the crop size (e.g., 256x256)
crop_size = (720, 1280)
# CenterCrop transform
self.center_crop = transforms.CenterCrop(crop_size)
def __getitem__(self, index):
gt_path = self.gt_list[index]
gt_img = Image.open(gt_path).convert('RGB')
gt_img = self.center_crop(self.img_preproc(gt_img))
lq_path = self.lr_list[index]
lq_img = Image.open(lq_path).convert('RGB')
lq_img = self.center_crop(self.img_preproc(lq_img))
example = dict()
example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
example["text"] = ""
return example
def __len__(self):
return len(self.gt_list)