File size: 3,640 Bytes
f0e9666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import random
import glob
import torchvision
from einops import rearrange
from torch.utils import data as data
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

class PairedCaptionVideoDataset(data.Dataset):
    def __init__(
            self,
            root_folders=None,
            null_text_ratio=0.5,
            num_frames=16
    ):
        super(PairedCaptionVideoDataset, self).__init__()

        self.null_text_ratio = null_text_ratio
        self.lr_list = []
        self.gt_list = []
        self.tag_path_list = []
        self.num_frames = num_frames

        # root_folders = root_folders.split(',')
        for root_folder in root_folders:
            lr_path = root_folder +'/lq'
            tag_path = root_folder +'/text'
            gt_path = root_folder +'/gt'

            self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
            self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
            self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))

        assert len(self.lr_list) == len(self.gt_list)
        assert len(self.lr_list) == len(self.tag_path_list)

    def __getitem__(self, index):

        gt_path = self.gt_list[index]
        vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
        fps = info['video_fps']
        vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
        # gt = self.trandform(vframes_gt)

        lq_path = self.lr_list[index]
        vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
        vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
        # lq = self.trandform(vframes_lq)

        if random.random() < self.null_text_ratio:
            tag = ''
        else:
            tag_path = self.tag_path_list[index]
            with open(tag_path, 'r', encoding='utf-8') as file:
                tag = file.read()

        return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}

    def __len__(self):
        return len(self.gt_list)


class PairedCaptionImageDataset(data.Dataset):
    def __init__(
            self,
            root_folder=None,
    ):
        super(PairedCaptionImageDataset, self).__init__()

        self.lr_list = []
        self.gt_list = []
        self.tag_path_list = []

        lr_path = root_folder +'/sr_bicubic'
        gt_path = root_folder +'/gt'

        self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
        self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))

        assert len(self.lr_list) == len(self.gt_list)

        self.img_preproc = transforms.Compose([       
            transforms.ToTensor(),
        ])

        # Define the crop size (e.g., 256x256)
        crop_size = (720, 1280)

        # CenterCrop transform
        self.center_crop = transforms.CenterCrop(crop_size)

    def __getitem__(self, index):

        gt_path = self.gt_list[index]
        gt_img = Image.open(gt_path).convert('RGB')
        gt_img = self.center_crop(self.img_preproc(gt_img))
        
        lq_path = self.lr_list[index]
        lq_img = Image.open(lq_path).convert('RGB')
        lq_img = self.center_crop(self.img_preproc(lq_img))

        example = dict()
        
        example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
        example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
        example["text"] = ""

        return example

    def __len__(self):
        return len(self.gt_list)