Spaces:
Sleeping
Sleeping
import av | |
import gc | |
import torch | |
import torchaudio | |
import numpy as np | |
import random | |
import logging | |
import io | |
from torchvision.transforms.functional import pil_to_tensor | |
logger = logging.getLogger(__name__) | |
def get_index(num_frames, num_segments): | |
seg_size = float(num_frames - 1) / num_segments | |
start = int(seg_size / 2) | |
offsets = np.array([ | |
start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
]) | |
return offsets | |
def lazy_load_s3video(s3path_video, num_frames, video_start_frame, video_end_frame, client): | |
# load video from ceph | |
assert client is not None | |
video_bytes_stream = client.get(s3path_video, enable_stream_lazyloding=True) | |
container = av.open(video_bytes_stream) | |
stream = container.streams.video[0] | |
# duration = stream.duration | |
real_fps = container.streams.video[0].average_rate | |
time_base = container.streams.video[0].time_base | |
start, end = video_start_frame, video_end_frame | |
# Convert time to pts | |
duration_frams = end - start + 1 | |
frames_index = get_index(duration_frams, num_frames) | |
pts_list = [] | |
start_pts = int((start/real_fps) / time_base) | |
end_pts = int((end/real_fps) / time_base) | |
for frame_index in frames_index: | |
pts_list.append(int((frame_index / real_fps)) / time_base) | |
# Seek to nearest key frame from the start | |
container.seek(max(start_pts, 0), stream=stream) | |
frames = [] | |
for frame in container.decode(**{"video":0}): | |
if frame.pts < start_pts: | |
continue | |
# if frame.pts <= end_pts: | |
if len(pts_list) >0: | |
if frame.pts >= pts_list[0]: | |
frames.append(frame) | |
pts_list.pop(0) | |
else: | |
break | |
frames = [pil_to_tensor(frames[idx].to_rgb().to_image()).unsqueeze(0) for idx in range(len(frames))] | |
container.close() | |
del video_bytes_stream # T C H W | |
return torch.cat(frames, dim=0) # , start, end, float(real_fps) | |
def load_audio_av(video_path, video_start_frame, video_end_frame, sr, max_audio_length, client): # sr should be 16000 | |
assert client is not None | |
video_bytes_stream = client.get(video_path, enable_stream_lazyloding=True) | |
try: | |
container = av.open(video_bytes_stream) | |
except: | |
logger.warn(f"Something wrong when av.open (video_path: {video_path})!") | |
return None | |
if len(container.streams.audio) == 0: | |
logger.warn(f"There is no audio! (video_path: {video_path})!") | |
return None | |
audio_stream = container.streams.audio[0] | |
real_fps = container.streams.video[0].average_rate | |
time_base = audio_stream.time_base | |
csr = audio_stream.sample_rate | |
start_frame, end_frame = video_start_frame, video_end_frame | |
start_pts = int((start_frame/real_fps) / time_base) | |
end_pts = int((end_frame/real_fps) / time_base) | |
frames = [] | |
container.seek(max(start_pts, 0), stream=audio_stream) | |
try: | |
for frame in container.decode(**{"audio":0}): | |
if frame.pts < start_pts: | |
continue | |
frames.append(frame.to_ndarray()) | |
if frame.pts > end_pts: | |
break | |
except: | |
gc.collect() | |
pass | |
# gc.collect() | |
container.close() | |
del video_bytes_stream | |
audio_raw = np.concatenate(frames, 1) | |
audio = torch.from_numpy(audio_raw) | |
if audio.size(0) == 2: | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
if len(audio.shape) == 1: | |
audio = audio.unsqueeze(0) | |
assert max_audio_length == 10, max_audio_length | |
max_length = max_audio_length * sr | |
if csr != sr: | |
trans = torchaudio.transforms.Resample(csr, sr) | |
audio = trans(audio) | |
if audio.shape[1] >= max_length: | |
max_start = audio.shape[1] - max_length | |
start = random.randint(0, max_start) | |
audio = audio[:, start: start + max_length] | |
audio = audio * 2 ** 15 | |
fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) | |
fbank_mean = 15.41663 | |
fbank_std = 6.55582 | |
fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 | |
src_length = fbank.shape[0] | |
pad_len = 998 - src_length | |
fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) | |
padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() | |
return fbank#, padding_mask | |
def load_full_audio_av(video_path, sr, max_audio_length, client): | |
assert client is not None | |
video_bytes_stream = client.get(video_path) #, enable_stream_lazyloding=False) | |
try: | |
container = av.open(io.BytesIO(video_bytes_stream)) | |
except Exception as e: | |
logger.warn(f"Something wrong {e} when av.open (video_path: {video_path})!") | |
return None | |
if len(container.streams.audio) == 0: | |
logger.warn(f"There is no audio! (video_path: {video_path})!") | |
return None | |
audio_stream = container.streams.audio[0] | |
csr = audio_stream.sample_rate | |
frames = [] | |
try: | |
for frame in container.decode(**{"audio":0}): | |
frames.append(frame.to_ndarray()) | |
except: | |
gc.collect() | |
pass | |
# gc.collect() | |
container.close() | |
del video_bytes_stream | |
audio_raw = np.concatenate(frames, 1) | |
audio = torch.from_numpy(audio_raw) | |
if audio.size(0) == 2: | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
if len(audio.shape)==1: | |
audio = audio.unsqueeze(0) | |
assert max_audio_length == 10, max_audio_length | |
max_length = max_audio_length * sr | |
if csr != sr: | |
trans = torchaudio.transforms.Resample(csr, sr) | |
audio = trans(audio) | |
if audio.shape[1] >= max_length: | |
max_start = audio.shape[1] - max_length | |
start = random.randint(0, max_start) | |
audio = audio[:, start: start + max_length] | |
audio = audio * 2 ** 15 | |
fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) | |
fbank_mean = 15.41663 | |
fbank_std = 6.55582 | |
fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 | |
src_length = fbank.shape[0] | |
pad_len = 998 - src_length | |
fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) | |
padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() | |
return fbank #, padding_mask | |
# frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 | |
# # frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 | |