陈硕 commited on
Commit
3a2f1ee
1 Parent(s): f8acb76

update orbit lora

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -55,11 +55,21 @@ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
55
  text_encoder=pipe.text_encoder,
56
  torch_dtype=torch.bfloat16,
57
  )
58
- lora_path = "wenqsun/DimensionX"
59
- lora_rank = 256
60
- pipe_image.load_lora_weights(lora_path, weight_name="orbit_left_lora_weights.safetensors", adapter_name="orbit_left")
61
- pipe_image.fuse_lora(lora_scale=1 / lora_rank)
62
- pipe_image = pipe_image.to(device)
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  # pipe.transformer.to(memory_format=torch.channels_last)
@@ -213,6 +223,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
213
  @spaces.GPU
214
  def infer(
215
  prompt: str,
 
216
  image_input: str,
217
  num_inference_steps: int,
218
  guidance_scale: float,
@@ -235,6 +246,16 @@ def infer(
235
  # guidance_scale=guidance_scale,
236
  # generator=torch.Generator(device="cpu").manual_seed(seed),
237
  # ).frames
 
 
 
 
 
 
 
 
 
 
238
  if image_input is not None:
239
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
240
  image = load_image(image_input)
@@ -301,6 +322,12 @@ with gr.Blocks() as demo:
301
  </div>
302
  """)
303
  with gr.Row():
 
 
 
 
 
 
304
  with gr.Column():
305
  with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
306
  image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
@@ -340,6 +367,7 @@ with gr.Blocks() as demo:
340
 
341
  def generate(
342
  prompt,
 
343
  image_input,
344
  # video_input,
345
  # video_strength,
@@ -350,6 +378,7 @@ with gr.Blocks() as demo:
350
  ):
351
  latents, seed = infer(
352
  prompt,
 
353
  image_input,
354
  # video_input,
355
  # video_strength,
@@ -386,7 +415,7 @@ with gr.Blocks() as demo:
386
 
387
  generate_button.click(
388
  generate,
389
- inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
390
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
391
  )
392
 
 
55
  text_encoder=pipe.text_encoder,
56
  torch_dtype=torch.bfloat16,
57
  )
58
+
59
+ os.makedirs("checkpoints", exist_ok=True)
60
+
61
+ # Download LoRA weights
62
+ hf_hub_download(
63
+ repo_id="wenqsun/DimensionX",
64
+ filename="orbit_left_lora_weights.safetensors",
65
+ local_dir="checkpoints"
66
+ )
67
+
68
+ hf_hub_download(
69
+ repo_id="wenqsun/DimensionX",
70
+ filename="orbit_up_lora_weights.safetensors",
71
+ local_dir="checkpoints"
72
+ )
73
 
74
 
75
  # pipe.transformer.to(memory_format=torch.channels_last)
 
223
  @spaces.GPU
224
  def infer(
225
  prompt: str,
226
+ orbit_type: str,
227
  image_input: str,
228
  num_inference_steps: int,
229
  guidance_scale: float,
 
246
  # guidance_scale=guidance_scale,
247
  # generator=torch.Generator(device="cpu").manual_seed(seed),
248
  # ).frames
249
+
250
+ lora_path = "checkpoints/"
251
+ weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
252
+ lora_rank = 256
253
+ adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
254
+
255
+ # Load LoRA weights on CPU
256
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
257
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
258
+
259
  if image_input is not None:
260
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
261
  image = load_image(image_input)
 
322
  </div>
323
  """)
324
  with gr.Row():
325
+ with gr.Column():
326
+ image_in = gr.Image(label="Image Input", type="filepath")
327
+ prompt = gr.Textbox(label="Prompt")
328
+ orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
329
+ submit_btn = gr.Button("Submit")
330
+
331
  with gr.Column():
332
  with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
333
  image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
 
367
 
368
  def generate(
369
  prompt,
370
+ orbit_type,
371
  image_input,
372
  # video_input,
373
  # video_strength,
 
378
  ):
379
  latents, seed = infer(
380
  prompt,
381
+ orbit_type,
382
  image_input,
383
  # video_input,
384
  # video_strength,
 
415
 
416
  generate_button.click(
417
  generate,
418
+ inputs=[prompt, orbit_type, image_input, seed_param, enable_scale, enable_rife],
419
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
420
  )
421