Optimize video generation by adding torch.no_grad() context to reduce memory usage
Browse files
app.py
CHANGED
@@ -198,23 +198,24 @@ def generate_video_from_text(
|
|
198 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
199 |
progress((step + 1) / num_inference_steps)
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
218 |
|
219 |
output_path = tempfile.mktemp(suffix=".mp4")
|
220 |
print(images.shape)
|
@@ -268,23 +269,24 @@ def generate_video_from_image(
|
|
268 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
269 |
progress((step + 1) / num_inference_steps)
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
288 |
|
289 |
output_path = tempfile.mktemp(suffix=".mp4")
|
290 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
|
|
198 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
199 |
progress((step + 1) / num_inference_steps)
|
200 |
|
201 |
+
with torch.no_grad():
|
202 |
+
images = pipeline(
|
203 |
+
num_inference_steps=num_inference_steps,
|
204 |
+
num_images_per_prompt=1,
|
205 |
+
guidance_scale=guidance_scale,
|
206 |
+
generator=generator,
|
207 |
+
output_type="pt",
|
208 |
+
height=height,
|
209 |
+
width=width,
|
210 |
+
num_frames=num_frames,
|
211 |
+
frame_rate=frame_rate,
|
212 |
+
**sample,
|
213 |
+
is_video=True,
|
214 |
+
vae_per_channel_normalize=True,
|
215 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
216 |
+
mixed_precision=True,
|
217 |
+
callback_on_step_end=gradio_progress_callback,
|
218 |
+
).images
|
219 |
|
220 |
output_path = tempfile.mktemp(suffix=".mp4")
|
221 |
print(images.shape)
|
|
|
269 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
270 |
progress((step + 1) / num_inference_steps)
|
271 |
|
272 |
+
with torch.no_grad():
|
273 |
+
images = pipeline(
|
274 |
+
num_inference_steps=num_inference_steps,
|
275 |
+
num_images_per_prompt=1,
|
276 |
+
guidance_scale=guidance_scale,
|
277 |
+
generator=generator,
|
278 |
+
output_type="pt",
|
279 |
+
height=height,
|
280 |
+
width=width,
|
281 |
+
num_frames=num_frames,
|
282 |
+
frame_rate=frame_rate,
|
283 |
+
**sample,
|
284 |
+
is_video=True,
|
285 |
+
vae_per_channel_normalize=True,
|
286 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
287 |
+
mixed_precision=True,
|
288 |
+
callback_on_step_end=gradio_progress_callback,
|
289 |
+
).images
|
290 |
|
291 |
output_path = tempfile.mktemp(suffix=".mp4")
|
292 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|