luma-api-dreammachine-gui / gradio_app.py
Remsky's picture
Update app file reference in README and clean up unused import in gradio_app.py
54c49c7
import gradio as gr
import time
import random
import requests
from lumaai import LumaAI
import traceback
from lib.status_utils import load_messages, StatusTracker
from lib.image_utils import prepare_image
from lib.ui_components import create_interface
def generate_video(api_key, prompt, camera_motion, loop_video, image=None, progress=gr.Progress()):
if not api_key or not prompt:
raise gr.Error("Please enter your LumaAI API key and prompt")
try:
progress(0, desc="Initializing LumaAI...")
client = LumaAI(auth_token=api_key)
# Create status tracker with progress object
status_tracker = StatusTracker(
progress=lambda x: progress(x),
status_box=None
)
# Prepare generation parameters
generation_params = {
"prompt": f"{prompt} {camera_motion if camera_motion != 'None' else ''}",
"loop": loop_video,
"aspect_ratio": "1:1" # Force square aspect ratio
}
# Handle image if provided
if image is not None:
try:
progress(0.1, desc="Preparing image...")
cdn_url = prepare_image(image, status_tracker)
generation_params["keyframes"] = {
"frame0": {
"type": "image",
"url": cdn_url
}
}
except Exception as e:
raise gr.Error("Failed to process the input image")
progress(0.2, desc="Starting generation...")
try:
generation = client.generations.create(**generation_params)
except Exception as e:
raise gr.Error("Failed to start video generation. Please check your API key.")
# Load and shuffle status messages for variety
status_messages = load_messages()
random.shuffle(status_messages)
message_index = 0
last_message_time = time.time()
# Poll for completion
start_time = time.time()
last_status = None
while True:
try:
generation_status = client.generations.get(generation.id)
status = generation_status.state
elapsed_time = time.time() - start_time
current_time = time.time()
# Update status message at random intervals between 2-8 seconds
if current_time - last_message_time >= random.uniform(2, 8):
progress_val = min(0.2 + (elapsed_time/60), 0.8) # Adjusted for 1-minute expectation
progress(progress_val, desc=status_messages[message_index % len(status_messages)])
message_index += 1
last_message_time = current_time
if status == 'completed':
progress(0.9, desc="Generation completed!")
download_url = generation_status.assets.video
break
elif status == 'failed':
raise gr.Error("Video generation failed")
if elapsed_time > 300:
raise gr.Error("Generation timed out after 5 minutes")
time.sleep(1)
except Exception as e:
print(f"Error during generation polling: {str(e)}")
print(traceback.format_exc())
time.sleep(1)
continue
# Download the video
progress(0.95, desc="Downloading video...")
try:
response = requests.get(download_url, stream=True, timeout=30)
response.raise_for_status()
file_path = "output_video.mp4"
with open(file_path, 'wb') as file:
file.write(response.content)
progress(1.0, desc="Video ready!")
return file_path
except Exception as e:
raise gr.Error("Failed to download the generated video")
except gr.Error as e:
raise e
except Exception as e:
print(f"Error during generation: {str(e)}")
print(traceback.format_exc())
raise gr.Error("An unexpected error occurred")
# Create Gradio interface
app = create_interface(generate_video)
# For Hugging Face Spaces, we want to specify a smaller queue size
app.queue(max_size=5)
if __name__ == "__main__":
app.launch()