Esmail-AGumaan commited on
Commit
a8ff063
1 Parent(s): dd7cfee

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +141 -141
pipeline.py CHANGED
@@ -1,141 +1,141 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- from nanograd.models.stable_diffusion.ddpm import DDPMSampler
5
-
6
- WIDTH = 512
7
- HEIGHT = 512
8
- LATENTS_WIDTH = WIDTH // 8
9
- LATENTS_HEIGHT = HEIGHT // 8
10
-
11
- def generate(
12
- prompt,
13
- uncond_prompt=None,
14
- input_image=None,
15
- strength=0.8,
16
- do_cfg=True,
17
- cfg_scale=7.5,
18
- sampler_name="ddpm",
19
- n_inference_steps=50,
20
- models={},
21
- seed=None,
22
- device=None,
23
- idle_device=None,
24
- tokenizer=None,
25
- ):
26
- with torch.no_grad():
27
- if not 0 < strength <= 1:
28
- raise ValueError("strength must be between 0 and 1")
29
-
30
- if idle_device:
31
- to_idle = lambda x: x.to(idle_device)
32
- else:
33
- to_idle = lambda x: x
34
-
35
- generator = torch.Generator(device=device)
36
- if seed is None:
37
- generator.seed()
38
- else:
39
- generator.manual_seed(seed)
40
-
41
- clip = models["clip"]
42
- clip.to(device)
43
-
44
- if do_cfg:
45
- cond_tokens = tokenizer.batch_encode_plus(
46
- [prompt], padding="max_length", max_length=77
47
- ).input_ids
48
- cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
49
- cond_context = clip(cond_tokens)
50
- uncond_tokens = tokenizer.batch_encode_plus(
51
- [uncond_prompt], padding="max_length", max_length=77
52
- ).input_ids
53
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
54
- uncond_context = clip(uncond_tokens)
55
- context = torch.cat([cond_context, uncond_context])
56
- else:
57
- tokens = tokenizer.batch_encode_plus(
58
- [prompt], padding="max_length", max_length=77
59
- ).input_ids
60
- tokens = torch.tensor(tokens, dtype=torch.long, device=device)
61
- context = clip(tokens)
62
- to_idle(clip)
63
-
64
- if sampler_name == "ddpm":
65
- sampler = DDPMSampler(generator)
66
- sampler.set_inference_timesteps(n_inference_steps)
67
- else:
68
- raise ValueError("Unknown sampler value %s. ")
69
-
70
- latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
71
-
72
- if input_image:
73
- encoder = models["encoder"]
74
- encoder.to(device)
75
-
76
- input_image_tensor = input_image.resize((WIDTH, HEIGHT))
77
- input_image_tensor = np.array(input_image_tensor)
78
- input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
79
- input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
80
- input_image_tensor = input_image_tensor.unsqueeze(0)
81
- input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
82
-
83
- encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
84
- latents = encoder(input_image_tensor, encoder_noise)
85
-
86
- sampler.set_strength(strength=strength)
87
- latents = sampler.add_noise(latents, sampler.timesteps[0])
88
-
89
- to_idle(encoder)
90
- else:
91
- latents = torch.randn(latents_shape, generator=generator, device=device)
92
-
93
- diffusion = models["diffusion"]
94
- diffusion.to(device)
95
-
96
- timesteps = tqdm(sampler.timesteps)
97
- for i, timestep in enumerate(timesteps):
98
- time_embedding = get_time_embedding(timestep).to(device)
99
-
100
- model_input = latents
101
-
102
- if do_cfg:
103
- model_input = model_input.repeat(2, 1, 1, 1)
104
-
105
- model_output = diffusion(model_input, context, time_embedding)
106
-
107
- if do_cfg:
108
- output_cond, output_uncond = model_output.chunk(2)
109
- model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
110
-
111
- latents = sampler.step(timestep, latents, model_output)
112
-
113
- to_idle(diffusion)
114
-
115
- decoder = models["decoder"]
116
- decoder.to(device)
117
- images = decoder(latents)
118
- to_idle(decoder)
119
-
120
- images = rescale(images, (-1, 1), (0, 255), clamp=True)
121
- images = images.permute(0, 2, 3, 1)
122
- images = images.to("cpu", torch.uint8).numpy()
123
- return images[0]
124
-
125
- def rescale(x, old_range, new_range, clamp=False):
126
- old_min, old_max = old_range
127
- new_min, new_max = new_range
128
- x -= old_min
129
- x *= (new_max - new_min) / (old_max - old_min)
130
- x += new_min
131
- if clamp:
132
- x = x.clamp(new_min, new_max)
133
- return x
134
-
135
- def get_time_embedding(timestep):
136
- # Shape: (160,)
137
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
138
- # Shape: (1, 160)
139
- x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
140
- # Shape: (1, 160 * 2)
141
- return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from ddpm import DDPMSampler
5
+
6
+ WIDTH = 512
7
+ HEIGHT = 512
8
+ LATENTS_WIDTH = WIDTH // 8
9
+ LATENTS_HEIGHT = HEIGHT // 8
10
+
11
+ def generate(
12
+ prompt,
13
+ uncond_prompt=None,
14
+ input_image=None,
15
+ strength=0.8,
16
+ do_cfg=True,
17
+ cfg_scale=7.5,
18
+ sampler_name="ddpm",
19
+ n_inference_steps=50,
20
+ models={},
21
+ seed=None,
22
+ device=None,
23
+ idle_device=None,
24
+ tokenizer=None,
25
+ ):
26
+ with torch.no_grad():
27
+ if not 0 < strength <= 1:
28
+ raise ValueError("strength must be between 0 and 1")
29
+
30
+ if idle_device:
31
+ to_idle = lambda x: x.to(idle_device)
32
+ else:
33
+ to_idle = lambda x: x
34
+
35
+ generator = torch.Generator(device=device)
36
+ if seed is None:
37
+ generator.seed()
38
+ else:
39
+ generator.manual_seed(seed)
40
+
41
+ clip = models["clip"]
42
+ clip.to(device)
43
+
44
+ if do_cfg:
45
+ cond_tokens = tokenizer.batch_encode_plus(
46
+ [prompt], padding="max_length", max_length=77
47
+ ).input_ids
48
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
49
+ cond_context = clip(cond_tokens)
50
+ uncond_tokens = tokenizer.batch_encode_plus(
51
+ [uncond_prompt], padding="max_length", max_length=77
52
+ ).input_ids
53
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
54
+ uncond_context = clip(uncond_tokens)
55
+ context = torch.cat([cond_context, uncond_context])
56
+ else:
57
+ tokens = tokenizer.batch_encode_plus(
58
+ [prompt], padding="max_length", max_length=77
59
+ ).input_ids
60
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
61
+ context = clip(tokens)
62
+ to_idle(clip)
63
+
64
+ if sampler_name == "ddpm":
65
+ sampler = DDPMSampler(generator)
66
+ sampler.set_inference_timesteps(n_inference_steps)
67
+ else:
68
+ raise ValueError("Unknown sampler value %s. ")
69
+
70
+ latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
71
+
72
+ if input_image:
73
+ encoder = models["encoder"]
74
+ encoder.to(device)
75
+
76
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
77
+ input_image_tensor = np.array(input_image_tensor)
78
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
79
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
80
+ input_image_tensor = input_image_tensor.unsqueeze(0)
81
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
82
+
83
+ encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
84
+ latents = encoder(input_image_tensor, encoder_noise)
85
+
86
+ sampler.set_strength(strength=strength)
87
+ latents = sampler.add_noise(latents, sampler.timesteps[0])
88
+
89
+ to_idle(encoder)
90
+ else:
91
+ latents = torch.randn(latents_shape, generator=generator, device=device)
92
+
93
+ diffusion = models["diffusion"]
94
+ diffusion.to(device)
95
+
96
+ timesteps = tqdm(sampler.timesteps)
97
+ for i, timestep in enumerate(timesteps):
98
+ time_embedding = get_time_embedding(timestep).to(device)
99
+
100
+ model_input = latents
101
+
102
+ if do_cfg:
103
+ model_input = model_input.repeat(2, 1, 1, 1)
104
+
105
+ model_output = diffusion(model_input, context, time_embedding)
106
+
107
+ if do_cfg:
108
+ output_cond, output_uncond = model_output.chunk(2)
109
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
110
+
111
+ latents = sampler.step(timestep, latents, model_output)
112
+
113
+ to_idle(diffusion)
114
+
115
+ decoder = models["decoder"]
116
+ decoder.to(device)
117
+ images = decoder(latents)
118
+ to_idle(decoder)
119
+
120
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
121
+ images = images.permute(0, 2, 3, 1)
122
+ images = images.to("cpu", torch.uint8).numpy()
123
+ return images[0]
124
+
125
+ def rescale(x, old_range, new_range, clamp=False):
126
+ old_min, old_max = old_range
127
+ new_min, new_max = new_range
128
+ x -= old_min
129
+ x *= (new_max - new_min) / (old_max - old_min)
130
+ x += new_min
131
+ if clamp:
132
+ x = x.clamp(new_min, new_max)
133
+ return x
134
+
135
+ def get_time_embedding(timestep):
136
+ # Shape: (160,)
137
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
138
+ # Shape: (1, 160)
139
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
140
+ # Shape: (1, 160 * 2)
141
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)