evaluate-sd-schedulers / metrics_utils.py
sayakpaul's picture
sayakpaul HF staff
redo the space to spice things up.
873e677
raw
history blame
1.67 kB
from functools import partial
from typing import Callable, Dict, List
import numpy as np
import torch
from torchmetrics.functional.multimodal import clip_score
from torchmetrics.image.inception import InceptionScore
SEED = 0
inception_score_fn = InceptionScore(normalize=True)
torch.manual_seed(SEED)
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
def compute_main_metrics(images: np.ndarray, prompts: List[str]) -> Dict:
inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
inception_score = inception_score_fn.compute()
images_int = (images * 255).astype("uint8")
clip_score = clip_score_fn(
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
).detach()
return {
"inception_score (⬆️)": {
"mean": round(float(inception_score[0]), 4),
"std": round(float(inception_score[1]), 4),
},
"clip_score (⬆️)": round(float(clip_score), 4),
}
def compute_psnr_or_ssim(
fn: Callable, images_dict: Dict, original_scheduler_name: str
) -> Dict:
result_dict = {}
original_scheduler_images = images_dict[original_scheduler_name]
original_scheduler_images = torch.from_numpy(original_scheduler_images).permute(
0, 3, 1, 2
)
for k in images_dict:
if k != original_scheduler_name:
current_scheduler_images = torch.from_numpy(images_dict[k]).permute(
0, 3, 1, 2
)
current_value = fn(current_scheduler_images, original_scheduler_images)
result_dict.update({k: round(float(current_value), 4)})
return result_dict