File size: 3,069 Bytes
81ccbca
 
 
 
 
 
 
 
 
 
 
 
7021212
 
 
81ccbca
 
 
 
 
 
 
 
 
 
 
 
7021212
 
 
 
 
81ccbca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7021212
 
 
 
81ccbca
 
 
 
fd9afda
 
81ccbca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm

def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
  
    nsteps = 50

    diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
    diffuser.train()




    finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)

    optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
    criteria = torch.nn.MSELoss()

    pbar = tqdm(range(iterations))

    with torch.no_grad():

        neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
        positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)

    del diffuser.vae
    del diffuser.text_encoder
    del diffuser.tokenizer

    torch.cuda.empty_cache()

    for i in pbar:
        
        with torch.no_grad():

            diffuser.set_scheduler_timesteps(nsteps)

            optimizer.zero_grad()

            iteration = torch.randint(1, nsteps - 1, (1,)).item()

            latents = diffuser.get_initial_latents(1, 512, 1)

            with finetuner:

                latents_steps, _ = diffuser.diffusion(
                    latents,
                    positive_text_embeddings,
                    start_iteration=0,
                    end_iteration=iteration,
                    guidance_scale=3, 
                    show_progress=False
                )

            diffuser.set_scheduler_timesteps(1000)

            iteration = int(iteration / nsteps * 1000)
            
            positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
            neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)

        with finetuner:
            negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)

        positive_latents.requires_grad = False
        neutral_latents.requires_grad = False

        loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
        
        del negative_latents, neutral_latents, positive_latents, latents_steps, latents
        torch.cuda.empty_cache()

        loss.backward()
        optimizer.step()

    torch.save(finetuner.state_dict(), save_path)
    del diffuser
    torch.cuda.empty_cache()
if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--prompt', required=True)
    parser.add_argument('--modules', required=True)
    parser.add_argument('--freeze_modules', nargs='+', required=True)
    parser.add_argument('--save_path', required=True)
    parser.add_argument('--iterations', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--negative_guidance', type=float, required=True)

    train(**vars(parser.parse_args()))