File size: 5,419 Bytes
cdee5b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from diffusers import AutoencoderKL
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import os

class VAE():
    """
    VAE (Variational Autoencoder) class for image processing.
    """

    def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
        """
        Initialize the VAE instance.

        :param model_path: Path to the trained model.
        :param resized_img: The size to which images are resized.
        :param use_float16: Whether to use float16 precision.
        """
        self.model_path = model_path
        self.vae = AutoencoderKL.from_pretrained(self.model_path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.vae.to(self.device)

        if use_float16:
            self.vae = self.vae.half()
            self._use_float16 = True
        else:
            self._use_float16 = False

        self.scaling_factor = self.vae.config.scaling_factor
        self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self._resized_img = resized_img
        self._mask_tensor = self.get_mask_tensor()
        
    def get_mask_tensor(self):
        """
        Creates a mask tensor for image processing.
        :return: A mask tensor.
        """
        mask_tensor = torch.zeros((self._resized_img,self._resized_img))
        mask_tensor[:self._resized_img//2,:] = 1
        mask_tensor[mask_tensor< 0.5] = 0
        mask_tensor[mask_tensor>= 0.5] = 1
        return mask_tensor
            
    def preprocess_img(self,img_name,half_mask=False):
        """
        Preprocess an image for the VAE.

        :param img_name: The image file path or a list of image file paths.
        :param half_mask: Whether to apply a half mask to the image.
        :return: A preprocessed image tensor.
        """
        window = []
        if isinstance(img_name, str):
            window_fnames = [img_name]
            for fname in window_fnames:
                img = cv2.imread(fname)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (self._resized_img, self._resized_img),
                                     interpolation=cv2.INTER_LANCZOS4)
                window.append(img)
        else:
            img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
            window.append(img)
            
        x = np.asarray(window) / 255.
        x = np.transpose(x, (3, 0, 1, 2))
        x = torch.squeeze(torch.FloatTensor(x))
        if half_mask:
            x = x * (self._mask_tensor>0.5)
        x = self.transform(x)
        
        x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
        x = x.to(self.vae.device)

        return x

    def encode_latents(self,image):
        """
        Encode an image into latent variables.

        :param image: The image tensor to encode.
        :return: The encoded latent variables.
        """
        with torch.no_grad():
            init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
        init_latents = self.scaling_factor * init_latent_dist.sample()
        return init_latents
    
    def decode_latents(self, latents):
        """
        Decode latent variables back into an image.
        :param latents: The latent variables to decode.
        :return: A NumPy array representing the decoded image.
        """
        latents = (1/  self.scaling_factor) * latents
        image = self.vae.decode(latents.to(self.vae.dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        image = image[...,::-1] # RGB to BGR
        return image
    
    def get_latents_for_unet(self,img):
        """
        Prepare latent variables for a U-Net model.
        :param img: The image to process.
        :return: A concatenated tensor of latents for U-Net input.
        """
        
        ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
        masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
        ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
        return latent_model_input

if __name__ == "__main__":
    vae_mode_path = "./models/sd-vae-ft-mse/"
    vae = VAE(model_path = vae_mode_path,use_float16=False)
    img_path = "./results/sun001_crop/00000.png"
    
    crop_imgs_path = "./results/sun001_crop/"
    latents_out_path = "./results/latents/"
    if not os.path.exists(latents_out_path):
        os.mkdir(latents_out_path)

    files = os.listdir(crop_imgs_path)
    files.sort()
    files = [file for file in files if file.split(".")[-1] == "png"]

    for file in files:
        index = file.split(".")[0]
        img_path = crop_imgs_path + file
        latents = vae.get_latents_for_unet(img_path)
        print(img_path,"latents",latents.size())
        #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
        #reload_tensor = torch.load('tensor.pt')
        #print(reload_tensor.size())