update tiling
Browse files
videosys/pipelines/cogvideox/pipeline_cogvideox.py
CHANGED
@@ -50,8 +50,7 @@ class CogVideoXConfig:
|
|
50 |
self,
|
51 |
model_path: str = "THUDM/CogVideoX-2b",
|
52 |
world_size: int = 1,
|
53 |
-
|
54 |
-
guidance_scale: float = 6.0,
|
55 |
enable_pab: bool = False,
|
56 |
pab_config=CogVideoXPABConfig(),
|
57 |
):
|
@@ -61,15 +60,17 @@ class CogVideoXConfig:
|
|
61 |
# ======= pipeline ========
|
62 |
self.pipeline_cls = CogVideoXPipeline
|
63 |
|
|
|
|
|
64 |
# ======= model ========
|
65 |
self.model_path = model_path
|
66 |
-
self.num_inference_steps = num_inference_steps
|
67 |
-
self.guidance_scale = guidance_scale
|
68 |
self.enable_pab = enable_pab
|
69 |
self.pab_config = pab_config
|
70 |
|
71 |
|
72 |
class CogVideoXPipeline(VideoSysPipeline):
|
|
|
|
|
73 |
_callback_tensor_inputs = [
|
74 |
"latents",
|
75 |
"prompt_embeds",
|
@@ -98,6 +99,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|
98 |
)
|
99 |
if vae is None:
|
100 |
vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
|
|
|
|
|
101 |
if tokenizer is None:
|
102 |
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
|
103 |
if text_encoder is None:
|
|
|
50 |
self,
|
51 |
model_path: str = "THUDM/CogVideoX-2b",
|
52 |
world_size: int = 1,
|
53 |
+
vae_tiling: bool = True,
|
|
|
54 |
enable_pab: bool = False,
|
55 |
pab_config=CogVideoXPABConfig(),
|
56 |
):
|
|
|
60 |
# ======= pipeline ========
|
61 |
self.pipeline_cls = CogVideoXPipeline
|
62 |
|
63 |
+
self.vae_tiling = vae_tiling
|
64 |
+
|
65 |
# ======= model ========
|
66 |
self.model_path = model_path
|
|
|
|
|
67 |
self.enable_pab = enable_pab
|
68 |
self.pab_config = pab_config
|
69 |
|
70 |
|
71 |
class CogVideoXPipeline(VideoSysPipeline):
|
72 |
+
_optional_components = []
|
73 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
74 |
_callback_tensor_inputs = [
|
75 |
"latents",
|
76 |
"prompt_embeds",
|
|
|
99 |
)
|
100 |
if vae is None:
|
101 |
vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
|
102 |
+
if config.vae_tiling:
|
103 |
+
vae.enable_tiling()
|
104 |
if tokenizer is None:
|
105 |
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
|
106 |
if text_encoder is None:
|