Update pipeline.py
Browse files- pipeline.py +22 -15
pipeline.py
CHANGED
@@ -33,24 +33,31 @@ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
|
|
33 |
|
34 |
"""
|
35 |
super().__init__()
|
36 |
-
self.register_to_config(
|
37 |
-
batch_size=kwargs.get("batch_size", 1),
|
38 |
-
device=kwargs.get("device", "cuda"),
|
39 |
-
guidance_scale=kwargs.get("guidance_scale", 7.5),
|
40 |
-
lift=kwargs.get("lift", 0.0),
|
41 |
-
num_inference_steps=kwargs.get("num_inference_steps", 50),
|
42 |
-
seed=kwargs.get("seed", 42)
|
43 |
-
)
|
44 |
-
|
45 |
-
# Assign model components
|
46 |
-
self.vae = vae
|
47 |
-
self.scheduler = scheduler
|
48 |
-
self.tokenizer = tokenizer
|
49 |
self.unet = unet
|
|
|
50 |
self.text_encoder = text_encoder
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
self.to(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
@torch.no_grad
|
56 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|
|
|
33 |
|
34 |
"""
|
35 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
self.unet = unet
|
37 |
+
self.vae = vae
|
38 |
self.text_encoder = text_encoder
|
39 |
+
self.tokenizer = tokenizer
|
40 |
+
self.scheduler = scheduler
|
41 |
+
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
|
44 |
+
self.vae.to(device)
|
45 |
+
self.unet.to(device)
|
46 |
+
self.text_encoder.to(device)
|
47 |
+
|
48 |
+
self.register_to_config(
|
49 |
+
vae=vae.__class__.__name__,
|
50 |
+
scheduler=scheduler.__class__.__name__,
|
51 |
+
tokenizer=tokenizer.__class__.__name__,
|
52 |
+
unet=unet.__class__.__name__,
|
53 |
+
text_encoder=text_encoder.__class__.__name__,
|
54 |
+
device=device,
|
55 |
+
batch_size=None,
|
56 |
+
num_inference_steps=None,
|
57 |
+
guidance_scale=None,
|
58 |
+
lift=None,
|
59 |
+
seed=None,
|
60 |
+
)
|
61 |
|
62 |
@torch.no_grad
|
63 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|