File size: 1,451 Bytes
d7b4a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.functional.multimodal import clip_score
from functools import partial
import torch
streetview_path = "Generated Images/streetview_clip"
sd_path = "Generated Images/sd_clip"
ldm3d_path = "Generated Images/ldm3d_clip"
clip_score_fn = partial(clip_score, model_name_or_path="checkpoints/clip-vit-base-patch16")
device = torch.device("cuda")
def calculate_clip_score(images, prompts):
images_int = (images * 255).astype("uint8")
images_tensor = torch.from_numpy(images_int).permute(0, 3, 1, 2).to(device)
# Calculate the clip score. Assuming the clip_score function handles the device correctly.
clip_score_value = clip_score_fn(images_tensor, prompts).detach().cpu() # Detach and move back to CPU if needed
return round(float(clip_score_value), 4)
models = [
(streetview_path, "StreetView360X"),
(sd_path, "Stable Diffusion 2.1"),
(ldm3d_path, "LDM3D-pano"),
]
for path, name in models:
files = os.listdir(path)
imgs = []
prompts = []
for file_name in files:
file_path = os.path.join(path, file_name)
image = plt.imread(file_path)
imgs.append(image)
prompt = file_name.split('_')[1]
prompts.append(prompt)
imgs = np.array(imgs, dtype=object)
clip_score = calculate_clip_score(imgs, prompts)
print(f"CLIP Score with {name}: {clip_score}")
|