Spaces:
Runtime error
Runtime error
import hashlib | |
import itertools | |
import json | |
import logging | |
import math | |
import random | |
import os | |
import tempfile | |
import time | |
import einops | |
from torch.nn.utils.rnn import pad_sequence | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
from .data import ImageRewardDataset, RankingDataset | |
from open_clip import get_cast_dtype, CLIP, CustomTextCLIP | |
from .distributed import is_master, barrier | |
from .zero_shot import zero_shot_eval | |
from .precision import get_autocast | |
from ..open_clip.loss import PreferenceLoss, RankingLoss, HPSLoss | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def postprocess_clip_output(model_out): | |
return { | |
"image_features": model_out[0], | |
"text_features": model_out[1], | |
"logit_scale": model_out[2] | |
} | |
def unwrap_model(model): | |
if hasattr(model, 'module'): | |
return model.module | |
else: | |
return model | |
def backward(total_loss, scaler): | |
if scaler is not None: | |
scaler.scale(total_loss).backward() | |
else: | |
total_loss.backward() | |
def random_sampling_iterator(iterators, sampling_ratios, data_types, num_iters): | |
iterators = [iter(iterator) for iterator in iterators] | |
num_iterators = len(iterators) | |
loop_counter = 0 | |
while loop_counter < num_iters: | |
current_state = random.getstate() | |
random.seed(loop_counter) | |
iterator_idx = random.choices(range(num_iterators), sampling_ratios)[0] | |
random.setstate(current_state) | |
yield next(iterators[iterator_idx]), data_types[iterator_idx] | |
loop_counter += 1 | |
def train_iters(model, data, iterations, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): | |
device = torch.device(args.device) | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
model.train() | |
ce_loss = PreferenceLoss() | |
mse_loss = torch.nn.MSELoss() | |
rk_loss = RankingLoss() | |
hps_loss = HPSLoss() | |
if args.distill: | |
dist_model.eval() | |
for train_set in data['train']: | |
train_set.set_epoch(0) # set epoch in process safe manner via sampler or shared_epoch | |
data_types = [d.data_type for d in data['train']] | |
train_data_sample_ratios = [sample_ratio for sample_ratio, ignore in zip(args.train_data_sample_ratio, args.ignore_in_train) if not ignore] | |
dataloader = random_sampling_iterator([dataset.dataloader for dataset in data['train']], train_data_sample_ratios, data_types, iterations) | |
sample_digits = math.ceil(math.log(sum([dataset.dataloader.num_samples for dataset in data['train']]) + 1, 10)) | |
losses_m = {} | |
batch_time_m = AverageMeter() | |
data_time_m = AverageMeter() | |
end = time.time() | |
for step, (batch, data_type) in enumerate(dataloader): | |
# TODO: currently only test on accum_freq==1 | |
if not args.skip_scheduler: | |
scheduler(step) | |
if data_type == 'preference': | |
images, num_images, labels, texts = batch | |
texts = texts.to(device=device, non_blocking=True) | |
elif data_type == 'rating': | |
images, labels = batch | |
elif data_type == 'regional': | |
images, labels = batch | |
elif data_type == 'ranking': | |
images, num_images, labels, texts = batch | |
texts = texts.to(device=device, non_blocking=True) | |
elif data_type == 'HPD': | |
images, labels, texts = batch | |
# num_per_prompts = num_per_prompts.to(device=device, non_blocking=True) | |
texts = texts.to(device=device, non_blocking=True) | |
images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
labels = labels.to(device=device, non_blocking=True) | |
data_time_m.update(time.time() - end) | |
optimizer.zero_grad() | |
if args.accum_freq == 1: | |
with autocast(): | |
if data_type == 'rating' or args.no_text_condition: | |
image_features = unwrap_model(model).visual(images) | |
scores = unwrap_model(model).score_predictor(image_features) | |
if args.no_text_condition: | |
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999) | |
total_loss = F.cross_entropy(paired_logits, labels) | |
else: | |
total_loss = mse_loss(scores.squeeze(), labels.to(scores.dtype)) | |
elif data_type == 'preference' : | |
output = model(images, texts) | |
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
# total_loss = loss(image_features, text_features, logit_scale) | |
logits_per_image = logit_scale * image_features @ text_features.T | |
total_loss = ce_loss(logits_per_image, num_images, labels) | |
elif data_type == 'HPD': | |
output = model(images, texts) | |
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
logits_per_text = logit_scale * text_features @ image_features.T | |
total_loss = hps_loss(logits_per_text, labels) | |
elif data_type == 'ranking': | |
output = model(images, texts) | |
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
# logits_per_image = logit_scale * image_features @ text_features.T | |
score = logit_scale * image_features @ text_features.T | |
total_loss = rk_loss(score, num_images, labels, args.margin) | |
elif data_type == 'regional': | |
# logit_scale = model.logit_scale | |
feature_map = unwrap_model(model).visual(images, skip_pool=True)[:, 1:] | |
logits = unwrap_model(model).region_predictor(feature_map) | |
wh = int(math.sqrt(feature_map.size(1))) | |
ps = images.size(2) // wh | |
logits = logits.unflatten(1, (wh, wh))[:,:,:,0] | |
# downsample the labels to match the feature map size | |
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps) | |
patches = (patches > 0).float() | |
total_loss = mse_loss(logits.sigmoid(), patches.to(patches.dtype)) | |
backward(total_loss, scaler) | |
losses = dict(total_loss=total_loss) | |
if scaler is not None: | |
if args.horovod: | |
optimizer.synchronize() | |
scaler.unscale_(optimizer) | |
if args.grad_clip_norm is not None: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
with optimizer.skip_synchronize(): | |
scaler.step(optimizer) | |
else: | |
if args.grad_clip_norm is not None: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
if args.grad_clip_norm is not None: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
optimizer.step() | |
# Note: we clamp to 4.6052 = ln(100), as in the original paper. | |
with torch.no_grad(): | |
unwrap_model(model).logit_scale.clamp_(0, math.log(100)) | |
batch_time_m.update(time.time() - end) | |
end = time.time() | |
batch_count = step + 1 | |
if is_master(args) and (step % args.log_every_n_steps == 0 or batch_count == iterations): | |
batch_size = len(images) | |
num_samples = batch_count * args.accum_freq | |
percent_complete = 100.0 * batch_count / iterations | |
# NOTE loss is coarsely sampled, just master node and per log update | |
for key, val in losses.items(): | |
if key not in losses_m: | |
losses_m[key] = AverageMeter() | |
losses_m[key].update(val.item(), batch_size) | |
logit_scale_scalar = unwrap_model(model).logit_scale.item() | |
loss_log = " ".join( | |
[ | |
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" | |
for loss_name, loss_m in losses_m.items() | |
] | |
) | |
samples_per_second = args.accum_freq * args.world_size / batch_time_m.val | |
samples_per_second_per_gpu = args.accum_freq / batch_time_m.val | |
logging.info( | |
f"Train iterations: [{num_samples:>{sample_digits}}/{iterations} ({percent_complete:.0f}%)] " | |
f"Data (t): {data_time_m.avg:.3f} " | |
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " | |
f"LR: {optimizer.param_groups[0]['lr']:5f} " | |
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log | |
) | |
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing | |
log_data = { | |
"data_time": data_time_m.val, | |
"batch_time": batch_time_m.val, | |
"samples_per_second": samples_per_second, | |
"samples_per_second_per_gpu": samples_per_second_per_gpu, | |
"scale": logit_scale_scalar, | |
"lr": optimizer.param_groups[0]["lr"] | |
} | |
log_data.update({name:val.val for name,val in losses_m.items()}) | |
for name, val in log_data.items(): | |
name = "train/" + name | |
if tb_writer is not None: | |
tb_writer.add_scalar(name, val, step) | |
# resetting batch / data time meters per log window | |
batch_time_m.reset() | |
data_time_m.reset() | |
def evaluate_preference(model, data, args): | |
model = unwrap_model(model) | |
model.eval() | |
dataloader = data.dataloader | |
samples_per_val = dataloader.num_samples | |
device = torch.device(args.device) | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
total = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for i, batch in enumerate(dataloader): | |
if i % args.world_size != args.rank: | |
continue | |
images, num_images, labels, texts = batch | |
images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
texts = texts.to(device=device, non_blocking=True) | |
with autocast(): | |
if args.no_text_condition: | |
image_features = model.visual(images) | |
logit_scale = model.logit_scale | |
scores = model.score_predictor(image_features) | |
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
else: | |
outputs = model(images, texts) | |
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"] | |
logits_per_image = logit_scale * image_features @ text_features.T | |
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] | |
predicted = torch.tensor([k.argmax().item() for k in paired_logits_list]) | |
correct += (predicted == labels).int().sum().item() | |
total += predicted.numel() | |
# write to a temp file | |
file_name = hashlib.md5(str(args.name).encode()).hexdigest() | |
with open(f"{file_name}_{args.rank}.json", "w") as f: | |
json.dump(dict( | |
correct=correct, | |
total=total, | |
), f) | |
time.sleep(0.1) | |
barrier(args) | |
correct = 0 | |
total = 0 | |
if is_master(args): | |
for i in range(args.world_size): | |
with open(f"{file_name}_{i}.json", "r") as f: | |
data = json.load(f) | |
correct += data["correct"] | |
total += data["total"] | |
os.remove(f"{file_name}_{i}.json") | |
logging.info( | |
f"Final Acc: {correct / total:.4f}\t") | |
return correct / (total + 1e-6) | |
def evaluate_regional(model, data, args): | |
dataloader = data.dataloader | |
samples_per_val = dataloader.num_samples | |
device = torch.device(args.device) | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
num_samples = len(dataloader) | |
threshold = 0.5 | |
with torch.no_grad(): | |
score = 0 | |
total = 0 | |
for i, batch in enumerate(dataloader): | |
images, labels = batch | |
images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
labels = labels.to(device=device, non_blocking=True) | |
with autocast(): | |
feature_map = model.visual(images, skip_pool=True)[:, 1:] | |
logits = model.region_predictor(feature_map) | |
wh = int(math.sqrt(feature_map.size(1))) | |
ps = images.size(2) // wh | |
logits = logits.unflatten(1, (wh, wh))[:,:,:,0] | |
# downsample the labels to match the feature map size | |
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps) | |
patches = (patches > 0).float() | |
pred_mask = (logits.sigmoid() > threshold).float() | |
#calc IOU | |
intersection = (pred_mask * patches).sum() | |
union = pred_mask.sum() + patches.sum() - intersection | |
iou_score = intersection / union | |
score += iou_score | |
total += 1 | |
if is_master(args) and (i % 100) == 0: | |
logging.info( | |
# f"[{i} / {samples_per_val}]\t" | |
f"[{i} / {len(dataloader)}]\t" | |
f"Current IoU: {score / (total + 0.001):.4f}\t") | |
if is_master(args): | |
logging.info( | |
f"Final IoU: {score / (total + 0.001):.4f}\t") | |
return score / (total + 0.001) | |
def inversion_score(p1, p2): | |
assert len(p1) == len(p2), f'{len(p1)}, {len(p2)}' | |
n = len(p1) | |
cnt = 0 | |
for i in range(n-1): | |
for j in range(i+1, n): | |
if p1[i] > p1[j] and p2[i] < p2[j]: | |
cnt += 1 | |
elif p1[i] < p1[j] and p2[i] > p2[j]: | |
cnt += 1 | |
return 1 - cnt / (n * (n - 1) / 2) | |
def model_pair_score(score:dict, p1, p2, num_image): | |
model_pairs = set() | |
for i in range(num_image): | |
if i not in score.keys(): | |
score[i] = {} | |
for j in range(num_image): | |
if j not in score[i].keys(): | |
score[i][j] = 0 | |
if j == i or (i, j) in model_pairs or (j, i) in model_pairs: | |
continue | |
model_pairs.add((i,j)) | |
if (p1[i] - p1[j]) * (p2[i] - p2[j]) > 0: | |
score[i][j] += 1 | |
return score | |
def all_gather(tensor): | |
world_size = torch.distributed.get_world_size() | |
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] | |
torch.distributed.all_gather(tensor_list, tensor, async_op=False) | |
return torch.cat(tensor_list, dim=0) | |
def evaluate_ranking(model, data, args): | |
model = unwrap_model(model) | |
model.eval() | |
dataloader = data.dataloader | |
samples_per_val = dataloader.num_samples | |
device = torch.device(args.device) | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
score = 0 | |
# pair_score = {} | |
with torch.no_grad(): | |
for i, batch in enumerate(dataloader): | |
if i % args.world_size != args.local_rank: | |
continue | |
images, num_images, labels, texts = batch | |
images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
texts = texts.to(device=device, non_blocking=True) | |
num_images = num_images.to(device=device, non_blocking=True) | |
labels = labels.to(device=device, non_blocking=True) | |
with autocast(): | |
if args.no_text_condition: | |
image_features = model.visual(images) | |
logit_scale = model.logit_scale | |
scores = model.score_predictor(image_features) | |
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
else: | |
outputs = model(images, texts) | |
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"] | |
logits_per_image = logit_scale * image_features @ text_features.T | |
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] | |
predicted = [torch.argsort(-k) for k in paired_logits_list] | |
hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)] | |
labels = [label for label in labels.split(num_images.tolist())] | |
if isinstance(dataloader.dataset, RankingDataset): | |
score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))]) | |
elif isinstance(dataloader.dataset, ImageRewardDataset): | |
score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))]) | |
# write score to a tempfile, file name is a hash string | |
file_name = hashlib.md5(str(args.name).encode()).hexdigest() | |
with open(f"{file_name}_{args.rank}.tmp", "w") as f: | |
f.write(str(score)) | |
time.sleep(0.1) | |
barrier(args) | |
score = 0 | |
if is_master(args): | |
for i in range(args.world_size): | |
with open(f"{file_name}_{i}.tmp", "r") as f: | |
score += float(f.read()) | |
os.remove(f"{file_name}_{i}.tmp") | |
score = score / samples_per_val | |
logging.info( | |
f"Final Acc: {score:.4f}\t") | |
# return score, pair_score | |
return score | |
def calc_ImageReward( pred, gt): | |
# using inversion score calculate method in ImageReward | |
# There's some little difference because ImageReward benchmark has tie rankings | |
tol_cnt = 0. | |
true_cnt = 0. | |
for idx in range(len(gt)): | |
item_base = gt | |
item = pred | |
for i in range(len(item_base)): | |
for j in range(i+1, len(item_base)): | |
if item_base[i] > item_base[j]: | |
if item[i] >= item[j]: | |
tol_cnt += 1 | |
elif item[i] < item[j]: | |
tol_cnt += 1 | |
true_cnt += 1 | |
elif item_base[i] < item_base[j]: | |
if item[i] > item[j]: | |
tol_cnt += 1 | |
true_cnt += 1 | |
elif item[i] <= item[j]: | |
tol_cnt += 1 | |
return true_cnt / tol_cnt | |
def get_clip_metrics(image_features, text_features, logit_scale): | |
metrics = {} | |
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() | |
logits_per_text = logits_per_image.t().detach().cpu() | |
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} | |
ground_truth = torch.arange(len(text_features)).view(-1, 1) | |
for name, logit in logits.items(): | |
ranking = torch.argsort(logit, descending=True) | |
preds = torch.where(ranking == ground_truth)[1] | |
preds = preds.detach().cpu().numpy() | |
metrics[f"{name}_mean_rank"] = preds.mean() + 1 | |
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 | |
for k in [1, 5, 10]: | |
metrics[f"{name}_R@{k}"] = np.mean(preds < k) | |
return metrics | |
def maybe_compute_generative_loss(model_out): | |
if "logits" in model_out and "labels" in model_out: | |
token_logits = model_out["logits"] | |
token_labels = model_out["labels"] | |
return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) | |