|
import torch |
|
import os |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from huggingface_hub import HfApi |
|
from torch.optim import AdamW |
|
from tqdm import tqdm |
|
import gc |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
print(torch.cuda.is_available()) |
|
|
|
img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data' |
|
|
|
|
|
class ManualCaptionDataset(Dataset): |
|
def __init__(self, img_dir, transform=None): |
|
self.img_dir = img_dir |
|
self.img_names = os.listdir(img_dir) |
|
self.transform = transform |
|
self.captions = [] |
|
|
|
|
|
for img_name in self.img_names: |
|
caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market' |
|
self.captions.append(caption) |
|
|
|
def __len__(self): |
|
return len(self.img_names) |
|
|
|
def __getitem__(self, idx): |
|
img_name = os.path.join(self.img_dir, self.img_names[idx]) |
|
image = Image.open(img_name).convert("RGB") |
|
caption = self.captions[idx] |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
return image, caption |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform) |
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16) |
|
unet.to("cuda") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16) |
|
vae.to("cuda") |
|
|
|
|
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
|
text_model.to("cuda") |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
|
|
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") |
|
|
|
|
|
optimizer = AdamW(unet.parameters(), lr=5e-6) |
|
|
|
|
|
unet.train() |
|
text_model.train() |
|
|
|
|
|
num_epochs = 5 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
for images, captions in tqdm(dataloader): |
|
images = images.to("cuda", dtype=torch.float16) |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda") |
|
|
|
|
|
noise = torch.randn_like(images).to("cuda", dtype=torch.float16) |
|
|
|
|
|
latents = vae.encode(images).latent_dist.sample() |
|
latents = latents * 0.18215 |
|
|
|
|
|
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long() |
|
|
|
|
|
encoder_hidden_states = text_model(inputs.input_ids)[0] |
|
|
|
|
|
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16) |
|
|
|
|
|
expected_dim = unet.config.cross_attention_dim |
|
if encoder_hidden_states.shape[-1] != expected_dim: |
|
projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16) |
|
encoder_hidden_states = projection_layer(encoder_hidden_states) |
|
|
|
|
|
with autocast(): |
|
noise_pred = unet(latents, timesteps, encoder_hidden_states).sample |
|
|
|
|
|
print(f"noise_pred shape: {noise_pred.shape}") |
|
print(f"noise shape: {noise.shape}") |
|
|
|
|
|
if noise_pred.shape[1] != noise.shape[1]: |
|
|
|
conv_layer = torch.nn.Conv2d( |
|
in_channels=noise_pred.shape[1], |
|
out_channels=noise.shape[1], |
|
kernel_size=1 |
|
).to("cuda", dtype=torch.float16) |
|
noise_pred = conv_layer(noise_pred) |
|
|
|
|
|
if noise_pred.shape[2:] != noise.shape[2:]: |
|
noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
loss = torch.nn.functional.mse_loss(noise_pred, noise) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
print(f"Epoch {epoch + 1}, Loss: {loss.item()}") |
|
|
|
|
|
unet.save_pretrained("./finetuned-unet") |
|
text_model.save_pretrained("./finetuned-text-model") |
|
api = HfApi() |
|
|
|
|
|
|
|
api.upload_folder( |
|
folder_path="./finetuned-unet", |
|
path_in_repo=".", |
|
repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", |
|
repo_type="model" |
|
) |
|
|
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|