mskrt commited on
Commit
1732305
verified
1 Parent(s): 45fe31c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +22 -15
pipeline.py CHANGED
@@ -33,24 +33,31 @@ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
33
 
34
  """
35
  super().__init__()
36
- self.register_to_config(
37
- batch_size=kwargs.get("batch_size", 1),
38
- device=kwargs.get("device", "cuda"),
39
- guidance_scale=kwargs.get("guidance_scale", 7.5),
40
- lift=kwargs.get("lift", 0.0),
41
- num_inference_steps=kwargs.get("num_inference_steps", 50),
42
- seed=kwargs.get("seed", 42)
43
- )
44
-
45
- # Assign model components
46
- self.vae = vae
47
- self.scheduler = scheduler
48
- self.tokenizer = tokenizer
49
  self.unet = unet
 
50
  self.text_encoder = text_encoder
 
 
 
 
51
 
52
- # Move components to device
53
- self.to(torch.device(self.config.device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @torch.no_grad
56
  def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
 
33
 
34
  """
35
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.unet = unet
37
+ self.vae = vae
38
  self.text_encoder = text_encoder
39
+ self.tokenizer = tokenizer
40
+ self.scheduler = scheduler
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
44
+ self.vae.to(device)
45
+ self.unet.to(device)
46
+ self.text_encoder.to(device)
47
+
48
+ self.register_to_config(
49
+ vae=vae.__class__.__name__,
50
+ scheduler=scheduler.__class__.__name__,
51
+ tokenizer=tokenizer.__class__.__name__,
52
+ unet=unet.__class__.__name__,
53
+ text_encoder=text_encoder.__class__.__name__,
54
+ device=device,
55
+ batch_size=None,
56
+ num_inference_steps=None,
57
+ guidance_scale=None,
58
+ lift=None,
59
+ seed=None,
60
+ )
61
 
62
  @torch.no_grad
63
  def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: