benibraz commited on
Commit
27be903
·
1 Parent(s): a40928e

Optimize video generation by adding torch.no_grad() context to reduce memory usage

Browse files
Files changed (1) hide show
  1. app.py +36 -34
app.py CHANGED
@@ -198,23 +198,24 @@ def generate_video_from_text(
198
  def gradio_progress_callback(self, step, timestep, kwargs):
199
  progress((step + 1) / num_inference_steps)
200
 
201
- images = pipeline(
202
- num_inference_steps=num_inference_steps,
203
- num_images_per_prompt=1,
204
- guidance_scale=guidance_scale,
205
- generator=generator,
206
- output_type="pt",
207
- height=height,
208
- width=width,
209
- num_frames=num_frames,
210
- frame_rate=frame_rate,
211
- **sample,
212
- is_video=True,
213
- vae_per_channel_normalize=True,
214
- conditioning_method=ConditioningMethod.FIRST_FRAME,
215
- mixed_precision=True,
216
- callback_on_step_end=gradio_progress_callback,
217
- ).images
 
218
 
219
  output_path = tempfile.mktemp(suffix=".mp4")
220
  print(images.shape)
@@ -268,23 +269,24 @@ def generate_video_from_image(
268
  def gradio_progress_callback(self, step, timestep, kwargs):
269
  progress((step + 1) / num_inference_steps)
270
 
271
- images = pipeline(
272
- num_inference_steps=num_inference_steps,
273
- num_images_per_prompt=1,
274
- guidance_scale=guidance_scale,
275
- generator=generator,
276
- output_type="pt",
277
- height=height,
278
- width=width,
279
- num_frames=num_frames,
280
- frame_rate=frame_rate,
281
- **sample,
282
- is_video=True,
283
- vae_per_channel_normalize=True,
284
- conditioning_method=ConditioningMethod.FIRST_FRAME,
285
- mixed_precision=True,
286
- callback_on_step_end=gradio_progress_callback,
287
- ).images
 
288
 
289
  output_path = tempfile.mktemp(suffix=".mp4")
290
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
 
198
  def gradio_progress_callback(self, step, timestep, kwargs):
199
  progress((step + 1) / num_inference_steps)
200
 
201
+ with torch.no_grad():
202
+ images = pipeline(
203
+ num_inference_steps=num_inference_steps,
204
+ num_images_per_prompt=1,
205
+ guidance_scale=guidance_scale,
206
+ generator=generator,
207
+ output_type="pt",
208
+ height=height,
209
+ width=width,
210
+ num_frames=num_frames,
211
+ frame_rate=frame_rate,
212
+ **sample,
213
+ is_video=True,
214
+ vae_per_channel_normalize=True,
215
+ conditioning_method=ConditioningMethod.FIRST_FRAME,
216
+ mixed_precision=True,
217
+ callback_on_step_end=gradio_progress_callback,
218
+ ).images
219
 
220
  output_path = tempfile.mktemp(suffix=".mp4")
221
  print(images.shape)
 
269
  def gradio_progress_callback(self, step, timestep, kwargs):
270
  progress((step + 1) / num_inference_steps)
271
 
272
+ with torch.no_grad():
273
+ images = pipeline(
274
+ num_inference_steps=num_inference_steps,
275
+ num_images_per_prompt=1,
276
+ guidance_scale=guidance_scale,
277
+ generator=generator,
278
+ output_type="pt",
279
+ height=height,
280
+ width=width,
281
+ num_frames=num_frames,
282
+ frame_rate=frame_rate,
283
+ **sample,
284
+ is_video=True,
285
+ vae_per_channel_normalize=True,
286
+ conditioning_method=ConditioningMethod.FIRST_FRAME,
287
+ mixed_precision=True,
288
+ callback_on_step_end=gradio_progress_callback,
289
+ ).images
290
 
291
  output_path = tempfile.mktemp(suffix=".mp4")
292
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()