mskrt commited on
Commit
e80c177
verified
1 Parent(s): a4aee32

uploading remaining files

Browse files
model_index.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SuperDiffPipeline",
3
+ "_diffusers_version": "0.31.0",
4
+ "batch_size": null,
5
+ "device": "cuda",
6
+ "guidance_scale": null,
7
+ "lift": null,
8
+ "num_inference_steps": null,
9
+ "seed": null
10
+ }
pipeline.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Callable, Dict, List, Optional
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
+ from diffusers.configuration_utils import ConfigMixin
7
+
8
+
9
+ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
10
+ """SuperDiffPipeline."""
11
+
12
+ def __init__(self, model: Callable, vae: Callable, text_encoder: Callable, scheduler: Callable, tokenizer: Callable, **kwargs) -> None:
13
+ """__init__.
14
+
15
+ Parameters
16
+ ----------
17
+ model : Callable
18
+ model
19
+ vae : Callable
20
+ vae
21
+ text_encoder : Callable
22
+ text_encoder
23
+ scheduler : Callable
24
+ scheduler
25
+ tokenizer : Callable
26
+ tokenizer
27
+ kwargs :
28
+ kwargs
29
+
30
+ Returns
31
+ -------
32
+ None
33
+
34
+ """
35
+ super().__init__()
36
+ self.model = model
37
+ self.vae = vae
38
+ self.text_encoder = text_encoder
39
+ self.tokenizer = tokenizer
40
+ self.scheduler = scheduler
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ self.vae.to(device)
45
+ self.model.to(device)
46
+ self.text_encoder.to(device)
47
+
48
+ self.register_to_config(
49
+ #model=model,
50
+ #vae=vae,
51
+ #tokenizer=tokenizer,
52
+ #text_encoder=text_encoder,
53
+ #scheduler=scheduler,
54
+ device=device,
55
+ batch_size=None,
56
+ num_inference_steps=None,
57
+ guidance_scale=None,
58
+ lift=None,
59
+ seed=None,
60
+ )
61
+
62
+ @torch.no_grad
63
+ def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
64
+ """get_batch.
65
+
66
+ Parameters
67
+ ----------
68
+ latents : Callable
69
+ latents
70
+ nrow : int
71
+ nrow
72
+ ncol : int
73
+ ncol
74
+
75
+ Returns
76
+ -------
77
+ Callable
78
+
79
+ """
80
+ image = self.vae.decode(
81
+ latents / self.vae.config.scaling_factor, return_dict=False
82
+ )[0]
83
+ image = (image / 2 + 0.5).clamp(0, 1).squeeze()
84
+ if len(image.shape) < 4:
85
+ image = image.unsqueeze(0)
86
+ image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
87
+ return image
88
+
89
+ @torch.no_grad
90
+ def get_text_embedding(self, prompt: str) -> Callable:
91
+ """get_text_embedding.
92
+
93
+ Parameters
94
+ ----------
95
+ prompt : str
96
+ prompt
97
+
98
+ Returns
99
+ -------
100
+ Callable
101
+
102
+ """
103
+ text_input = self.tokenizer(
104
+ prompt,
105
+ padding="max_length",
106
+ max_length=self.tokenizer.model_max_length,
107
+ truncation=True,
108
+ return_tensors="pt",
109
+ )
110
+ return self.text_encoder(text_input.input_ids.to(self.device))[0]
111
+
112
+ @torch.no_grad
113
+ def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable):
114
+ """get_vel.
115
+
116
+ Parameters
117
+ ----------
118
+ t : float
119
+ t
120
+ sigma : float
121
+ sigma
122
+ latents : Callable
123
+ latents
124
+ embeddings : Callable
125
+ embeddings
126
+ """
127
+ def v(_x, _e): return self.model(
128
+ _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
129
+ ).sample
130
+ embeds = torch.cat(embeddings)
131
+ latent_input = latents
132
+ vel = v(latent_input, embeds)
133
+ return vel
134
+
135
+ def preprocess(
136
+ self,
137
+ prompt_1: str,
138
+ prompt_2: str,
139
+ seed: int = None,
140
+ num_inference_steps: int = 1000,
141
+ batch_size: int = 1,
142
+ lift: int = 0.0,
143
+ height: int = 512,
144
+ width: int = 512,
145
+ guidance_scale: int = 7.5,
146
+ ) -> Callable:
147
+ """preprocess.
148
+
149
+ Parameters
150
+ ----------
151
+ prompt_1 : str
152
+ prompt_1
153
+ prompt_2 : str
154
+ prompt_2
155
+ seed : int
156
+ seed
157
+ num_inference_steps : int
158
+ num_inference_steps
159
+ batch_size : int
160
+ batch_size
161
+ lift : int
162
+ lift
163
+ height : int
164
+ height
165
+ width : int
166
+ width
167
+ guidance_scale : int
168
+ guidance_scale
169
+
170
+ Returns
171
+ -------
172
+ Callable
173
+
174
+ """
175
+ # Tokenize the input
176
+ self.batch_size = batch_size
177
+ self.num_inference_steps = num_inference_steps
178
+ self.guidance_scale = guidance_scale
179
+ self.lift = lift
180
+ self.seed = seed
181
+ if self.seed is None:
182
+ self.seed = random.randint(0, 2**32 - 1)
183
+ obj_prompt = [prompt_1]
184
+ bg_prompt = [prompt_2]
185
+ obj_embeddings = self.get_text_embedding(obj_prompt * batch_size)
186
+ bg_embeddings = self.get_text_embedding(bg_prompt * batch_size)
187
+
188
+ uncond_embeddings = self.get_text_embedding([""] * batch_size)
189
+
190
+ generator = torch.cuda.manual_seed(
191
+ self.seed
192
+ ) # Seed generator to create the initial latent noise
193
+ latents = torch.randn(
194
+ (batch_size, self.model.config.in_channels, height // 8, width // 8),
195
+ generator=generator,
196
+ device=self.device,
197
+ )
198
+
199
+ latents_og = latents.clone().detach()
200
+ latents_uncond_og = latents.clone().detach()
201
+
202
+ self.scheduler.set_timesteps(num_inference_steps)
203
+ latents = latents * self.scheduler.init_noise_sigma
204
+
205
+ latents_uncond = latents.clone().detach()
206
+ return {
207
+ "latents": latents,
208
+ "obj_embeddings": obj_embeddings,
209
+ "uncond_embeddings": uncond_embeddings,
210
+ "bg_embeddings": bg_embeddings,
211
+ }
212
+
213
+ def _forward(self, model_inputs: Dict) -> Callable:
214
+ """_forward.
215
+
216
+ Parameters
217
+ ----------
218
+ model_inputs : Dict
219
+ model_inputs
220
+
221
+ Returns
222
+ -------
223
+ Callable
224
+
225
+ """
226
+ latents = model_inputs["latents"]
227
+ obj_embeddings = model_inputs["obj_embeddings"]
228
+ uncond_embeddings = model_inputs["uncond_embeddings"]
229
+ bg_embeddings = model_inputs["bg_embeddings"]
230
+
231
+ kappa = 0.5 * torch.ones(
232
+ (self.num_inference_steps + 1, self.batch_size), device=self.device
233
+ )
234
+ ll_obj = torch.ones(
235
+ (self.num_inference_steps + 1, self.batch_size), device=self.device
236
+ )
237
+ ll_bg = torch.ones(
238
+ (self.num_inference_steps + 1, self.batch_size), device=self.device
239
+ )
240
+ ll_uncond = torch.ones(
241
+ (self.num_inference_steps + 1, self.batch_size), device=self.device
242
+ )
243
+ with torch.no_grad():
244
+ for i, t in enumerate(self.scheduler.timesteps):
245
+ dsigma = self.scheduler.sigmas[i +
246
+ 1] - self.scheduler.sigmas[i]
247
+ sigma = self.scheduler.sigmas[i]
248
+ vel_obj = self.get_vel(t, sigma, latents, [obj_embeddings])
249
+ vel_uncond = self.get_vel(
250
+ t, sigma, latents, [uncond_embeddings])
251
+
252
+ vel_bg = self.get_vel(t, sigma, latents, [bg_embeddings])
253
+ noise = torch.sqrt(2 * torch.abs(dsigma) * sigma) * torch.randn_like(
254
+ latents
255
+ )
256
+
257
+ dx_ind = (
258
+ 2
259
+ * dsigma
260
+ * (vel_uncond + self.guidance_scale * (vel_bg - vel_uncond))
261
+ + noise
262
+ )
263
+ kappa[i + 1] = (
264
+ (torch.abs(dsigma) * (vel_bg - vel_obj) * (vel_bg + vel_obj)).sum(
265
+ (1, 2, 3)
266
+ )
267
+ - (dx_ind * ((vel_obj - vel_bg))).sum((1, 2, 3))
268
+ + sigma * self.lift / self.num_inference_steps
269
+ )
270
+ kappa[i + 1] /= (
271
+ 2
272
+ * dsigma
273
+ * self.guidance_scale
274
+ * ((vel_obj - vel_bg) ** 2).sum((1, 2, 3))
275
+ )
276
+
277
+ vf = vel_uncond + self.guidance_scale * (
278
+ (vel_bg - vel_uncond)
279
+ + kappa[i + 1][:, None, None, None] * (vel_obj - vel_bg)
280
+ )
281
+ dx = 2 * dsigma * vf + noise
282
+ latents += dx
283
+
284
+ ll_obj[i + 1] = ll_obj[i] + (
285
+ -torch.abs(dsigma) / sigma * (vel_obj) ** 2
286
+ - (dx * (vel_obj / sigma))
287
+ ).sum((1, 2, 3))
288
+ ll_bg[i + 1] = ll_bg[i] + (
289
+ -torch.abs(dsigma) / sigma * (vel_bg) ** 2 -
290
+ (dx * (vel_bg / sigma))
291
+ ).sum((1, 2, 3))
292
+
293
+ return latents
294
+
295
+ def postprocess(self, latents: Callable) -> Callable:
296
+ """postprocess.
297
+
298
+ Parameters
299
+ ----------
300
+ latents : Callable
301
+ latents
302
+
303
+ Returns
304
+ -------
305
+ Callable
306
+
307
+ """
308
+ image = self.get_batch(latents, 1, self.batch_size)
309
+ # Ensure the shape is (height, width, 3)
310
+ assert image.shape[-1] == 3 # Handle grayscale or invalid shapes
311
+
312
+ # Convert to uint8 if not already
313
+ image = image.to(torch.uint8) # Ensure it's uint8 for PIL
314
+
315
+ return image
316
+
317
+ def __call__(
318
+ self,
319
+ prompt_1: str,
320
+ prompt_2: str,
321
+ seed: int = None,
322
+ num_inference_steps: int = 1000,
323
+ batch_size: int = 1,
324
+ lift: int = 0.0,
325
+ height: int = 512,
326
+ width: int = 512,
327
+ guidance_scale: int = 7.5,
328
+ ) -> Callable:
329
+ """__call__.
330
+
331
+ Parameters
332
+ ----------
333
+ prompt_1 : str
334
+ prompt_1
335
+ prompt_2 : str
336
+ prompt_2
337
+ seed : int
338
+ seed
339
+ num_inference_steps : int
340
+ num_inference_steps
341
+ batch_size : int
342
+ batch_size
343
+ lift : int
344
+ lift
345
+ height : int
346
+ height
347
+ width : int
348
+ width
349
+ guidance_scale : int
350
+ guidance_scale
351
+
352
+ Returns
353
+ -------
354
+ Callable
355
+
356
+ """
357
+ # Preprocess inputs
358
+ model_inputs = self.preprocess(
359
+ prompt_1,
360
+ prompt_2,
361
+ seed,
362
+ num_inference_steps,
363
+ batch_size,
364
+ lift,
365
+ height,
366
+ width,
367
+ guidance_scale,
368
+ )
369
+
370
+ # Forward pass through the pipeline
371
+ latents = self._forward(model_inputs)
372
+
373
+ # Postprocess to generate the final output
374
+ images = self.postprocess(latents)
375
+ return images
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.31.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "final_sigmas_type": "zero",
9
+ "interpolation_type": "linear",
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "set_alpha_to_one": false,
14
+ "sigma_max": null,
15
+ "sigma_min": null,
16
+ "skip_prk_steps": true,
17
+ "steps_offset": 1,
18
+ "timestep_spacing": "linspace",
19
+ "timestep_type": "discrete",
20
+ "trained_betas": null,
21
+ "use_beta_sigmas": false,
22
+ "use_exponential_sigmas": false,
23
+ "use_karras_sigmas": false
24
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.46.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:778d02eb9e707c3fbaae0b67b79ea0d1399b52e624fb634f2f19375ae7c047c3
3
+ size 492265168
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.31.0",
4
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d2b5932bb4151e54e694fd31ccf51fca908223c9485bd56cd0e1d83ad94c49
3
+ size 334643268