StreetView360X / evaluate_clip.py
everettshen's picture
Upload 14 files
d7b4a46 verified
import os
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.functional.multimodal import clip_score
from functools import partial
import torch
streetview_path = "Generated Images/streetview_clip"
sd_path = "Generated Images/sd_clip"
ldm3d_path = "Generated Images/ldm3d_clip"
clip_score_fn = partial(clip_score, model_name_or_path="checkpoints/clip-vit-base-patch16")
device = torch.device("cuda")
def calculate_clip_score(images, prompts):
images_int = (images * 255).astype("uint8")
images_tensor = torch.from_numpy(images_int).permute(0, 3, 1, 2).to(device)
# Calculate the clip score. Assuming the clip_score function handles the device correctly.
clip_score_value = clip_score_fn(images_tensor, prompts).detach().cpu() # Detach and move back to CPU if needed
return round(float(clip_score_value), 4)
models = [
(streetview_path, "StreetView360X"),
(sd_path, "Stable Diffusion 2.1"),
(ldm3d_path, "LDM3D-pano"),
]
for path, name in models:
files = os.listdir(path)
imgs = []
prompts = []
for file_name in files:
file_path = os.path.join(path, file_name)
image = plt.imread(file_path)
imgs.append(image)
prompt = file_name.split('_')[1]
prompts.append(prompt)
imgs = np.array(imgs, dtype=object)
clip_score = calculate_clip_score(imgs, prompts)
print(f"CLIP Score with {name}: {clip_score}")