File size: 1,569 Bytes
f610e83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from PIL import Image
import os
import numpy as np
from torchvision.transforms import functional as F
import torch
from torchmetrics.image.fid import FrechetInceptionDistance


# Paths setup
generated_dataset_path = "output/tryon_results"
original_dataset_path = "data/VITON-HD/test/image"  # Replace with your actual original dataset path

# Get generated images
image_paths = sorted([os.path.join(generated_dataset_path, x) for x in os.listdir(generated_dataset_path)])
generated_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]

# Get corresponding original images
original_images = []
for gen_path in image_paths:
    # Extract the XXXXXX part from "tryon_XXXXXX.jpg"
    base_name = os.path.basename(gen_path)  # get filename from path
    original_id = base_name.replace("tryon_", "")  # remove "tryon_" prefix
    
    # Construct original image path
    original_path = os.path.join(original_dataset_path, original_id)
    original_images.append(np.array(Image.open(original_path).convert("RGB")))
    


def preprocess_image(image):
    image = torch.tensor(image).unsqueeze(0)
    image = image.permute(0, 3, 1, 2) / 255.0
    return F.center_crop(image, (768, 1024))

real_images = torch.cat([preprocess_image(image) for image in original_images])
fake_images = torch.cat([preprocess_image(image) for image in generated_images])
print(real_images.shape, fake_images.shape)

fid = FrechetInceptionDistance(normalize=True)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)

print(f"FID: {float(fid.compute())}")