import av import gc import torch import torchaudio import numpy as np import random import logging import io from torchvision.transforms.functional import pil_to_tensor logger = logging.getLogger(__name__) def get_index(num_frames, num_segments): seg_size = float(num_frames - 1) / num_segments start = int(seg_size / 2) offsets = np.array([ start + int(np.round(seg_size * idx)) for idx in range(num_segments) ]) return offsets def lazy_load_s3video(s3path_video, num_frames, video_start_frame, video_end_frame, client): # load video from ceph assert client is not None video_bytes_stream = client.get(s3path_video, enable_stream_lazyloding=True) container = av.open(video_bytes_stream) stream = container.streams.video[0] # duration = stream.duration real_fps = container.streams.video[0].average_rate time_base = container.streams.video[0].time_base start, end = video_start_frame, video_end_frame # Convert time to pts duration_frams = end - start + 1 frames_index = get_index(duration_frams, num_frames) pts_list = [] start_pts = int((start/real_fps) / time_base) end_pts = int((end/real_fps) / time_base) for frame_index in frames_index: pts_list.append(int((frame_index / real_fps)) / time_base) # Seek to nearest key frame from the start container.seek(max(start_pts, 0), stream=stream) frames = [] for frame in container.decode(**{"video":0}): if frame.pts < start_pts: continue # if frame.pts <= end_pts: if len(pts_list) >0: if frame.pts >= pts_list[0]: frames.append(frame) pts_list.pop(0) else: break frames = [pil_to_tensor(frames[idx].to_rgb().to_image()).unsqueeze(0) for idx in range(len(frames))] container.close() del video_bytes_stream # T C H W return torch.cat(frames, dim=0) # , start, end, float(real_fps) def load_audio_av(video_path, video_start_frame, video_end_frame, sr, max_audio_length, client): # sr should be 16000 assert client is not None video_bytes_stream = client.get(video_path, enable_stream_lazyloding=True) try: container = av.open(video_bytes_stream) except: logger.warn(f"Something wrong when av.open (video_path: {video_path})!") return None if len(container.streams.audio) == 0: logger.warn(f"There is no audio! (video_path: {video_path})!") return None audio_stream = container.streams.audio[0] real_fps = container.streams.video[0].average_rate time_base = audio_stream.time_base csr = audio_stream.sample_rate start_frame, end_frame = video_start_frame, video_end_frame start_pts = int((start_frame/real_fps) / time_base) end_pts = int((end_frame/real_fps) / time_base) frames = [] container.seek(max(start_pts, 0), stream=audio_stream) try: for frame in container.decode(**{"audio":0}): if frame.pts < start_pts: continue frames.append(frame.to_ndarray()) if frame.pts > end_pts: break except: gc.collect() pass # gc.collect() container.close() del video_bytes_stream audio_raw = np.concatenate(frames, 1) audio = torch.from_numpy(audio_raw) if audio.size(0) == 2: audio = torch.mean(audio, dim=0, keepdim=True) if len(audio.shape) == 1: audio = audio.unsqueeze(0) assert max_audio_length == 10, max_audio_length max_length = max_audio_length * sr if csr != sr: trans = torchaudio.transforms.Resample(csr, sr) audio = trans(audio) if audio.shape[1] >= max_length: max_start = audio.shape[1] - max_length start = random.randint(0, max_start) audio = audio[:, start: start + max_length] audio = audio * 2 ** 15 fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) fbank_mean = 15.41663 fbank_std = 6.55582 fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 src_length = fbank.shape[0] pad_len = 998 - src_length fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() return fbank#, padding_mask def load_full_audio_av(video_path, sr, max_audio_length, client): assert client is not None video_bytes_stream = client.get(video_path) #, enable_stream_lazyloding=False) try: container = av.open(io.BytesIO(video_bytes_stream)) except Exception as e: logger.warn(f"Something wrong {e} when av.open (video_path: {video_path})!") return None if len(container.streams.audio) == 0: logger.warn(f"There is no audio! (video_path: {video_path})!") return None audio_stream = container.streams.audio[0] csr = audio_stream.sample_rate frames = [] try: for frame in container.decode(**{"audio":0}): frames.append(frame.to_ndarray()) except: gc.collect() pass # gc.collect() container.close() del video_bytes_stream audio_raw = np.concatenate(frames, 1) audio = torch.from_numpy(audio_raw) if audio.size(0) == 2: audio = torch.mean(audio, dim=0, keepdim=True) if len(audio.shape)==1: audio = audio.unsqueeze(0) assert max_audio_length == 10, max_audio_length max_length = max_audio_length * sr if csr != sr: trans = torchaudio.transforms.Resample(csr, sr) audio = trans(audio) if audio.shape[1] >= max_length: max_start = audio.shape[1] - max_length start = random.randint(0, max_start) audio = audio[:, start: start + max_length] audio = audio * 2 ** 15 fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) fbank_mean = 15.41663 fbank_std = 6.55582 fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 src_length = fbank.shape[0] pad_len = 998 - src_length fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() return fbank #, padding_mask # frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 # # frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8