import os import numpy as np import matplotlib.pyplot as plt from torchmetrics.functional.multimodal import clip_score from functools import partial import torch streetview_path = "Generated Images/streetview_clip" sd_path = "Generated Images/sd_clip" ldm3d_path = "Generated Images/ldm3d_clip" clip_score_fn = partial(clip_score, model_name_or_path="checkpoints/clip-vit-base-patch16") device = torch.device("cuda") def calculate_clip_score(images, prompts): images_int = (images * 255).astype("uint8") images_tensor = torch.from_numpy(images_int).permute(0, 3, 1, 2).to(device) # Calculate the clip score. Assuming the clip_score function handles the device correctly. clip_score_value = clip_score_fn(images_tensor, prompts).detach().cpu() # Detach and move back to CPU if needed return round(float(clip_score_value), 4) models = [ (streetview_path, "StreetView360X"), (sd_path, "Stable Diffusion 2.1"), (ldm3d_path, "LDM3D-pano"), ] for path, name in models: files = os.listdir(path) imgs = [] prompts = [] for file_name in files: file_path = os.path.join(path, file_name) image = plt.imread(file_path) imgs.append(image) prompt = file_name.split('_')[1] prompts.append(prompt) imgs = np.array(imgs, dtype=object) clip_score = calculate_clip_score(imgs, prompts) print(f"CLIP Score with {name}: {clip_score}")