HPSv2 / src /training /train.py
tgxs002's picture
init
54199b6
import hashlib
import itertools
import json
import logging
import math
import random
import os
import tempfile
import time
import einops
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel
from .data import ImageRewardDataset, RankingDataset
from open_clip import get_cast_dtype, CLIP, CustomTextCLIP
from .distributed import is_master, barrier
from .zero_shot import zero_shot_eval
from .precision import get_autocast
from ..open_clip.loss import PreferenceLoss, RankingLoss, HPSLoss
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
"text_features": model_out[1],
"logit_scale": model_out[2]
}
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
def random_sampling_iterator(iterators, sampling_ratios, data_types, num_iters):
iterators = [iter(iterator) for iterator in iterators]
num_iterators = len(iterators)
loop_counter = 0
while loop_counter < num_iters:
current_state = random.getstate()
random.seed(loop_counter)
iterator_idx = random.choices(range(num_iterators), sampling_ratios)[0]
random.setstate(current_state)
yield next(iterators[iterator_idx]), data_types[iterator_idx]
loop_counter += 1
def train_iters(model, data, iterations, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
model.train()
ce_loss = PreferenceLoss()
mse_loss = torch.nn.MSELoss()
rk_loss = RankingLoss()
hps_loss = HPSLoss()
if args.distill:
dist_model.eval()
for train_set in data['train']:
train_set.set_epoch(0) # set epoch in process safe manner via sampler or shared_epoch
data_types = [d.data_type for d in data['train']]
train_data_sample_ratios = [sample_ratio for sample_ratio, ignore in zip(args.train_data_sample_ratio, args.ignore_in_train) if not ignore]
dataloader = random_sampling_iterator([dataset.dataloader for dataset in data['train']], train_data_sample_ratios, data_types, iterations)
sample_digits = math.ceil(math.log(sum([dataset.dataloader.num_samples for dataset in data['train']]) + 1, 10))
losses_m = {}
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for step, (batch, data_type) in enumerate(dataloader):
# TODO: currently only test on accum_freq==1
if not args.skip_scheduler:
scheduler(step)
if data_type == 'preference':
images, num_images, labels, texts = batch
texts = texts.to(device=device, non_blocking=True)
elif data_type == 'rating':
images, labels = batch
elif data_type == 'regional':
images, labels = batch
elif data_type == 'ranking':
images, num_images, labels, texts = batch
texts = texts.to(device=device, non_blocking=True)
elif data_type == 'HPD':
images, labels, texts = batch
# num_per_prompts = num_per_prompts.to(device=device, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
with autocast():
if data_type == 'rating' or args.no_text_condition:
image_features = unwrap_model(model).visual(images)
scores = unwrap_model(model).score_predictor(image_features)
if args.no_text_condition:
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
total_loss = F.cross_entropy(paired_logits, labels)
else:
total_loss = mse_loss(scores.squeeze(), labels.to(scores.dtype))
elif data_type == 'preference' :
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
# total_loss = loss(image_features, text_features, logit_scale)
logits_per_image = logit_scale * image_features @ text_features.T
total_loss = ce_loss(logits_per_image, num_images, labels)
elif data_type == 'HPD':
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
logits_per_text = logit_scale * text_features @ image_features.T
total_loss = hps_loss(logits_per_text, labels)
elif data_type == 'ranking':
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
# logits_per_image = logit_scale * image_features @ text_features.T
score = logit_scale * image_features @ text_features.T
total_loss = rk_loss(score, num_images, labels, args.margin)
elif data_type == 'regional':
# logit_scale = model.logit_scale
feature_map = unwrap_model(model).visual(images, skip_pool=True)[:, 1:]
logits = unwrap_model(model).region_predictor(feature_map)
wh = int(math.sqrt(feature_map.size(1)))
ps = images.size(2) // wh
logits = logits.unflatten(1, (wh, wh))[:,:,:,0]
# downsample the labels to match the feature map size
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps)
patches = (patches > 0).float()
total_loss = mse_loss(logits.sigmoid(), patches.to(patches.dtype))
backward(total_loss, scaler)
losses = dict(total_loss=total_loss)
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = step + 1
if is_master(args) and (step % args.log_every_n_steps == 0 or batch_count == iterations):
batch_size = len(images)
num_samples = batch_count * args.accum_freq
percent_complete = 100.0 * batch_count / iterations
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
losses_m[key] = AverageMeter()
losses_m[key].update(val.item(), batch_size)
logit_scale_scalar = unwrap_model(model).logit_scale.item()
loss_log = " ".join(
[
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
for loss_name, loss_m in losses_m.items()
]
)
samples_per_second = args.accum_freq * args.world_size / batch_time_m.val
samples_per_second_per_gpu = args.accum_freq / batch_time_m.val
logging.info(
f"Train iterations: [{num_samples:>{sample_digits}}/{iterations} ({percent_complete:.0f}%)] "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"scale": logit_scale_scalar,
"lr": optimizer.param_groups[0]["lr"]
}
log_data.update({name:val.val for name,val in losses_m.items()})
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
def evaluate_preference(model, data, args):
model = unwrap_model(model)
model.eval()
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
total = 0
correct = 0
with torch.no_grad():
for i, batch in enumerate(dataloader):
if i % args.world_size != args.rank:
continue
images, num_images, labels, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
if args.no_text_condition:
image_features = model.visual(images)
logit_scale = model.logit_scale
scores = model.score_predictor(image_features)
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
else:
outputs = model(images, texts)
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
logits_per_image = logit_scale * image_features @ text_features.T
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
predicted = torch.tensor([k.argmax().item() for k in paired_logits_list])
correct += (predicted == labels).int().sum().item()
total += predicted.numel()
# write to a temp file
file_name = hashlib.md5(str(args.name).encode()).hexdigest()
with open(f"{file_name}_{args.rank}.json", "w") as f:
json.dump(dict(
correct=correct,
total=total,
), f)
time.sleep(0.1)
barrier(args)
correct = 0
total = 0
if is_master(args):
for i in range(args.world_size):
with open(f"{file_name}_{i}.json", "r") as f:
data = json.load(f)
correct += data["correct"]
total += data["total"]
os.remove(f"{file_name}_{i}.json")
logging.info(
f"Final Acc: {correct / total:.4f}\t")
return correct / (total + 1e-6)
def evaluate_regional(model, data, args):
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
num_samples = len(dataloader)
threshold = 0.5
with torch.no_grad():
score = 0
total = 0
for i, batch in enumerate(dataloader):
images, labels = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
with autocast():
feature_map = model.visual(images, skip_pool=True)[:, 1:]
logits = model.region_predictor(feature_map)
wh = int(math.sqrt(feature_map.size(1)))
ps = images.size(2) // wh
logits = logits.unflatten(1, (wh, wh))[:,:,:,0]
# downsample the labels to match the feature map size
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps)
patches = (patches > 0).float()
pred_mask = (logits.sigmoid() > threshold).float()
#calc IOU
intersection = (pred_mask * patches).sum()
union = pred_mask.sum() + patches.sum() - intersection
iou_score = intersection / union
score += iou_score
total += 1
if is_master(args) and (i % 100) == 0:
logging.info(
# f"[{i} / {samples_per_val}]\t"
f"[{i} / {len(dataloader)}]\t"
f"Current IoU: {score / (total + 0.001):.4f}\t")
if is_master(args):
logging.info(
f"Final IoU: {score / (total + 0.001):.4f}\t")
return score / (total + 0.001)
def inversion_score(p1, p2):
assert len(p1) == len(p2), f'{len(p1)}, {len(p2)}'
n = len(p1)
cnt = 0
for i in range(n-1):
for j in range(i+1, n):
if p1[i] > p1[j] and p2[i] < p2[j]:
cnt += 1
elif p1[i] < p1[j] and p2[i] > p2[j]:
cnt += 1
return 1 - cnt / (n * (n - 1) / 2)
def model_pair_score(score:dict, p1, p2, num_image):
model_pairs = set()
for i in range(num_image):
if i not in score.keys():
score[i] = {}
for j in range(num_image):
if j not in score[i].keys():
score[i][j] = 0
if j == i or (i, j) in model_pairs or (j, i) in model_pairs:
continue
model_pairs.add((i,j))
if (p1[i] - p1[j]) * (p2[i] - p2[j]) > 0:
score[i][j] += 1
return score
def all_gather(tensor):
world_size = torch.distributed.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor, async_op=False)
return torch.cat(tensor_list, dim=0)
def evaluate_ranking(model, data, args):
model = unwrap_model(model)
model.eval()
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
score = 0
# pair_score = {}
with torch.no_grad():
for i, batch in enumerate(dataloader):
if i % args.world_size != args.local_rank:
continue
images, num_images, labels, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
num_images = num_images.to(device=device, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
with autocast():
if args.no_text_condition:
image_features = model.visual(images)
logit_scale = model.logit_scale
scores = model.score_predictor(image_features)
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
else:
outputs = model(images, texts)
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
logits_per_image = logit_scale * image_features @ text_features.T
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
predicted = [torch.argsort(-k) for k in paired_logits_list]
hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
labels = [label for label in labels.split(num_images.tolist())]
if isinstance(dataloader.dataset, RankingDataset):
score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))])
elif isinstance(dataloader.dataset, ImageRewardDataset):
score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))])
# write score to a tempfile, file name is a hash string
file_name = hashlib.md5(str(args.name).encode()).hexdigest()
with open(f"{file_name}_{args.rank}.tmp", "w") as f:
f.write(str(score))
time.sleep(0.1)
barrier(args)
score = 0
if is_master(args):
for i in range(args.world_size):
with open(f"{file_name}_{i}.tmp", "r") as f:
score += float(f.read())
os.remove(f"{file_name}_{i}.tmp")
score = score / samples_per_val
logging.info(
f"Final Acc: {score:.4f}\t")
# return score, pair_score
return score
def calc_ImageReward( pred, gt):
# using inversion score calculate method in ImageReward
# There's some little difference because ImageReward benchmark has tie rankings
tol_cnt = 0.
true_cnt = 0.
for idx in range(len(gt)):
item_base = gt
item = pred
for i in range(len(item_base)):
for j in range(i+1, len(item_base)):
if item_base[i] > item_base[j]:
if item[i] >= item[j]:
tol_cnt += 1
elif item[i] < item[j]:
tol_cnt += 1
true_cnt += 1
elif item_base[i] < item_base[j]:
if item[i] > item[j]:
tol_cnt += 1
true_cnt += 1
elif item[i] <= item[j]:
tol_cnt += 1
return true_cnt / tol_cnt
def get_clip_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics
def maybe_compute_generative_loss(model_out):
if "logits" in model_out and "labels" in model_out:
token_logits = model_out["logits"]
token_labels = model_out["labels"]
return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)