|
import gradio as gr |
|
import time |
|
import random |
|
import requests |
|
from lumaai import LumaAI |
|
import traceback |
|
|
|
from lib.status_utils import load_messages, StatusTracker |
|
from lib.image_utils import prepare_image |
|
from lib.api_utils import get_camera_motions |
|
from lib.ui_components import create_interface |
|
|
|
def generate_video(api_key, prompt, camera_motion, loop_video, image=None, progress=gr.Progress()): |
|
if not api_key or not prompt: |
|
raise gr.Error("Please enter your LumaAI API key and prompt") |
|
|
|
try: |
|
progress(0, desc="Initializing LumaAI...") |
|
client = LumaAI(auth_token=api_key) |
|
|
|
|
|
status_tracker = StatusTracker( |
|
progress=lambda x: progress(x), |
|
status_box=None |
|
) |
|
|
|
|
|
generation_params = { |
|
"prompt": f"{prompt} {camera_motion if camera_motion != 'None' else ''}", |
|
"loop": loop_video, |
|
"aspect_ratio": "1:1" |
|
} |
|
|
|
|
|
if image is not None: |
|
try: |
|
progress(0.1, desc="Preparing image...") |
|
cdn_url = prepare_image(image, status_tracker) |
|
generation_params["keyframes"] = { |
|
"frame0": { |
|
"type": "image", |
|
"url": cdn_url |
|
} |
|
} |
|
except Exception as e: |
|
raise gr.Error("Failed to process the input image") |
|
|
|
progress(0.2, desc="Starting generation...") |
|
try: |
|
generation = client.generations.create(**generation_params) |
|
except Exception as e: |
|
raise gr.Error("Failed to start video generation. Please check your API key.") |
|
|
|
|
|
status_messages = load_messages() |
|
random.shuffle(status_messages) |
|
message_index = 0 |
|
last_message_time = time.time() |
|
|
|
|
|
start_time = time.time() |
|
last_status = None |
|
|
|
while True: |
|
try: |
|
generation_status = client.generations.get(generation.id) |
|
status = generation_status.state |
|
elapsed_time = time.time() - start_time |
|
current_time = time.time() |
|
|
|
|
|
if current_time - last_message_time >= random.uniform(2, 8): |
|
progress_val = min(0.2 + (elapsed_time/60), 0.8) |
|
progress(progress_val, desc=status_messages[message_index % len(status_messages)]) |
|
message_index += 1 |
|
last_message_time = current_time |
|
|
|
if status == 'completed': |
|
progress(0.9, desc="Generation completed!") |
|
download_url = generation_status.assets.video |
|
break |
|
elif status == 'failed': |
|
raise gr.Error("Video generation failed") |
|
|
|
if elapsed_time > 300: |
|
raise gr.Error("Generation timed out after 5 minutes") |
|
|
|
time.sleep(1) |
|
|
|
except Exception as e: |
|
print(f"Error during generation polling: {str(e)}") |
|
print(traceback.format_exc()) |
|
time.sleep(1) |
|
continue |
|
|
|
|
|
progress(0.95, desc="Downloading video...") |
|
try: |
|
response = requests.get(download_url, stream=True, timeout=30) |
|
response.raise_for_status() |
|
file_path = "output_video.mp4" |
|
with open(file_path, 'wb') as file: |
|
file.write(response.content) |
|
|
|
progress(1.0, desc="Video ready!") |
|
return file_path |
|
except Exception as e: |
|
raise gr.Error("Failed to download the generated video") |
|
|
|
except gr.Error as e: |
|
raise e |
|
except Exception as e: |
|
print(f"Error during generation: {str(e)}") |
|
print(traceback.format_exc()) |
|
raise gr.Error("An unexpected error occurred") |
|
|
|
|
|
app = create_interface(generate_video) |
|
|
|
|
|
app.queue(max_size=5) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|