import torch
import gradio as gr
from diffusers import AnimateDiffSparseControlNetPipeline, AutoencoderKL, MotionAdapter, SparseControlNetModel, AnimateDiffPipeline, EulerAncestralDiscreteScheduler
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_video(prompt, negative_prompt, num_inference_steps, conditioning_frame_indices, controlnet_conditioning_scale):
    motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=torch.float16).to(device)
    controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-scribble", torch_dtype=torch.float16).to(device)
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
    pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
        "SG161222/Realistic_Vision_V5.1_noVAE",
        motion_adapter=motion_adapter,
        controlnet=controlnet,
        vae=vae,
        torch_dtype=torch.float16,
    ).to(device)
    
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)
    pipe.load_lora_weights("guoyww/animatediff-motion-lora-v1-5-3", adapter_name="motion_lora")
    pipe.fuse_lora(lora_scale=1.0)
    
    image_files = [
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png"
    ]
    conditioning_frames = [load_image(img_file) for img_file in image_files]

    video = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        conditioning_frames=conditioning_frames,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        controlnet_frame_indices=[int(x) for x in conditioning_frame_indices.split(",")],
        generator=torch.Generator().manual_seed(1337),
    ).frames
    
    output_file = "output.gif"
    export_to_gif(video, output_file)
    return output_file

def generate_simple_video(prompt):
    adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
    pipe = AnimateDiffPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16).to(device)
    pipe.scheduler = EulerAncestralDiscreteScheduler(
        beta_schedule="linear",
        beta_start=0.00085,
        beta_end=0.012,
    )
    
    pipe.enable_free_noise()
    pipe.vae.enable_slicing()
    pipe.enable_model_cpu_offload()

    frames = pipe(
        prompt,
        num_frames=64,
        num_inference_steps=20,
        guidance_scale=7.0,
        decode_chunk_size=2,
    ).frames
    
    output_file = "simple_output.gif"
    export_to_gif(frames, output_file)
    return output_file

demo1 = gr.Interface(
    fn=generate_video,
    inputs=[
        gr.Textbox(label="Prompt", value="an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"),
        gr.Textbox(label="Negative Prompt", value="low quality, worst quality, letterboxed"),
        gr.Slider(label="Number of Inference Steps", minimum=1, maximum=100, step=1, value=25),
        gr.Textbox(label="Conditioning Frame Indices", value="0, 8, 15"),
        gr.Slider(label="ControlNet Conditioning Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
    ],
    outputs=gr.Video(label="Generated Video"),
    title="Generate Video with AnimateDiffSparseControlNetPipeline",
    description="Generate a video using the AnimateDiffSparseControlNetPipeline."
)

demo2 = gr.Interface(
    fn=generate_simple_video,
    inputs=gr.Textbox(label="Prompt", value="An astronaut riding a horse on Mars."),
    outputs=gr.Video(label="Generated Simple Video"),
    title="Generate Simple Video with AnimateDiff",
    description="Generate a simple video using the AnimateDiffPipeline."
)

demo = gr.TabbedInterface([demo1, demo2], ["Advanced Video Generation", "Simple Video Generation"])

demo.launch()
#demo.launch(server_name="0.0.0.0", server_port=7910)