Spaces:
Sleeping
Sleeping
from functools import partial | |
from typing import Callable, Dict, List | |
import numpy as np | |
import torch | |
from torchmetrics.functional.multimodal import clip_score | |
from torchmetrics.image.inception import InceptionScore | |
SEED = 0 | |
inception_score_fn = InceptionScore(normalize=True) | |
torch.manual_seed(SEED) | |
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") | |
def compute_main_metrics(images: np.ndarray, prompts: List[str]) -> Dict: | |
inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2)) | |
inception_score = inception_score_fn.compute() | |
images_int = (images * 255).astype("uint8") | |
clip_score = clip_score_fn( | |
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts | |
).detach() | |
return { | |
"inception_score (⬆️)": { | |
"mean": round(float(inception_score[0]), 4), | |
"std": round(float(inception_score[1]), 4), | |
}, | |
"clip_score (⬆️)": round(float(clip_score), 4), | |
} | |
def compute_psnr_or_ssim( | |
fn: Callable, images_dict: Dict, original_scheduler_name: str | |
) -> Dict: | |
result_dict = {} | |
original_scheduler_images = images_dict[original_scheduler_name] | |
original_scheduler_images = torch.from_numpy(original_scheduler_images).permute( | |
0, 3, 1, 2 | |
) | |
for k in images_dict: | |
if k != original_scheduler_name: | |
current_scheduler_images = torch.from_numpy(images_dict[k]).permute( | |
0, 3, 1, 2 | |
) | |
current_value = fn(current_scheduler_images, original_scheduler_images) | |
result_dict.update({k: round(float(current_value), 4)}) | |
return result_dict | |