Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import logging | |
import time | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from einops import rearrange | |
from models.criterions import get_sim | |
from utils.basic_utils import MetricLogger | |
from utils.distributed import get_rank, get_world_size | |
logger = logging.getLogger(__name__) | |
def extract_text_feats(texts, max_txt_l, tokenizer, model, device, return_ids=False): | |
num_text = len(texts) | |
text_bs = 256 | |
text_feats = [] | |
text_atts = [] | |
if return_ids: | |
text_ids = [] | |
for i in range(0, num_text, text_bs): | |
text = texts[i : min(num_text, i + text_bs)] | |
text_input = tokenizer( | |
text, | |
padding="max_length", | |
truncation=True, | |
max_length=max_txt_l, | |
return_tensors="pt", | |
).to(device) # NOTE not need to cast | |
text_feat = model.encode_text(text_input)[0] | |
text_feats.append(text_feat) | |
text_atts.append(text_input.attention_mask) | |
if return_ids: | |
text_ids.append(text_input.input_ids) | |
text_feats = torch.cat(text_feats, dim=0) | |
text_atts = torch.cat(text_atts, dim=0) | |
if return_ids: | |
text_ids = torch.cat(text_ids, dim=0) | |
return text_feats, text_atts, text_ids | |
else: | |
return text_feats, text_atts | |
def extract_vision_feats(data_loader, model, device, config): | |
if config.use_half_precision: | |
if config.get('use_bf16', False): | |
cast_dtype = torch.bfloat16 | |
else: | |
cast_dtype = torch.float16 | |
else: | |
cast_dtype = None | |
image_feats_all = [] | |
pooled_image_feats_all = [] | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "extracting image feats" | |
iterator = metric_logger.log_every(data_loader, 100, header) | |
for image, img_id in iterator: | |
image = image.to(device, dtype=cast_dtype, non_blocking=True) | |
image_feat, pooled_image_feat = model.encode_vision(image, test=True) | |
if len(pooled_image_feat.shape) == 2: | |
pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy | |
if config.evaluation.eval_frame_ensemble == "concat": | |
if len(image_feat.shape) == 4: | |
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() | |
image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d) | |
else: | |
assert config.video_input.num_frames == 1, "only support single-frame" | |
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] | |
if config.evaluation.eval_offload: | |
image_feats_all.append(image_feat.cpu()) | |
pooled_image_feats_all.append(pooled_image_feat.cpu()) | |
else: | |
image_feats_all.append(image_feat) | |
pooled_image_feats_all.append(pooled_image_feat) | |
image_feats_all = torch.cat(image_feats_all, dim=0) | |
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) | |
return image_feats_all, pooled_image_feats_all | |
def extract_audio_feats(data_loader, model, device, config): | |
if config.use_half_precision: | |
if config.get('use_bf16', False): | |
cast_dtype = torch.bfloat16 | |
else: | |
cast_dtype = torch.float16 | |
else: | |
cast_dtype = None | |
audio_feats_all = [] | |
pooled_audio_feats_all = [] | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "extracting audio feats" | |
iterator = metric_logger.log_every(data_loader, 100, header) | |
for audio, _ in iterator: | |
audio = audio.to(device, dtype=cast_dtype, non_blocking=True) | |
audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True) | |
audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy | |
pooled_audio_feat = pooled_audio_feat.unsqueeze(1) | |
if config.evaluation.eval_offload: | |
audio_feats_all.append(audio_feat.cpu()) | |
pooled_audio_feats_all.append(pooled_audio_feat.cpu()) | |
else: | |
audio_feats_all.append(audio_feat) | |
pooled_audio_feats_all.append(pooled_audio_feat) | |
audio_feats_all = torch.cat(audio_feats_all, dim=0) | |
pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0) | |
return audio_feats_all, pooled_audio_feats_all | |
def extract_audio_vision_feats(data_loader, model, device, config): | |
if config.use_half_precision: | |
if config.get('use_bf16', False): | |
cast_dtype = torch.bfloat16 | |
else: | |
cast_dtype = torch.float16 | |
else: | |
cast_dtype = None | |
audio_feats_all = [] | |
pooled_audio_feats_all = [] | |
image_feats_all = [] | |
pooled_image_feats_all = [] | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "extracting audio and vision feats" | |
iterator = metric_logger.log_every(data_loader, 100, header) | |
for media, _ in iterator: | |
audio = media[0] | |
image = media[1] | |
audio = audio.to(device, dtype=cast_dtype, non_blocking=True) | |
image = image.to(device, dtype=cast_dtype, non_blocking=True) | |
audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True) | |
audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy | |
pooled_audio_feat = pooled_audio_feat.unsqueeze(1) | |
image_feat, pooled_image_feat = model.encode_vision(image, test=True) | |
if len(pooled_image_feat.shape) == 2: | |
pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy | |
if config.evaluation.eval_frame_ensemble == "concat": | |
if len(image_feat.shape) == 4: | |
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() | |
image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d) | |
else: | |
assert config.video_input.num_frames == 1, "only support single-frame" | |
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] | |
if config.evaluation.eval_offload: | |
audio_feats_all.append(audio_feat.cpu()) | |
pooled_audio_feats_all.append(pooled_audio_feat.cpu()) | |
image_feats_all.append(image_feat.cpu()) | |
pooled_image_feats_all.append(pooled_image_feat.cpu()) | |
else: | |
audio_feats_all.append(audio_feat) | |
pooled_audio_feats_all.append(pooled_audio_feat) | |
image_feats_all.append(image_feat) | |
pooled_image_feats_all.append(pooled_image_feat) | |
audio_feats_all = torch.cat(audio_feats_all, dim=0) | |
pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0) | |
image_feats_all = torch.cat(image_feats_all, dim=0) | |
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) | |
return audio_feats_all, pooled_audio_feats_all, image_feats_all, pooled_image_feats_all | |
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): | |
amp_eval_enabled = config.use_half_precision and not (hasattr(config, "deepspeed") and config.deepspeed.enable) | |
logger.info(f"Begin to eval, model_without_ddp.dtype={model.dtype if hasattr(model, 'dtype') else None}, amp_eval_enabled={amp_eval_enabled}, dtype={torch.bfloat16 if config.get('use_bf16', False) else torch.float16}") | |
with torch.cuda.amp.autocast(enabled=amp_eval_enabled, dtype=torch.bfloat16 if config.get('use_bf16', False) else torch.float16): | |
i2t_match, t2i_match = None, None | |
if "qformer" in config.model.model_cls.lower(): | |
i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_qformer( | |
model, data_loader, tokenizer, device, config | |
) | |
elif "blip" in config.model.model_cls.lower(): | |
raise NotImplementedError | |
elif "clip" in config.model.model_cls.lower() or 'coca' in config.model.model_cls.lower(): | |
# raise NotImplementedError | |
i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_clip( | |
model, data_loader, tokenizer, device, config | |
) | |
else: | |
i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation( | |
model, data_loader, tokenizer, device, config | |
) | |
if hasattr(data_loader.dataset, "num_prompts"): | |
np = data_loader.dataset.num_prompts | |
logger.info(f"Using {np} prompts, we need reshape and mean!!!") | |
nt = len(data_loader.dataset.text) // np | |
if i2t_match is not None: | |
i2t_match = i2t_match.reshape((i2t_match.shape[0], nt, np)).mean(axis=-1) | |
t2i_match = t2i_match.reshape((nt, np, t2i_match.shape[1])).mean(axis=1) | |
i2t_sim = i2t_sim.reshape((i2t_sim.shape[0], nt, np)).mean(axis=-1) | |
t2i_sim = t2i_sim.reshape((nt, np, t2i_sim.shape[1])).mean(axis=1) | |
i2t_dsl = i2t_dsl.reshape((i2t_dsl.shape[0], nt, np)).mean(axis=-1) | |
t2i_dsl = t2i_dsl.reshape((nt, np, t2i_dsl.shape[1])).mean(axis=1) | |
score_pairs = [ | |
(prefix + "_sim", i2t_sim, t2i_sim), | |
(prefix + "_dsl", i2t_dsl, t2i_dsl), | |
] | |
if i2t_match is not None: | |
if config.evaluation.get('use_dsl_for_match', False): | |
score_pairs.append((prefix + "_match (use_dsl)", i2t_match, t2i_match)) | |
else: | |
score_pairs.append((prefix + "_match", i2t_match, t2i_match)) | |
res = dict() | |
for name, i2t, t2i in score_pairs: | |
if i2t is not None: | |
txt2img_ids = data_loader.dataset.txt2img | |
img2txt_ids = data_loader.dataset.img2txt | |
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids) | |
return res | |
def evaluation(model, data_loader, tokenizer, device, config): | |
model.eval() | |
use_dsl_for_match = config.evaluation.get('use_dsl_for_match', False) | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "Evaluation:" | |
dtype = torch.half if config.use_half_precision else torch.float | |
media_type = data_loader.dataset.media_type | |
use_subtitle = hasattr(data_loader.dataset, "use_subtitle") and data_loader.dataset.use_subtitle | |
if use_subtitle: | |
assert media_type in ["video", "audio_video"], f"Not support media_type: {media_type}." | |
assert hasattr(data_loader.dataset, "subtitle") and data_loader.dataset.subtitle is not None, "You don't have subtitle to use." | |
logger.info(f"Start evaluation for media_type={media_type}") | |
assert media_type in ['audio', 'video', 'audio_video'], f"Not implement evaluation of {media_type}" | |
logger.info("Computing dual encoder features...") | |
start_time = time.time() | |
# this computes all features in each GPU | |
texts = data_loader.dataset.text | |
# max_txt_l of eval depends on data_cofig | |
max_txt_l = data_loader.dataset.max_txt_l | |
text_feats, text_atts = extract_text_feats( | |
texts, max_txt_l, tokenizer, model, device | |
) # (bsz, Lt, d), (bsz, Lt) | |
if use_subtitle: | |
subtitle_feats, _ = extract_text_feats( | |
data_loader.dataset.subtitle, max_txt_l, tokenizer, model, device | |
) # (bsz, Lt, d), (bsz, Lt) | |
subtitle_proj = model.text_proj(subtitle_feats[:, 0]).unsqueeze(1) | |
subtitle_feats = subtitle_feats.unsqueeze(1) | |
if media_type == 'video': | |
image_feats, pooled_image_feats = extract_vision_feats( | |
data_loader, model, device, config | |
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d) | |
logger.info("Finished vision feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
if config.evaluation.eval_offload: | |
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!! | |
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True) | |
if use_subtitle: | |
# print(subtitle_proj.shape, pooled_image_feats.shape) | |
i2t_scores, t2i_scores = get_sim( | |
model.vs_fusion(torch.concat([subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0]) | |
) | |
else: | |
i2t_scores, t2i_scores = get_sim( | |
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0]) | |
) | |
if use_dsl_for_match: | |
logger.info("use_dsl_for_match!!!") | |
old_i2t_scores, old_t2i_scores = i2t_scores, t2i_scores | |
i2t_scores = old_i2t_scores * old_i2t_scores.softmax(dim=0) | |
t2i_scores = old_i2t_scores.T * old_i2t_scores.T.softmax(dim=0) | |
num_medias = len(data_loader.dataset.image) | |
# pooled_media_feats = pooled_image_feats | |
if use_subtitle: | |
media_feats = torch.concat([subtitle_feats, image_feats], dim=-2) | |
if hasattr(model, "vstm_head"): | |
match_head = model.vstm_head | |
else: | |
match_head = None | |
else: | |
media_feats = image_feats | |
if hasattr(model, "itm_head"): | |
match_head = model.itm_head | |
else: | |
match_head = None | |
elif media_type == 'audio': | |
audio_feats, pooled_audio_feats = extract_audio_feats( | |
data_loader, model, device, config | |
) | |
logger.info("Finished audio feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
if config.evaluation.eval_offload: | |
pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True) | |
i2t_scores, t2i_scores = get_sim( | |
model.audio_proj(pooled_audio_feats), model.text_proj(text_feats[:, 0]) | |
) | |
num_medias = len(data_loader.dataset.audio) | |
media_feats = audio_feats | |
# pooled_media_feats = pooled_audio_feats | |
if hasattr(model, "atm_head"): | |
match_head = model.atm_head | |
else: | |
match_head = None | |
elif media_type == 'audio_video': | |
audio_feats, pooled_audio_feats, image_feats, pooled_image_feats = extract_audio_vision_feats( | |
data_loader, model, device, config | |
) | |
logger.info("Finished audio and vision feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
if config.evaluation.eval_offload: | |
pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True) | |
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True) | |
if use_subtitle: | |
i2t_scores, t2i_scores = get_sim( | |
model.avs_fusion(torch.concat([model.audio_proj(pooled_audio_feats), subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0]) | |
) | |
else: | |
i2t_scores, t2i_scores = get_sim( | |
model.av_fusion(torch.concat([model.audio_proj(pooled_audio_feats), model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0]) | |
) | |
num_medias = len(data_loader.dataset.image) | |
if use_subtitle: | |
media_feats = torch.concat([audio_feats, subtitle_feats, image_feats], dim=-2) | |
# pooled_media_feats = pooled_audio_feats | |
if hasattr(model, "avstm_head"): | |
match_head = model.avstm_head | |
else: | |
match_head = None | |
else: | |
media_feats = torch.concat([audio_feats, image_feats], dim=-2) | |
# pooled_media_feats = pooled_audio_feats | |
if hasattr(model, "avtm_head"): | |
match_head = model.avtm_head | |
else: | |
match_head = None | |
else: | |
raise NotImplementedError(media_type) | |
logger.info("Computing ITC scores [dot-product], done!") | |
if match_head is not None: | |
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
# computes only part of the scores at each GPU, gather at the end | |
logger.info("Rerank dual-encoder results with cross-encoder...") | |
num_tasks = get_world_size() | |
rank = get_rank() | |
# only uses the part associated with the raw eval set | |
# compute media2text # | |
step = num_medias // num_tasks + 1 | |
start = rank * step | |
end = min(num_medias, start + step) | |
text_encoder = model.get_text_encoder() | |
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) | |
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1] | |
) | |
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}" | |
logger.info( | |
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}" | |
) | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
if config.deep_fusion: | |
encoder_output = [ | |
feat[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[start + i, clip_idx] | |
for feat in media_feats | |
] | |
else: | |
encoder_output = ( | |
media_feats[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else media_feats[start + i, clip_idx] | |
) # (#frm*Li, d) | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
if config.deep_fusion: | |
if len(topk_idx) % bs != 0: | |
left = len(topk_idx) % bs | |
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output] | |
left_encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in left_encoder_output | |
] | |
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in encoder_output | |
] | |
else: | |
if len(topk_idx) % bs != 0: | |
left = len(topk_idx) % bs | |
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d) | |
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) | |
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for j in range(0, len(topk_idx), bs): | |
if j + bs > len(topk_idx): | |
output = text_encoder( | |
encoder_embeds=text_feats[topk_idx[j:]], | |
attention_mask=text_atts[topk_idx[j:]], | |
encoder_hidden_states=left_encoder_output, | |
encoder_attention_mask=left_encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
else: | |
output = text_encoder( | |
encoder_embeds=text_feats[topk_idx[j : j + bs]], | |
attention_mask=text_atts[topk_idx[j : j + bs]], | |
encoder_hidden_states=encoder_output, | |
encoder_attention_mask=encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
batch_itm_embeds = output.last_hidden_state[:, 0] | |
itm_embeds.append(batch_itm_embeds) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
# end new | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) | |
# compute text2media # | |
num_text = len(data_loader.dataset.text) | |
t2i_scores_x = torch.full((num_text, num_medias), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
step = num_text // num_tasks + 1 | |
start = rank * step | |
end = min(num_text, start + step) | |
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) | |
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1] | |
) | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
for j in range(0, len(topk_idx), bs): | |
if config.deep_fusion: | |
encoder_output = [ | |
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[topk_idx[j : j + bs], clip_idx] | |
for feat in media_feats | |
] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in encoder_output | |
] | |
else: | |
encoder_output = ( | |
media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to( | |
device, non_blocking=True | |
) | |
if config.evaluation.eval_offload | |
else media_feats[topk_idx[j : j + bs], clip_idx] | |
) | |
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
repeat_n = ( | |
encoder_output.shape[0] | |
if not config.deep_fusion | |
else encoder_output[0].shape[0] | |
) | |
output = text_encoder( | |
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1), | |
attention_mask=text_atts[start + i].repeat(repeat_n, 1), | |
encoder_hidden_states=encoder_output, | |
encoder_attention_mask=encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
batch_itm_embeds = output.last_hidden_state[:, 0] | |
itm_embeds.append(batch_itm_embeds) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
# end new | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) | |
logger.info("Compute over!!!") | |
if config.distributed: | |
logger.info("Gather across GPUs!!!") | |
# gather across GPUs | |
dist.barrier() | |
logger.info("dist.barrier()!!!") | |
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
if use_dsl_for_match: | |
i2t_scores_dsl = i2t_scores | |
i2t_scores_dsl_T = t2i_scores | |
i2t_scores = old_i2t_scores | |
t2i_scores = old_t2i_scores | |
else: | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
else: | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info(f"Evaluation time {total_time_str}") | |
if match_head is not None: | |
return ( | |
i2t_scores_x.softmax(dim=1).cpu().float().numpy(), | |
t2i_scores_x.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
else: | |
return ( | |
None, | |
None, | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
def evaluation_simple(model, data_loader, tokenizer, device, config): | |
model.eval() | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "Evaluation:" | |
media_type = data_loader.dataset.media_type | |
logger.info(f"Start evaluation for media_type={media_type}") | |
assert media_type in ['video'], f"Not implement evaluation of {media_type}" | |
logger.info("Computing dual encoder features...") | |
start_time = time.time() | |
# this computes all features in each GPU | |
texts = data_loader.dataset.text | |
# max_txt_l of eval depends on data_cofig | |
max_txt_l = data_loader.dataset.max_txt_l | |
text_feats, text_atts = extract_text_feats( | |
texts, max_txt_l, tokenizer, model, device | |
) # (bsz, Lt, d), (bsz, Lt) | |
if media_type == 'video': | |
image_feats, pooled_image_feats = extract_vision_feats( | |
data_loader, model, device, config | |
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d) | |
logger.info("Finished vision feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
if config.evaluation.eval_offload: | |
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!! | |
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True) | |
i2t_scores, t2i_scores = get_sim( | |
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0]) | |
) | |
num_medias = len(data_loader.dataset.image) | |
media_feats = image_feats | |
if hasattr(model, "itm_head"): | |
match_head = model.itm_head | |
else: | |
match_head = None | |
else: | |
raise NotImplementedError(media_type) | |
logger.info("Computing ITC scores [dot-product], done!") | |
if match_head is not None: | |
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
# computes only part of the scores at each GPU, gather at the end | |
logger.info("Rerank dual-encoder results with cross-encoder...") | |
num_tasks = get_world_size() | |
rank = get_rank() | |
# only uses the part associated with the raw eval set | |
# compute media2text # | |
step = num_medias // num_tasks + 1 | |
start = rank * step | |
end = min(num_medias, start + step) | |
text_encoder = model.get_text_encoder() | |
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) | |
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1] | |
) | |
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}" | |
logger.info( | |
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}" | |
) | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
if config.deep_fusion: | |
encoder_output = [ | |
feat[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[start + i, clip_idx] | |
for feat in media_feats | |
] | |
else: | |
encoder_output = ( | |
media_feats[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else media_feats[start + i, clip_idx] | |
) # (#frm*Li, d) | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
if config.deep_fusion: | |
if len(topk_idx) % bs != 0: | |
left = len(topk_idx) % bs | |
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output] | |
left_encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in left_encoder_output | |
] | |
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in encoder_output | |
] | |
else: | |
if len(topk_idx) % bs != 0: | |
left = len(topk_idx) % bs | |
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d) | |
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) | |
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for j in range(0, len(topk_idx), bs): | |
if j + bs > len(topk_idx): | |
output = text_encoder( | |
encoder_embeds=text_feats[topk_idx[j:]], | |
attention_mask=text_atts[topk_idx[j:]], | |
encoder_hidden_states=left_encoder_output, | |
encoder_attention_mask=left_encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
else: | |
output = text_encoder( | |
encoder_embeds=text_feats[topk_idx[j : j + bs]], | |
attention_mask=text_atts[topk_idx[j : j + bs]], | |
encoder_hidden_states=encoder_output, | |
encoder_attention_mask=encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
batch_itm_embeds = output.last_hidden_state[:, 0] | |
itm_embeds.append(batch_itm_embeds) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
# end new | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) | |
# compute text2media # | |
num_text = len(data_loader.dataset.text) | |
t2i_scores_x = torch.full((num_text, num_medias), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
step = num_text // num_tasks + 1 | |
start = rank * step | |
end = min(num_text, start + step) | |
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) | |
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1] | |
) | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
for j in range(0, len(topk_idx), bs): | |
if config.deep_fusion: | |
encoder_output = [ | |
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[topk_idx[j : j + bs], clip_idx] | |
for feat in media_feats | |
] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in encoder_output | |
] | |
else: | |
encoder_output = ( | |
media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to( | |
device, non_blocking=True | |
) | |
if config.evaluation.eval_offload | |
else media_feats[topk_idx[j : j + bs], clip_idx] | |
) | |
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
repeat_n = ( | |
encoder_output.shape[0] | |
if not config.deep_fusion | |
else encoder_output[0].shape[0] | |
) | |
output = text_encoder( | |
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1), | |
attention_mask=text_atts[start + i].repeat(repeat_n, 1), | |
encoder_hidden_states=encoder_output, | |
encoder_attention_mask=encoder_att, | |
return_dict=True, | |
mode="fusion", | |
) | |
batch_itm_embeds = output.last_hidden_state[:, 0] | |
itm_embeds.append(batch_itm_embeds) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
# end new | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) | |
logger.info("Compute over!!!") | |
if config.distributed: | |
logger.info("Gather across GPUs!!!") | |
# gather across GPUs | |
dist.barrier() | |
logger.info("dist.barrier()!!!") | |
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
else: | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info(f"Evaluation time {total_time_str}") | |
if match_head is not None: | |
return ( | |
i2t_scores_x.softmax(dim=1).cpu().float().numpy(), | |
t2i_scores_x.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
else: | |
return ( | |
None, | |
None, | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
def evaluation_qformer(model, data_loader, tokenizer, device, config): | |
model.eval() | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "Evaluation:" | |
dtype = torch.half if config.use_half_precision else torch.float | |
media_type = data_loader.dataset.media_type | |
logger.info(f"Start evaluation_qformer for media_type={media_type}") | |
assert media_type == 'video', f"Not implement evaluation of {media_type}" | |
logger.info("Computing dual encoder features...") | |
start_time = time.time() | |
# this computes all features in each GPU | |
texts = data_loader.dataset.text | |
# max_txt_l of eval depends on data_cofig | |
max_txt_l = data_loader.dataset.max_txt_l | |
text_feats, text_atts, text_ids = extract_text_feats( | |
texts, max_txt_l, tokenizer, model, device, return_ids=True | |
) # (bsz, Lt, d), (bsz, Lt) | |
if media_type == 'video': | |
image_feats, pooled_image_feats = extract_vision_feats( | |
data_loader, model, device, config | |
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d) | |
logger.info("Finished vision feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
if config.evaluation.eval_offload: | |
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!! | |
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True) | |
if hasattr(model, "q_vision_proj"): | |
i2t_scores, t2i_scores = get_sim( | |
model.q_vision_proj(pooled_image_feats), model.q_text_proj(text_feats[:, 0]) | |
) | |
else: | |
i2t_scores, t2i_scores = get_sim( | |
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0]) | |
) | |
num_medias = len(data_loader.dataset.image) | |
media_feats = image_feats | |
if hasattr(model, "itm_head"): | |
match_head = model.itm_head | |
elif hasattr(model, "q_itm_head"): | |
match_head = model.q_itm_head | |
else: | |
raise NotImplementedError("you must have a match head in qformer!!!") | |
logger.info("Computing ITC scores [dot-product], done!") | |
if match_head is not None: | |
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
# computes only part of the scores at each GPU, gather at the end | |
logger.info("Rerank dual-encoder results with cross-encoder...") | |
num_tasks = get_world_size() | |
rank = get_rank() | |
# only uses the part associated with the raw eval set | |
# compute image2text # | |
step = num_medias // num_tasks + 1 | |
start = rank * step | |
end = min(num_medias, start + step) | |
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) | |
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] | |
) | |
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}" | |
logger.info( | |
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}" | |
) | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
if config.deep_fusion: | |
encoder_output = [ | |
feat[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[start + i, clip_idx] | |
for feat in media_feats | |
] | |
else: | |
encoder_output = ( | |
image_feats[start + i, clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else image_feats[start + i, clip_idx] | |
) # (#frm*Li, d) | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
if not config.deep_fusion: # Create fake list | |
encoder_output = [encoder_output] | |
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True) | |
for feat in encoder_output | |
] | |
for j in range(0, len(topk_idx), bs): | |
cur_bs = min(bs, len(topk_idx) - j) | |
encoder_output = [feat[:cur_bs] for feat in encoder_output] | |
encoder_att = [att[:cur_bs] for att in encoder_att] | |
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] | |
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] | |
output = model.vtm_embed( | |
text_ids=text_ids[topk_idx[j:j+bs]], | |
text_atts=text_atts[topk_idx[j:j+bs]], | |
vision_embeds=batch_encoder_output, | |
vision_atts=batch_encoder_att, | |
) | |
itm_embeds.append(output) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) | |
# compute text2image # | |
num_text = len(data_loader.dataset.text) | |
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to( | |
device, torch.float, non_blocking=True | |
) | |
step = num_text // num_tasks + 1 | |
start = rank * step | |
end = min(num_text, start + step) | |
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) | |
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") | |
# generate score for each clip, and aggregate all clip scores for a video | |
n_clip_per_video = ( | |
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] | |
) | |
k = config.evaluation.k_test | |
logger.info(f"Top-{k} matching") | |
for i, sims in enumerate(iterator): | |
k = min(len(sims), config.evaluation.k_test) | |
topk_sim, topk_idx = sims.topk(k=k, dim=0) | |
clip_scores = [] | |
for clip_idx in range(n_clip_per_video): | |
# new | |
bs = 32 | |
# bs = config.batch_size_test.video | |
itm_embeds = [] | |
for j in range(0, len(topk_idx), bs): | |
fake_image_feats = [image_feats] if not config.deep_fusion else image_feats | |
encoder_output = [ | |
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True) | |
if config.evaluation.eval_offload | |
else feat[topk_idx[j : j + bs], clip_idx] | |
for feat in fake_image_feats | |
] | |
encoder_att = [ | |
torch.ones(feat.size()[:-1], dtype=torch.long).to( | |
device, non_blocking=True | |
) | |
for feat in encoder_output | |
] | |
cur_bs = min(bs, len(topk_idx) - j) | |
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] | |
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] | |
output = model.vtm_embed( | |
text_ids=text_ids[start + i].repeat(cur_bs, 1), | |
text_atts=text_atts[start + i].repeat(cur_bs, 1), | |
vision_embeds=batch_encoder_output, | |
vision_atts=batch_encoder_att, | |
) | |
itm_embeds.append(output) | |
itm_embeds = torch.cat(itm_embeds, dim=0) | |
# end new | |
score = match_head(itm_embeds)[:, 1] | |
clip_scores.append(score) | |
if len(clip_scores) == 1: | |
score = clip_scores[0] | |
else: | |
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}") | |
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) | |
logger.info("Compute over!!!") | |
if config.distributed: | |
logger.info("Gather across GPUs!!!") | |
# gather across GPUs | |
dist.barrier() | |
logger.info("dist.barrier()!!!") | |
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) | |
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!") | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
else: | |
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info(f"Evaluation time {total_time_str}") | |
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0) | |
if match_head is not None: | |
return ( | |
i2t_scores_x.softmax(dim=1).cpu().float().numpy(), | |
t2i_scores_x.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
else: | |
return ( | |
None, | |
None, | |
i2t_scores.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores.T.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(), | |
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy() | |
) | |
def evaluation_clip(model, data_loader, tokenizer, device, config): | |
model.eval() | |
metric_logger = MetricLogger(delimiter=" ") | |
header = "Evaluation:" | |
dtype = torch.half if config.use_half_precision else torch.float | |
media_type = data_loader.dataset.media_type | |
logger.info(f"Start evaluation_clip for media_type={media_type}") | |
logger.info("Computing dual encoder features...") | |
# this computes all features in each GPU | |
texts = data_loader.dataset.text | |
num_text = len(texts) | |
text_bs = 256 | |
text_feats = [] | |
for i in range(0, num_text, text_bs): | |
text = texts[i : min(num_text, i + text_bs)] | |
if "internvideo2" in config.model.model_cls.lower(): | |
text_feat = model.encode_text(tokenizer(text).to(device)) | |
else: | |
raise NotImplementedError | |
text_feat = model.encode_text(text) | |
text_feats.append(text_feat.cpu()) | |
text_feats = torch.cat(text_feats, dim=0) | |
logger.info("Finished computing text features") | |
media_feats = [] | |
metric_logger = MetricLogger(delimiter=" ") | |
header = f"extracting {media_type} feats!!!" | |
iterator = metric_logger.log_every(data_loader, 100, header) | |
for media, _ in iterator: | |
if media_type in ['image', 'video']: | |
media = media.to(device, non_blocking=True) | |
media_feat = model.encode_vision(media, test=True) | |
elif media_type == 'audio': | |
media = media.to(device, non_blocking=True) | |
media_feat = model.encode_audio(media, test=True) | |
elif media_type == 'audio_video': | |
raise NotImplementedError(f"Not implement media_type:{media_type}") | |
else: | |
raise NotImplementedError(f"Not implement media_type:{media_type}") | |
media_feats.append(media_feat.cpu()) | |
media_feats = torch.cat(media_feats, dim=0) | |
logger.info("Finished feature extraction") | |
logger.info("Computing ITC scores [dot-product]") | |
# print(media_feats.dtype, text_feats.dtype) | |
# print(media_feats.device, text_feats.device) | |
i2t_scores, t2i_scores = get_sim(media_feats.float(), text_feats.float()) | |
del media_feats, text_feats | |
logger.info("Computing ITC scores [dot-product], done!") | |
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0) | |
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0) | |
return ( | |
i2t_scores.cpu().float().numpy(), | |
i2t_scores.T.cpu().float().numpy(), | |
i2t_scores_dsl.cpu().float().numpy(), | |
i2t_scores_dsl_T.cpu().float().numpy(), | |
) | |
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): | |
# Images->Text | |
ranks = np.zeros(scores_i2t.shape[0]) | |
for index, score in enumerate(scores_i2t): | |
inds = np.argsort(score)[::-1] | |
# Score | |
gt_txt_ids = img2txt[index] | |
if isinstance(gt_txt_ids, int): | |
ranks[index] = np.where(inds == gt_txt_ids)[0][0] | |
else: | |
rank = 1e20 | |
for i in gt_txt_ids: | |
tmp = np.where(inds == i)[0][0] | |
if tmp < rank: | |
rank = tmp | |
ranks[index] = rank | |
# Compute metrics | |
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
# Text->Images | |
ranks = np.zeros(scores_t2i.shape[0]) | |
for index, score in enumerate(scores_t2i): | |
inds = np.argsort(score)[::-1] | |
gt_img_ids = txt2img[index] | |
if isinstance(gt_img_ids, int): | |
ranks[index] = np.where(inds == gt_img_ids)[0][0] | |
else: # list, used in the case each caption has multiple GT images | |
# Score | |
rank = 1e20 | |
for i in gt_img_ids: | |
tmp = np.where(inds == i)[0][0] | |
if tmp < rank: | |
rank = tmp | |
ranks[index] = rank | |
# Compute metrics | |
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
tr_mean = (tr1 + tr5 + tr10) / 3 | |
ir_mean = (ir1 + ir5 + ir10) / 3 | |
r_mean = (tr_mean + ir_mean) / 2 | |
eval_result = { | |
"v2t_r1": tr1, | |
"v2t_r5": tr5, | |
"v2t_r10": tr10, | |
"v2t_r_mean": tr_mean, | |
"t2v_r1": ir1, | |
"t2v_r5": ir5, | |
"t2v_r10": ir10, | |
"t2v_r_mean": ir_mean, | |
"r_mean": r_mean, | |
} | |
eval_result = {k: round(v, 2) for k, v in eval_result.items()} | |
return eval_result | |