File size: 6,511 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import av
import gc
import torch
import torchaudio
import numpy as np
import random
import logging
import io
from torchvision.transforms.functional import pil_to_tensor

logger = logging.getLogger(__name__)



def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


def lazy_load_s3video(s3path_video, num_frames, video_start_frame, video_end_frame, client):
    # load video from ceph
    assert client is not None
    video_bytes_stream = client.get(s3path_video, enable_stream_lazyloding=True)
    container = av.open(video_bytes_stream)
    stream = container.streams.video[0]
    # duration = stream.duration
    real_fps = container.streams.video[0].average_rate
    time_base = container.streams.video[0].time_base
    start, end = video_start_frame, video_end_frame
    # Convert time to pts
    duration_frams = end - start + 1
    frames_index = get_index(duration_frams, num_frames)

    pts_list = []

    start_pts = int((start/real_fps) / time_base)
    end_pts = int((end/real_fps) / time_base)
    for frame_index in frames_index:
        pts_list.append(int((frame_index / real_fps)) /  time_base)

    # Seek to nearest key frame from the start
    container.seek(max(start_pts, 0), stream=stream)
    
    frames = []
    for frame in container.decode(**{"video":0}):
        if frame.pts < start_pts:
            continue
        # if frame.pts <= end_pts:
        if len(pts_list) >0:
            if frame.pts >= pts_list[0]:
                frames.append(frame)
                pts_list.pop(0)
        else:
            break
    frames = [pil_to_tensor(frames[idx].to_rgb().to_image()).unsqueeze(0) for idx in range(len(frames))]
    container.close()
    del video_bytes_stream # T C H W

    return torch.cat(frames, dim=0) # , start, end, float(real_fps)


def load_audio_av(video_path, video_start_frame, video_end_frame, sr, max_audio_length, client):  # sr should be 16000
    assert client is not None
    video_bytes_stream = client.get(video_path, enable_stream_lazyloding=True)
    try:
        container = av.open(video_bytes_stream)
    except:
        logger.warn(f"Something wrong when av.open (video_path: {video_path})!")
        return None
    if len(container.streams.audio) == 0:
        logger.warn(f"There is no audio! (video_path: {video_path})!")
        return None
    audio_stream = container.streams.audio[0]
    real_fps = container.streams.video[0].average_rate
    time_base = audio_stream.time_base
    csr = audio_stream.sample_rate
    start_frame, end_frame = video_start_frame, video_end_frame
    start_pts = int((start_frame/real_fps) / time_base)
    end_pts = int((end_frame/real_fps) / time_base)
    frames = []
    container.seek(max(start_pts, 0), stream=audio_stream)
    try:
        for frame in container.decode(**{"audio":0}):
            if frame.pts < start_pts:
                continue
            frames.append(frame.to_ndarray())
            if frame.pts > end_pts:
                break
    except:
        gc.collect()
        pass
    # gc.collect()
    container.close()
    del video_bytes_stream

    audio_raw = np.concatenate(frames, 1)
    audio = torch.from_numpy(audio_raw)
    if audio.size(0) == 2:
        audio = torch.mean(audio, dim=0, keepdim=True)
    if len(audio.shape) == 1:
        audio = audio.unsqueeze(0)
    assert max_audio_length == 10, max_audio_length
    max_length = max_audio_length * sr
    if csr != sr:
        trans = torchaudio.transforms.Resample(csr, sr)
        audio = trans(audio)
    if audio.shape[1] >= max_length:
        max_start = audio.shape[1] - max_length
        start = random.randint(0, max_start)
        audio = audio[:, start: start + max_length]
    audio = audio * 2 ** 15
    fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10)
    fbank_mean = 15.41663
    fbank_std = 6.55582
    fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64

    src_length = fbank.shape[0]
    pad_len = 998 - src_length
    fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
    padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool()

    return fbank#, padding_mask

def load_full_audio_av(video_path, sr, max_audio_length, client):
    assert client is not None
    video_bytes_stream = client.get(video_path) #, enable_stream_lazyloding=False)
    try:
        container = av.open(io.BytesIO(video_bytes_stream))
    except Exception as e:
        logger.warn(f"Something wrong {e} when av.open (video_path: {video_path})!")
        return None
    if len(container.streams.audio) == 0:
        logger.warn(f"There is no audio! (video_path: {video_path})!")
        return None
    audio_stream = container.streams.audio[0]
    csr = audio_stream.sample_rate
    frames = []
    try:
        for frame in container.decode(**{"audio":0}):
            frames.append(frame.to_ndarray())
    except:
        gc.collect()
        pass
    # gc.collect()
    container.close()
    del video_bytes_stream

    audio_raw = np.concatenate(frames, 1)
    audio = torch.from_numpy(audio_raw)
    if audio.size(0) == 2:
        audio = torch.mean(audio, dim=0, keepdim=True)
    if len(audio.shape)==1:
        audio = audio.unsqueeze(0)
    assert max_audio_length == 10, max_audio_length
    max_length = max_audio_length * sr
    if csr != sr:
        trans = torchaudio.transforms.Resample(csr, sr)
        audio = trans(audio)
    if audio.shape[1] >= max_length:
        max_start = audio.shape[1] - max_length
        start = random.randint(0, max_start)
        audio = audio[:, start: start + max_length]
    audio = audio * 2 ** 15
    fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10)
    fbank_mean = 15.41663
    fbank_std = 6.55582
    fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64

    src_length = fbank.shape[0]
    pad_len = 998 - src_length
    fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
    padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool()

    return fbank #, padding_mask


    # frames = video_reader.get_batch(frame_indices)  # (T, H, W, C), torch.uint8
    # # frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W), torch.uint8