fffiloni commited on
Commit
51df367
1 Parent(s): 1b6ca43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -35
app.py CHANGED
@@ -35,6 +35,9 @@ vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dt
35
  tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
36
  pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
37
 
 
 
 
38
  def find_and_move_object_to_cpu():
39
  for obj in gc.get_objects():
40
  try:
@@ -52,52 +55,48 @@ def clear_gpu():
52
  gc.collect()
53
 
54
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
 
 
 
55
 
56
  lora_path = "checkpoints/"
57
- if orbit_type == "Left":
58
- weight_name = "orbit_left_lora_weights.safetensors"
59
- #adapter_name = "orbit_left_lora_weights"
60
- elif orbit_type == "Up":
61
- weight_name = "orbit_up_lora_weights.safetensors"
62
- #adapter_name = "orbit_up_lora_weights"
63
  lora_rank = 128
64
-
65
  adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
66
 
67
- # Load LoRA weights on CPU, move to GPU afterward
68
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
69
  pipe.fuse_lora(lora_scale=1 / lora_rank)
70
 
71
- # Move the pipeline to GPU for inference
72
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Set the inference prompt
75
- prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
76
- image = load_image(image_path)
77
- seed = random.randint(0, 2**8 - 1)
78
-
79
-
80
- video = pipe(
81
- image,
82
- prompt,
83
- num_inference_steps=25,
84
- guidance_scale=7.0,
85
- use_dynamic_cfg=True,
86
- generator=torch.Generator(device="cpu").manual_seed(seed)
87
- )
88
-
89
- torch.cuda.empty_cache()
90
- pipe.unfuse_lora()
91
- pipe.unload_lora_weights()
92
-
93
-
94
- # Generate and save output video
95
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
96
  export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
97
-
98
- # Move objects to CPU and clear GPU memory immediately after inference
99
- find_and_move_object_to_cpu()
100
- clear_gpu()
101
 
102
  return f"output_{timestamp}.mp4"
103
 
 
35
  tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
36
  pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
37
 
38
+ # Add this near the top after imports
39
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
40
+
41
  def find_and_move_object_to_cpu():
42
  for obj in gc.get_objects():
43
  try:
 
55
  gc.collect()
56
 
57
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
58
+ # Move everything to CPU initially
59
+ pipe.to("cpu")
60
+ torch.cuda.empty_cache()
61
 
62
  lora_path = "checkpoints/"
63
+ weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
 
 
 
 
 
64
  lora_rank = 128
 
65
  adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
66
 
67
+ # Load LoRA weights on CPU
68
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
69
  pipe.fuse_lora(lora_scale=1 / lora_rank)
70
 
71
+ try:
72
+ # Move to GPU just before inference
73
+ pipe.to("cuda")
74
+ torch.cuda.empty_cache()
75
+
76
+ prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
77
+ image = load_image(image_path)
78
+ seed = random.randint(0, 2**8 - 1)
79
+
80
+ with torch.inference_mode():
81
+ video = pipe(
82
+ image,
83
+ prompt,
84
+ num_inference_steps=25,
85
+ guidance_scale=7.0,
86
+ use_dynamic_cfg=True,
87
+ generator=torch.Generator(device="cpu").manual_seed(seed)
88
+ )
89
+ finally:
90
+ # Ensure cleanup happens even if inference fails
91
+ pipe.to("cpu")
92
+ pipe.unfuse_lora()
93
+ pipe.unload_lora_weights()
94
+ torch.cuda.empty_cache()
95
+ gc.collect()
96
 
97
+ # Generate output video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
99
  export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
 
 
 
 
100
 
101
  return f"output_{timestamp}.mp4"
102