AndreiUrsu
commited on
Commit
•
6ffd722
1
Parent(s):
1f06dab
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from torchvision import transforms
|
6 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
8 |
+
from huggingface_hub import HfApi
|
9 |
+
from torch.optim import AdamW
|
10 |
+
from tqdm import tqdm
|
11 |
+
import gc
|
12 |
+
from torch.cuda.amp import autocast
|
13 |
+
|
14 |
+
# Setare configurare CUDA pentru a reduce fragmentarea memoriei
|
15 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
16 |
+
|
17 |
+
# Verifică dacă GPU-ul este detectat
|
18 |
+
print(torch.cuda.is_available())
|
19 |
+
|
20 |
+
img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data'
|
21 |
+
|
22 |
+
# Definirea dataset-ului
|
23 |
+
class ManualCaptionDataset(Dataset):
|
24 |
+
def __init__(self, img_dir, transform=None):
|
25 |
+
self.img_dir = img_dir
|
26 |
+
self.img_names = os.listdir(img_dir)
|
27 |
+
self.transform = transform
|
28 |
+
self.captions = []
|
29 |
+
|
30 |
+
# Introducem manual descrierile pentru fiecare imagine
|
31 |
+
for img_name in self.img_names:
|
32 |
+
caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market'
|
33 |
+
self.captions.append(caption)
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.img_names)
|
37 |
+
|
38 |
+
def __getitem__(self, idx):
|
39 |
+
img_name = os.path.join(self.img_dir, self.img_names[idx])
|
40 |
+
image = Image.open(img_name).convert("RGB")
|
41 |
+
caption = self.captions[idx]
|
42 |
+
|
43 |
+
if self.transform:
|
44 |
+
image = self.transform(image)
|
45 |
+
|
46 |
+
return image, caption
|
47 |
+
|
48 |
+
# Configurare transformări
|
49 |
+
transform = transforms.Compose([
|
50 |
+
transforms.Resize((256, 256)), # Dimensiune imagine redusă
|
51 |
+
transforms.ToTensor(),
|
52 |
+
])
|
53 |
+
|
54 |
+
# Crearea dataset-ului
|
55 |
+
dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform)
|
56 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Dimensiune batch redusă
|
57 |
+
|
58 |
+
# Încărcare model UNet
|
59 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16)
|
60 |
+
unet.to("cuda")
|
61 |
+
|
62 |
+
# Încărcare model pentru autoencoder
|
63 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16)
|
64 |
+
vae.to("cuda")
|
65 |
+
|
66 |
+
# Încărcare tokenizer și text model pentru CLIP
|
67 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
68 |
+
text_model.to("cuda")
|
69 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
70 |
+
|
71 |
+
# Scheduler
|
72 |
+
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
|
73 |
+
|
74 |
+
# Pregătire optimizer
|
75 |
+
optimizer = AdamW(unet.parameters(), lr=5e-6)
|
76 |
+
|
77 |
+
# Setare model în modul de antrenament
|
78 |
+
unet.train()
|
79 |
+
text_model.train()
|
80 |
+
|
81 |
+
# Definire număr de epoci
|
82 |
+
num_epochs = 5 # Poți ajusta acest număr în funcție de resurse
|
83 |
+
|
84 |
+
# Training loop
|
85 |
+
for epoch in range(num_epochs):
|
86 |
+
for images, captions in tqdm(dataloader):
|
87 |
+
images = images.to("cuda", dtype=torch.float16)
|
88 |
+
|
89 |
+
# Curăță memoria GPU înainte de fiecare iterare
|
90 |
+
gc.collect()
|
91 |
+
torch.cuda.empty_cache()
|
92 |
+
|
93 |
+
# Tokenizare captions
|
94 |
+
inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
|
95 |
+
|
96 |
+
# Generare zgomot aleatoriu
|
97 |
+
noise = torch.randn_like(images).to("cuda", dtype=torch.float16)
|
98 |
+
|
99 |
+
# Codificare imagini în latențe
|
100 |
+
latents = vae.encode(images).latent_dist.sample()
|
101 |
+
latents = latents * 0.18215
|
102 |
+
|
103 |
+
# Generare timesteps
|
104 |
+
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long()
|
105 |
+
|
106 |
+
# Forward pass prin UNet
|
107 |
+
encoder_hidden_states = text_model(inputs.input_ids)[0]
|
108 |
+
|
109 |
+
# Convertim encoder_hidden_states la float16
|
110 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16)
|
111 |
+
|
112 |
+
# Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet
|
113 |
+
expected_dim = unet.config.cross_attention_dim
|
114 |
+
if encoder_hidden_states.shape[-1] != expected_dim:
|
115 |
+
projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16)
|
116 |
+
encoder_hidden_states = projection_layer(encoder_hidden_states)
|
117 |
+
|
118 |
+
# Generare predicție de zgomot
|
119 |
+
with autocast():
|
120 |
+
noise_pred = unet(latents, timesteps, encoder_hidden_states).sample
|
121 |
+
|
122 |
+
# Verifică dimensiunile tensorilor
|
123 |
+
print(f"noise_pred shape: {noise_pred.shape}")
|
124 |
+
print(f"noise shape: {noise.shape}")
|
125 |
+
|
126 |
+
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
|
127 |
+
if noise_pred.shape[1] != noise.shape[1]:
|
128 |
+
# Ajustează numărul de canale pentru noise_pred
|
129 |
+
conv_layer = torch.nn.Conv2d(
|
130 |
+
in_channels=noise_pred.shape[1],
|
131 |
+
out_channels=noise.shape[1],
|
132 |
+
kernel_size=1
|
133 |
+
).to("cuda", dtype=torch.float16)
|
134 |
+
noise_pred = conv_layer(noise_pred)
|
135 |
+
|
136 |
+
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
|
137 |
+
if noise_pred.shape[2:] != noise.shape[2:]:
|
138 |
+
noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False)
|
139 |
+
|
140 |
+
# Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original
|
141 |
+
loss = torch.nn.functional.mse_loss(noise_pred, noise)
|
142 |
+
|
143 |
+
# Backpropagation
|
144 |
+
optimizer.zero_grad()
|
145 |
+
loss.backward()
|
146 |
+
optimizer.step()
|
147 |
+
|
148 |
+
# Curăță memoria GPU după fiecare iterare
|
149 |
+
gc.collect()
|
150 |
+
torch.cuda.empty_cache()
|
151 |
+
|
152 |
+
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
|
153 |
+
|
154 |
+
# Salvarea modelului antrenat
|
155 |
+
unet.save_pretrained("./finetuned-unet")
|
156 |
+
text_model.save_pretrained("./finetuned-text-model")
|
157 |
+
api = HfApi()
|
158 |
+
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model")
|
159 |
+
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model")
|
160 |
+
# Încărcarea pe Hugging Face
|
161 |
+
api.upload_folder(
|
162 |
+
folder_path="./finetuned-unet",
|
163 |
+
path_in_repo=".",
|
164 |
+
repo_id="AndreiUrsu/finetuned-stable-diffusion-unet",
|
165 |
+
repo_type="model"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
# Curăță memoria GPU la final
|
170 |
+
gc.collect()
|
171 |
+
torch.cuda.empty_cache()
|