|
import os |
|
|
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
from torchmetrics.functional.multimodal import clip_score |
|
|
|
from functools import partial |
|
import torch |
|
|
|
|
|
streetview_path = "Generated Images/streetview_clip" |
|
sd_path = "Generated Images/sd_clip" |
|
ldm3d_path = "Generated Images/ldm3d_clip" |
|
|
|
clip_score_fn = partial(clip_score, model_name_or_path="checkpoints/clip-vit-base-patch16") |
|
device = torch.device("cuda") |
|
|
|
|
|
def calculate_clip_score(images, prompts): |
|
|
|
images_int = (images * 255).astype("uint8") |
|
|
|
images_tensor = torch.from_numpy(images_int).permute(0, 3, 1, 2).to(device) |
|
|
|
|
|
|
|
clip_score_value = clip_score_fn(images_tensor, prompts).detach().cpu() |
|
|
|
return round(float(clip_score_value), 4) |
|
|
|
|
|
|
|
models = [ |
|
|
|
(streetview_path, "StreetView360X"), |
|
(sd_path, "Stable Diffusion 2.1"), |
|
(ldm3d_path, "LDM3D-pano"), |
|
|
|
] |
|
|
|
|
|
|
|
for path, name in models: |
|
|
|
files = os.listdir(path) |
|
|
|
imgs = [] |
|
|
|
prompts = [] |
|
|
|
for file_name in files: |
|
|
|
file_path = os.path.join(path, file_name) |
|
|
|
image = plt.imread(file_path) |
|
|
|
imgs.append(image) |
|
|
|
prompt = file_name.split('_')[1] |
|
|
|
prompts.append(prompt) |
|
|
|
imgs = np.array(imgs, dtype=object) |
|
|
|
clip_score = calculate_clip_score(imgs, prompts) |
|
|
|
print(f"CLIP Score with {name}: {clip_score}") |
|
|
|
|
|
|