Spaces:
Running
on
Zero
Running
on
Zero
陈硕
commited on
Commit
•
3a2f1ee
1
Parent(s):
f8acb76
update orbit lora
Browse files
app.py
CHANGED
@@ -55,11 +55,21 @@ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
|
55 |
text_encoder=pipe.text_encoder,
|
56 |
torch_dtype=torch.bfloat16,
|
57 |
)
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
# pipe.transformer.to(memory_format=torch.channels_last)
|
@@ -213,6 +223,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|
213 |
@spaces.GPU
|
214 |
def infer(
|
215 |
prompt: str,
|
|
|
216 |
image_input: str,
|
217 |
num_inference_steps: int,
|
218 |
guidance_scale: float,
|
@@ -235,6 +246,16 @@ def infer(
|
|
235 |
# guidance_scale=guidance_scale,
|
236 |
# generator=torch.Generator(device="cpu").manual_seed(seed),
|
237 |
# ).frames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
if image_input is not None:
|
239 |
image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
|
240 |
image = load_image(image_input)
|
@@ -301,6 +322,12 @@ with gr.Blocks() as demo:
|
|
301 |
</div>
|
302 |
""")
|
303 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
with gr.Column():
|
305 |
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
306 |
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
@@ -340,6 +367,7 @@ with gr.Blocks() as demo:
|
|
340 |
|
341 |
def generate(
|
342 |
prompt,
|
|
|
343 |
image_input,
|
344 |
# video_input,
|
345 |
# video_strength,
|
@@ -350,6 +378,7 @@ with gr.Blocks() as demo:
|
|
350 |
):
|
351 |
latents, seed = infer(
|
352 |
prompt,
|
|
|
353 |
image_input,
|
354 |
# video_input,
|
355 |
# video_strength,
|
@@ -386,7 +415,7 @@ with gr.Blocks() as demo:
|
|
386 |
|
387 |
generate_button.click(
|
388 |
generate,
|
389 |
-
inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
|
390 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
391 |
)
|
392 |
|
|
|
55 |
text_encoder=pipe.text_encoder,
|
56 |
torch_dtype=torch.bfloat16,
|
57 |
)
|
58 |
+
|
59 |
+
os.makedirs("checkpoints", exist_ok=True)
|
60 |
+
|
61 |
+
# Download LoRA weights
|
62 |
+
hf_hub_download(
|
63 |
+
repo_id="wenqsun/DimensionX",
|
64 |
+
filename="orbit_left_lora_weights.safetensors",
|
65 |
+
local_dir="checkpoints"
|
66 |
+
)
|
67 |
+
|
68 |
+
hf_hub_download(
|
69 |
+
repo_id="wenqsun/DimensionX",
|
70 |
+
filename="orbit_up_lora_weights.safetensors",
|
71 |
+
local_dir="checkpoints"
|
72 |
+
)
|
73 |
|
74 |
|
75 |
# pipe.transformer.to(memory_format=torch.channels_last)
|
|
|
223 |
@spaces.GPU
|
224 |
def infer(
|
225 |
prompt: str,
|
226 |
+
orbit_type: str,
|
227 |
image_input: str,
|
228 |
num_inference_steps: int,
|
229 |
guidance_scale: float,
|
|
|
246 |
# guidance_scale=guidance_scale,
|
247 |
# generator=torch.Generator(device="cpu").manual_seed(seed),
|
248 |
# ).frames
|
249 |
+
|
250 |
+
lora_path = "checkpoints/"
|
251 |
+
weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
|
252 |
+
lora_rank = 256
|
253 |
+
adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
254 |
+
|
255 |
+
# Load LoRA weights on CPU
|
256 |
+
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
|
257 |
+
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
258 |
+
|
259 |
if image_input is not None:
|
260 |
image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
|
261 |
image = load_image(image_input)
|
|
|
322 |
</div>
|
323 |
""")
|
324 |
with gr.Row():
|
325 |
+
with gr.Column():
|
326 |
+
image_in = gr.Image(label="Image Input", type="filepath")
|
327 |
+
prompt = gr.Textbox(label="Prompt")
|
328 |
+
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
329 |
+
submit_btn = gr.Button("Submit")
|
330 |
+
|
331 |
with gr.Column():
|
332 |
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
333 |
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
|
|
367 |
|
368 |
def generate(
|
369 |
prompt,
|
370 |
+
orbit_type,
|
371 |
image_input,
|
372 |
# video_input,
|
373 |
# video_strength,
|
|
|
378 |
):
|
379 |
latents, seed = infer(
|
380 |
prompt,
|
381 |
+
orbit_type,
|
382 |
image_input,
|
383 |
# video_input,
|
384 |
# video_strength,
|
|
|
415 |
|
416 |
generate_button.click(
|
417 |
generate,
|
418 |
+
inputs=[prompt, orbit_type, image_input, seed_param, enable_scale, enable_rife],
|
419 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
420 |
)
|
421 |
|