Spaces:
Sleeping
Sleeping
from utils.distributed import is_main_process, get_rank, get_world_size | |
import logging | |
import torch.distributed as dist | |
import torch | |
import io | |
import os | |
import json | |
import re | |
import random | |
import numpy as np | |
from os.path import join | |
from tqdm import trange | |
from PIL import Image | |
from PIL import ImageFile | |
from torchvision.transforms import PILToTensor | |
import librosa | |
import torchaudio | |
# import soundfile as sf | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
Image.MAX_IMAGE_PIXELS = None | |
logger = logging.getLogger(__name__) | |
def load_audio_from_path(audio_path, client, sr, audio_reader_type, max_length=0): | |
# print(f"audio_path: {audio_path}, client: {client}, sr: {sr}, audio_reader_type: {audio_reader_type}") | |
if "s3://" in audio_path and client is not None: | |
audio_bytes = client.get(audio_path) | |
buff = io.BytesIO(audio_bytes) | |
else: | |
buff = audio_path | |
if audio_reader_type == 'librosa': | |
audio, _ = librosa.load(buff, sr=sr) | |
audio = torch.from_numpy(audio) | |
# audio = normalize(audio) # normalize waveform to -1,1 due to specified sr in librosa.load | |
# elif audio_reader_type == 'soundfile': | |
# audio, _ = sf.read(buff, sr=sr) | |
# audio = torch.from_numpy(audio) | |
elif audio_reader_type == 'torchaudio': | |
torchaudio.set_audio_backend('soundfile') # for flac files | |
audio, csr = torchaudio.load(buff) | |
if csr != sr: | |
trans = torchaudio.transforms.Resample(csr, sr) | |
audio = trans(audio) | |
if audio.size(0) == 2: | |
audio = torch.mean(audio, dim=0, keepdim=False) | |
else: | |
raise NotImplementedError | |
if max_length != 0: | |
# if audio length is longer than max_length, we randomly crop it to uta length | |
if audio.shape[0] >= max_length: | |
max_start = audio.shape[0] - max_length | |
start = random.randint(0, max_start) | |
audio = audio[start: start + max_length] | |
# padding = torch.zeros(audio.shape).long() | |
else: | |
# padding = torch.cat((torch.zeros(audio.shape), torch.ones(max_length-audio.shape[0])), -1).long() | |
audio = torch.nn.functional.pad(audio, (0, max_length-audio.shape[-1]), 'constant') | |
# print(f"post audio max: {audio.max()}, audio min: {audio.min()}, audio shape: {audio.shape}") | |
if len(audio.shape) == 1: | |
audio = audio.unsqueeze(0) | |
fbank = audio * 2 ** 15 | |
fbank = torchaudio.compliance.kaldi.fbank(fbank, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) | |
fbank_mean = 15.41663 | |
fbank_std = 6.55582 | |
fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 | |
return fbank | |
def load_image_from_path(image_path, client): | |
if "s3://" in image_path and client is not None: | |
value = client.Get(image_path) | |
if value is None: | |
logger.warning(f"Failed to load {image_path}") | |
img_bytes = np.frombuffer(value, dtype=np.uint8) | |
buff = io.BytesIO(img_bytes) | |
image = Image.open(buff).convert('RGB') | |
else: | |
image = Image.open(image_path).convert('RGB') # PIL Image | |
image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 | |
return image | |
def load_anno(ann_file_list): | |
"""[summary] | |
Args: | |
ann_file_list (List[List[str, str]] or List[str, str]): | |
the latter will be automatically converted to the former. | |
Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video']) | |
which specifies the data type, video or image | |
Returns: | |
List(dict): each dict is { | |
image: str or List[str], # image_path, | |
caption: str or List[str] # caption text string | |
} | |
""" | |
if isinstance(ann_file_list, dict): | |
ann_file_list = [ann_file_list] | |
ann = [] | |
for d in ann_file_list: | |
data_root = d.data_root | |
data_root_prefix = d.get("data_root_prefix", "") | |
fp = d.anno_path | |
cur_ann = json.load(open(fp, "r")) | |
iterator = trange(len(cur_ann), desc=f"Loading {fp}") \ | |
if is_main_process() else range(len(cur_ann)) | |
for idx in iterator: | |
if d.media_type == "image": | |
key = "image" | |
elif d.media_type in ["video", "audio_video"]: | |
key = "video" | |
elif d.media_type == "audio": | |
key = "audio" | |
else: | |
raise NotImplementedError(key) | |
# unified to have the same key for data path | |
if isinstance(cur_ann[idx][key], str): | |
cur_ann[idx]["image"] = data_root_prefix + join(data_root, cur_ann[idx][key]) | |
else: # list | |
cur_ann[idx]["image"] = [data_root_prefix + join(data_root, e) for e in cur_ann[idx][key]] | |
ann += cur_ann | |
return ann | |
def pre_text(text, max_l=None): | |
assert type(text) is str, text | |
text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) | |
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person') | |
text = re.sub(r"\s{2,}", ' ', text) | |
text = text.rstrip('\n').strip(' ') | |
if max_l: # truncate | |
words = text.split(' ') | |
if len(words) > max_l: | |
text = ' '.join(words[:max_l]) | |
return text | |
def collect_result(result, result_dir, filename, is_json=True, is_list=True): | |
if is_json: | |
result_file = os.path.join( | |
result_dir, '%s_rank%d.json' % (filename, get_rank())) | |
final_result_file = os.path.join(result_dir, '%s.json' % filename) | |
json.dump(result, open(result_file, 'w')) | |
else: | |
result_file = os.path.join( | |
result_dir, '%s_rank%d.pth' % (filename, get_rank())) | |
final_result_file = os.path.join(result_dir, '%s.pth' % filename) | |
torch.save(result, result_file) | |
dist.barrier() | |
result = None | |
if is_main_process(): | |
# combine results from all processes | |
if is_list: | |
result = [] | |
else: | |
result = {} | |
for rank in range(get_world_size()): | |
if is_json: | |
result_file = os.path.join( | |
result_dir, '%s_rank%d.json' % (filename, rank)) | |
res = json.load(open(result_file, 'r')) | |
else: | |
result_file = os.path.join( | |
result_dir, '%s_rank%d.pth' % (filename, rank)) | |
res = torch.load(result_file) | |
if is_list: | |
result += res | |
else: | |
result.update(res) | |
return result | |
def sync_save_result(result, result_dir, filename, is_json=True, is_list=True): | |
"""gather results from multiple GPUs""" | |
if is_json: | |
result_file = os.path.join( | |
result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank())) | |
final_result_file = os.path.join(result_dir, '%s.json' % filename) | |
os.makedirs(os.path.dirname(result_file), exist_ok=True) | |
json.dump(result, open(result_file, 'w')) | |
else: | |
result_file = os.path.join( | |
result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank())) | |
os.makedirs(os.path.dirname(result_file), exist_ok=True) | |
final_result_file = os.path.join(result_dir, '%s.pth' % filename) | |
torch.save(result, result_file) | |
dist.barrier() | |
if is_main_process(): | |
# combine results from all processes | |
if is_list: | |
result = [] | |
else: | |
result = {} | |
for rank in range(get_world_size()): | |
if is_json: | |
result_file = os.path.join( | |
result_dir, "dist_res", '%s_rank%d.json' % (filename, rank)) | |
res = json.load(open(result_file, 'r')) | |
else: | |
result_file = os.path.join( | |
result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank)) | |
res = torch.load(result_file) | |
if is_list: | |
result += res | |
else: | |
result.update(res) | |
if is_json: | |
json.dump(result, open(final_result_file, 'w')) | |
else: | |
torch.save(result, final_result_file) | |
logger.info('result file saved to %s' % final_result_file) | |
dist.barrier() | |
return final_result_file, result | |
def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None): | |
""" Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray) | |
into a (n+1)-d array, only allow the first dim has variable lengths. | |
Args: | |
sequences: list(n-d tensor or list) | |
dtype: np.dtype or torch.dtype | |
device: | |
fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length. | |
return will be of shape [len(sequences), fixed_length, ...] | |
Returns: | |
padded_seqs: ((n+1)-d tensor) padded with zeros | |
mask: (2d tensor) of the same shape as the first two dims of padded_seqs, | |
1 indicate valid, 0 otherwise | |
Examples: | |
>>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] | |
>>> pad_sequences_1d(test_data_list, dtype=torch.long) | |
>>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)] | |
>>> pad_sequences_1d(test_data_3d, dtype=torch.float) | |
>>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] | |
>>> pad_sequences_1d(test_data_list, dtype=np.float32) | |
>>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)] | |
>>> pad_sequences_1d(test_data_3d, dtype=np.float32) | |
""" | |
if isinstance(sequences[0], list): | |
if "torch" in str(dtype): | |
sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences] | |
else: | |
sequences = [np.asarray(s, dtype=dtype) for s in sequences] | |
extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements | |
lengths = [len(seq) for seq in sequences] | |
if fixed_length is not None: | |
max_length = fixed_length | |
else: | |
max_length = max(lengths) | |
if isinstance(sequences[0], torch.Tensor): | |
assert "torch" in str(dtype), "dtype and input type does not match" | |
padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device) | |
mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device) | |
else: # np | |
assert "numpy" in str(dtype), "dtype and input type does not match" | |
padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype) | |
mask = np.zeros((len(sequences), max_length), dtype=np.float32) | |
for idx, seq in enumerate(sequences): | |
end = lengths[idx] | |
padded_seqs[idx, :end] = seq | |
mask[idx, :end] = 1 | |
return padded_seqs, mask # , lengths |