AndreiUrsu commited on
Commit
6ffd722
1 Parent(s): 1f06dab

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +171 -0
main.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
7
+ from transformers import CLIPTextModel, CLIPTokenizer
8
+ from huggingface_hub import HfApi
9
+ from torch.optim import AdamW
10
+ from tqdm import tqdm
11
+ import gc
12
+ from torch.cuda.amp import autocast
13
+
14
+ # Setare configurare CUDA pentru a reduce fragmentarea memoriei
15
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
16
+
17
+ # Verifică dacă GPU-ul este detectat
18
+ print(torch.cuda.is_available())
19
+
20
+ img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data'
21
+
22
+ # Definirea dataset-ului
23
+ class ManualCaptionDataset(Dataset):
24
+ def __init__(self, img_dir, transform=None):
25
+ self.img_dir = img_dir
26
+ self.img_names = os.listdir(img_dir)
27
+ self.transform = transform
28
+ self.captions = []
29
+
30
+ # Introducem manual descrierile pentru fiecare imagine
31
+ for img_name in self.img_names:
32
+ caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market'
33
+ self.captions.append(caption)
34
+
35
+ def __len__(self):
36
+ return len(self.img_names)
37
+
38
+ def __getitem__(self, idx):
39
+ img_name = os.path.join(self.img_dir, self.img_names[idx])
40
+ image = Image.open(img_name).convert("RGB")
41
+ caption = self.captions[idx]
42
+
43
+ if self.transform:
44
+ image = self.transform(image)
45
+
46
+ return image, caption
47
+
48
+ # Configurare transformări
49
+ transform = transforms.Compose([
50
+ transforms.Resize((256, 256)), # Dimensiune imagine redusă
51
+ transforms.ToTensor(),
52
+ ])
53
+
54
+ # Crearea dataset-ului
55
+ dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform)
56
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Dimensiune batch redusă
57
+
58
+ # Încărcare model UNet
59
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16)
60
+ unet.to("cuda")
61
+
62
+ # Încărcare model pentru autoencoder
63
+ vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16)
64
+ vae.to("cuda")
65
+
66
+ # Încărcare tokenizer și text model pentru CLIP
67
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
68
+ text_model.to("cuda")
69
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
70
+
71
+ # Scheduler
72
+ scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
73
+
74
+ # Pregătire optimizer
75
+ optimizer = AdamW(unet.parameters(), lr=5e-6)
76
+
77
+ # Setare model în modul de antrenament
78
+ unet.train()
79
+ text_model.train()
80
+
81
+ # Definire număr de epoci
82
+ num_epochs = 5 # Poți ajusta acest număr în funcție de resurse
83
+
84
+ # Training loop
85
+ for epoch in range(num_epochs):
86
+ for images, captions in tqdm(dataloader):
87
+ images = images.to("cuda", dtype=torch.float16)
88
+
89
+ # Curăță memoria GPU înainte de fiecare iterare
90
+ gc.collect()
91
+ torch.cuda.empty_cache()
92
+
93
+ # Tokenizare captions
94
+ inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
95
+
96
+ # Generare zgomot aleatoriu
97
+ noise = torch.randn_like(images).to("cuda", dtype=torch.float16)
98
+
99
+ # Codificare imagini în latențe
100
+ latents = vae.encode(images).latent_dist.sample()
101
+ latents = latents * 0.18215
102
+
103
+ # Generare timesteps
104
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long()
105
+
106
+ # Forward pass prin UNet
107
+ encoder_hidden_states = text_model(inputs.input_ids)[0]
108
+
109
+ # Convertim encoder_hidden_states la float16
110
+ encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16)
111
+
112
+ # Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet
113
+ expected_dim = unet.config.cross_attention_dim
114
+ if encoder_hidden_states.shape[-1] != expected_dim:
115
+ projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16)
116
+ encoder_hidden_states = projection_layer(encoder_hidden_states)
117
+
118
+ # Generare predicție de zgomot
119
+ with autocast():
120
+ noise_pred = unet(latents, timesteps, encoder_hidden_states).sample
121
+
122
+ # Verifică dimensiunile tensorilor
123
+ print(f"noise_pred shape: {noise_pred.shape}")
124
+ print(f"noise shape: {noise.shape}")
125
+
126
+ # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
127
+ if noise_pred.shape[1] != noise.shape[1]:
128
+ # Ajustează numărul de canale pentru noise_pred
129
+ conv_layer = torch.nn.Conv2d(
130
+ in_channels=noise_pred.shape[1],
131
+ out_channels=noise.shape[1],
132
+ kernel_size=1
133
+ ).to("cuda", dtype=torch.float16)
134
+ noise_pred = conv_layer(noise_pred)
135
+
136
+ # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
137
+ if noise_pred.shape[2:] != noise.shape[2:]:
138
+ noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False)
139
+
140
+ # Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original
141
+ loss = torch.nn.functional.mse_loss(noise_pred, noise)
142
+
143
+ # Backpropagation
144
+ optimizer.zero_grad()
145
+ loss.backward()
146
+ optimizer.step()
147
+
148
+ # Curăță memoria GPU după fiecare iterare
149
+ gc.collect()
150
+ torch.cuda.empty_cache()
151
+
152
+ print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
153
+
154
+ # Salvarea modelului antrenat
155
+ unet.save_pretrained("./finetuned-unet")
156
+ text_model.save_pretrained("./finetuned-text-model")
157
+ api = HfApi()
158
+ #api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model")
159
+ #api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model")
160
+ # Încărcarea pe Hugging Face
161
+ api.upload_folder(
162
+ folder_path="./finetuned-unet",
163
+ path_in_repo=".",
164
+ repo_id="AndreiUrsu/finetuned-stable-diffusion-unet",
165
+ repo_type="model"
166
+ )
167
+
168
+
169
+ # Curăță memoria GPU la final
170
+ gc.collect()
171
+ torch.cuda.empty_cache()