1aurent commited on
Commit
5fe830e
1 Parent(s): 270ab30

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +83 -0
pipeline.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.schedulers import DDIMScheduler
6
+ from diffusers.utils.torch_utils import randn_tensor
7
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
8
+
9
+
10
+ class DDIMPipelineCustom(DiffusionPipeline):
11
+ model_cpu_offload_seq = "unet"
12
+
13
+ def __init__(self, unet, scheduler):
14
+ super().__init__()
15
+
16
+ # make sure scheduler can always be converted to DDIM
17
+ scheduler = DDIMScheduler.from_config(scheduler.config)
18
+
19
+ self.register_modules(unet=unet, scheduler=scheduler)
20
+
21
+ @torch.no_grad()
22
+ def __call__(
23
+ self,
24
+ condition = None,
25
+ guidance: float = 1,
26
+ batch_size: int = 1,
27
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
28
+ eta: float = 0.0,
29
+ num_inference_steps: int = 50,
30
+ use_clipped_model_output: Optional[bool] = None,
31
+ output_type: Optional[str] = "pil",
32
+ return_dict: bool = True,
33
+ ) -> Union[ImagePipelineOutput, Tuple]:
34
+ # Sample gaussian noise to begin loop
35
+ if isinstance(self.unet.config.sample_size, int):
36
+ image_shape = (
37
+ batch_size,
38
+ self.unet.config.in_channels,
39
+ self.unet.config.sample_size,
40
+ self.unet.config.sample_size,
41
+ )
42
+ else:
43
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
44
+
45
+ if isinstance(generator, list) and len(generator) != batch_size:
46
+ raise ValueError(
47
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
48
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
49
+ )
50
+
51
+ image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
52
+
53
+ # set step values
54
+ self.scheduler.set_timesteps(num_inference_steps)
55
+
56
+ for t in self.progress_bar(self.scheduler.timesteps):
57
+ # 1. predict noise model_output
58
+ uncond = -torch.ones(batch_size, device=self.device)
59
+
60
+ if condition is not None:
61
+ model_output_uncond = self.unet(image, t, uncond).sample
62
+ model_output_cond = self.unet(image, t, condition).sample
63
+
64
+ model_output = torch.lerp(model_output_uncond, model_output_cond, guidance)
65
+ else:
66
+ model_output = self.unet(image, t, uncond).sample
67
+
68
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
69
+ # eta corresponds to η in paper and should be between [0, 1]
70
+ # do x_t -> x_t-1
71
+ image = self.scheduler.step(
72
+ model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
73
+ ).prev_sample
74
+
75
+ image = (image / 2 + 0.5).clamp(0, 1)
76
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
77
+ if output_type == "pil":
78
+ image = self.numpy_to_pil(image)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+
83
+ return ImagePipelineOutput(images=image)