import gradio as gr
import os
from lumaai import AsyncLumaAI
import asyncio
import aiohttp

async def generate_video(api_key, prompt, loop=False, aspect_ratio="16:9", progress=gr.Progress()):
    client = AsyncLumaAI(auth_token=api_key)
    
    progress(0, desc="Initiating video generation...")
    generation = await client.generations.create(
        prompt=prompt,
        loop=loop,
        aspect_ratio=aspect_ratio
    )
    
    progress(0.1, desc="Video generation started. Waiting for completion...")
    
    # Poll for completion
    start_time = asyncio.get_event_loop().time()
    while True:
        status = await client.generations.get(id=generation.id)
        if status.state == "completed":
            break
        elif status.state == "failed":
            raise Exception("Video generation failed")
        
        # Update progress based on time elapsed (assuming 60 seconds total)
        elapsed_time = asyncio.get_event_loop().time() - start_time
        progress_value = min(0.1 + (elapsed_time / 60) * 0.8, 0.9)
        progress(progress_value, desc="Generating video...")
        
        await asyncio.sleep(5)

    progress(0.9, desc="Downloading generated video...")
    
    # Download the video
    video_url = status.assets.video
    async with aiohttp.ClientSession() as session:
        async with session.get(video_url) as resp:
            if resp.status == 200:
                file_name = f"luma_ai_generated_{generation.id}.mp4"
                with open(file_name, 'wb') as fd:
                    while True:
                        chunk = await resp.content.read(1024)
                        if not chunk:
                            break
                        fd.write(chunk)
    
    progress(1.0, desc="Video generation complete!")
    return file_name

async def text_to_video(api_key, prompt, loop, aspect_ratio, progress=gr.Progress()):
    if not api_key:
        raise gr.Error("Please enter your Luma AI API key.")
    
    try:
        video_path = await generate_video(api_key, prompt, loop, aspect_ratio, progress)
        return video_path, ""
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

async def image_to_video(api_key, prompt, image_url, loop, aspect_ratio, progress=gr.Progress()):
    if not api_key:
        raise gr.Error("Please enter your Luma AI API key.")
    
    try:
        client = AsyncLumaAI(auth_token=api_key)
        
        progress(0, desc="Initiating video generation from image...")
        generation = await client.generations.create(
            prompt=prompt,
            loop=loop,
            aspect_ratio=aspect_ratio,
            keyframes={
                "frame0": {
                    "type": "image",
                    "url": image_url
                }
            }
        )
        
        progress(0.1, desc="Video generation started. Waiting for completion...")
        
        # Poll for completion
        start_time = asyncio.get_event_loop().time()
        while True:
            status = await client.generations.get(id=generation.id)
            if status.state == "completed":
                break
            elif status.state == "failed":
                raise Exception("Video generation failed")
            
            # Update progress based on time elapsed (assuming 60 seconds total)
            elapsed_time = asyncio.get_event_loop().time() - start_time
            progress_value = min(0.1 + (elapsed_time / 60) * 0.8, 0.9)
            progress(progress_value, desc="Generating video...")
            
            await asyncio.sleep(5)

        progress(0.9, desc="Downloading generated video...")
        
        # Download the video
        video_url = status.assets.video
        async with aiohttp.ClientSession() as session:
            async with session.get(video_url) as resp:
                if resp.status == 200:
                    file_name = f"luma_ai_generated_{generation.id}.mp4"
                    with open(file_name, 'wb') as fd:
                        while True:
                            chunk = await resp.content.read(1024)
                            if not chunk:
                                break
                            fd.write(chunk)
        
        progress(1.0, desc="Video generation complete!")
        return file_name, ""
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

with gr.Blocks() as demo:
    gr.Markdown("# Luma AI Text-to-Video Demo")
    
    api_key = gr.Textbox(label="Luma AI API Key", type="password")
    
    with gr.Tab("Text to Video"):
        prompt = gr.Textbox(label="Prompt")
        generate_btn = gr.Button("Generate Video")
        video_output = gr.Video(label="Generated Video")
        error_output = gr.Textbox(label="Error Messages", visible=True)
        
        with gr.Accordion("Advanced Options", open=False):
            loop = gr.Checkbox(label="Loop", value=False)
            aspect_ratio = gr.Dropdown(label="Aspect Ratio", choices=["16:9", "1:1", "9:16", "4:3", "3:4"], value="16:9")
        
        generate_btn.click(
            text_to_video,
            inputs=[api_key, prompt, loop, aspect_ratio],
            outputs=[video_output, error_output]
        )
    
    with gr.Tab("Image to Video"):
        img_prompt = gr.Textbox(label="Prompt")
        img_url = gr.Textbox(label="Image URL")
        img_generate_btn = gr.Button("Generate Video from Image")
        img_video_output = gr.Video(label="Generated Video")
        img_error_output = gr.Textbox(label="Error Messages", visible=True)
        
        with gr.Accordion("Advanced Options", open=False):
            img_loop = gr.Checkbox(label="Loop", value=False)
            img_aspect_ratio = gr.Dropdown(label="Aspect Ratio", choices=["16:9", "1:1", "9:16", "4:3", "3:4"], value="16:9")
        
        img_generate_btn.click(
            image_to_video,
            inputs=[api_key, img_prompt, img_url, img_loop, img_aspect_ratio],
            outputs=[img_video_output, img_error_output]
        )

demo.queue().launch(share=True)