import torch import os from PIL import Image from torch.utils.data import Dataset, DataLoader from torchvision import transforms from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from huggingface_hub import HfApi from torch.optim import AdamW from tqdm import tqdm import gc from torch.cuda.amp import autocast # Setare configurare CUDA pentru a reduce fragmentarea memoriei os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Verifică dacă GPU-ul este detectat print(torch.cuda.is_available()) img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data' # Definirea dataset-ului class ManualCaptionDataset(Dataset): def __init__(self, img_dir, transform=None): self.img_dir = img_dir self.img_names = os.listdir(img_dir) self.transform = transform self.captions = [] # Introducem manual descrierile pentru fiecare imagine for img_name in self.img_names: caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market' self.captions.append(caption) def __len__(self): return len(self.img_names) def __getitem__(self, idx): img_name = os.path.join(self.img_dir, self.img_names[idx]) image = Image.open(img_name).convert("RGB") caption = self.captions[idx] if self.transform: image = self.transform(image) return image, caption # Configurare transformări transform = transforms.Compose([ transforms.Resize((256, 256)), # Dimensiune imagine redusă transforms.ToTensor(), ]) # Crearea dataset-ului dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform) dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Dimensiune batch redusă # Încărcare model UNet unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16) unet.to("cuda") # Încărcare model pentru autoencoder vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16) vae.to("cuda") # Încărcare tokenizer și text model pentru CLIP text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") text_model.to("cuda") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # Scheduler scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") # Pregătire optimizer optimizer = AdamW(unet.parameters(), lr=5e-6) # Setare model în modul de antrenament unet.train() text_model.train() # Definire număr de epoci num_epochs = 5 # Poți ajusta acest număr în funcție de resurse # Training loop for epoch in range(num_epochs): for images, captions in tqdm(dataloader): images = images.to("cuda", dtype=torch.float16) # Curăță memoria GPU înainte de fiecare iterare gc.collect() torch.cuda.empty_cache() # Tokenizare captions inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda") # Generare zgomot aleatoriu noise = torch.randn_like(images).to("cuda", dtype=torch.float16) # Codificare imagini în latențe latents = vae.encode(images).latent_dist.sample() latents = latents * 0.18215 # Generare timesteps timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long() # Forward pass prin UNet encoder_hidden_states = text_model(inputs.input_ids)[0] # Convertim encoder_hidden_states la float16 encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16) # Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet expected_dim = unet.config.cross_attention_dim if encoder_hidden_states.shape[-1] != expected_dim: projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16) encoder_hidden_states = projection_layer(encoder_hidden_states) # Generare predicție de zgomot with autocast(): noise_pred = unet(latents, timesteps, encoder_hidden_states).sample # Verifică dimensiunile tensorilor print(f"noise_pred shape: {noise_pred.shape}") print(f"noise shape: {noise.shape}") # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise if noise_pred.shape[1] != noise.shape[1]: # Ajustează numărul de canale pentru noise_pred conv_layer = torch.nn.Conv2d( in_channels=noise_pred.shape[1], out_channels=noise.shape[1], kernel_size=1 ).to("cuda", dtype=torch.float16) noise_pred = conv_layer(noise_pred) # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise if noise_pred.shape[2:] != noise.shape[2:]: noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False) # Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original loss = torch.nn.functional.mse_loss(noise_pred, noise) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() # Curăță memoria GPU după fiecare iterare gc.collect() torch.cuda.empty_cache() print(f"Epoch {epoch + 1}, Loss: {loss.item()}") # Salvarea modelului antrenat unet.save_pretrained("./finetuned-unet") text_model.save_pretrained("./finetuned-text-model") api = HfApi() #api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model") #api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model") # Încărcarea pe Hugging Face api.upload_folder( folder_path="./finetuned-unet", path_in_repo=".", repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model" ) # Curăță memoria GPU la final gc.collect() torch.cuda.empty_cache()