oahzxl commited on
Commit
9c115c5
1 Parent(s): 14a04f5

update tiling

Browse files
videosys/pipelines/cogvideox/pipeline_cogvideox.py CHANGED
@@ -50,8 +50,7 @@ class CogVideoXConfig:
50
  self,
51
  model_path: str = "THUDM/CogVideoX-2b",
52
  world_size: int = 1,
53
- num_inference_steps: int = 50,
54
- guidance_scale: float = 6.0,
55
  enable_pab: bool = False,
56
  pab_config=CogVideoXPABConfig(),
57
  ):
@@ -61,15 +60,17 @@ class CogVideoXConfig:
61
  # ======= pipeline ========
62
  self.pipeline_cls = CogVideoXPipeline
63
 
 
 
64
  # ======= model ========
65
  self.model_path = model_path
66
- self.num_inference_steps = num_inference_steps
67
- self.guidance_scale = guidance_scale
68
  self.enable_pab = enable_pab
69
  self.pab_config = pab_config
70
 
71
 
72
  class CogVideoXPipeline(VideoSysPipeline):
 
 
73
  _callback_tensor_inputs = [
74
  "latents",
75
  "prompt_embeds",
@@ -98,6 +99,8 @@ class CogVideoXPipeline(VideoSysPipeline):
98
  )
99
  if vae is None:
100
  vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
 
 
101
  if tokenizer is None:
102
  tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
103
  if text_encoder is None:
 
50
  self,
51
  model_path: str = "THUDM/CogVideoX-2b",
52
  world_size: int = 1,
53
+ vae_tiling: bool = True,
 
54
  enable_pab: bool = False,
55
  pab_config=CogVideoXPABConfig(),
56
  ):
 
60
  # ======= pipeline ========
61
  self.pipeline_cls = CogVideoXPipeline
62
 
63
+ self.vae_tiling = vae_tiling
64
+
65
  # ======= model ========
66
  self.model_path = model_path
 
 
67
  self.enable_pab = enable_pab
68
  self.pab_config = pab_config
69
 
70
 
71
  class CogVideoXPipeline(VideoSysPipeline):
72
+ _optional_components = []
73
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
74
  _callback_tensor_inputs = [
75
  "latents",
76
  "prompt_embeds",
 
99
  )
100
  if vae is None:
101
  vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
102
+ if config.vae_tiling:
103
+ vae.enable_tiling()
104
  if tokenizer is None:
105
  tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
106
  if text_encoder is None: